• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Repository rule for CUDA autoconfiguration.
2
3`cuda_configure` depends on the following environment variables:
4
5  * `TF_NEED_CUDA`: Whether to enable building with CUDA.
6  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
7  * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
8  * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
9    both host and device code compilation if TF_CUDA_CLANG is 1.
10  * `TF_SYSROOT`: The sysroot to use when compiling.
11  * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
12    compiler and use it to build tensorflow. When this option is set
13    CLANG_CUDA_COMPILER_PATH is ignored.
14  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
15    `/usr/local/cuda,usr/`.
16  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
17    `/usr/local/cuda`.
18  * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
19    use the system default.
20  * `TF_CUDNN_VERSION`: The version of the cuDNN library.
21  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
22    `/usr/local/cuda`.
23  * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
24    `3.5,5.2`.
25  * `PYTHON_BIN_PATH`: The python binary path
26"""
27
28load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
29load(
30    "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
31    "escape_string",
32    "get_env_var",
33)
34load(
35    "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
36    "find_msvc_tool",
37    "find_vc_path",
38    "setup_vc_env_vars",
39)
40load(
41    "//third_party/remote_config:common.bzl",
42    "config_repo_label",
43    "err_out",
44    "execute",
45    "get_bash_bin",
46    "get_cpu_value",
47    "get_host_environ",
48    "get_python_bin",
49    "is_windows",
50    "raw_exec",
51    "read_dir",
52    "realpath",
53    "which",
54)
55
56_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
57_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
58_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
59_TF_SYSROOT = "TF_SYSROOT"
60_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
61_TF_CUDA_VERSION = "TF_CUDA_VERSION"
62_TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
63_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
64_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
65_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
66_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
67_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
68
69def to_list_of_strings(elements):
70    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
71
72    This is to be used to put a list of strings into the bzl file templates
73    so it gets interpreted as list of strings in Starlark.
74
75    Args:
76      elements: list of string elements
77
78    Returns:
79      single string of elements wrapped in quotes separated by a comma."""
80    quoted_strings = ["\"" + element + "\"" for element in elements]
81    return ", ".join(quoted_strings)
82
83def verify_build_defines(params):
84    """Verify all variables that crosstool/BUILD.tpl expects are substituted.
85
86    Args:
87      params: dict of variables that will be passed to the BUILD.tpl template.
88    """
89    missing = []
90    for param in [
91        "cxx_builtin_include_directories",
92        "extra_no_canonical_prefixes_flags",
93        "host_compiler_path",
94        "host_compiler_prefix",
95        "host_compiler_warnings",
96        "linker_bin_path",
97        "compiler_deps",
98        "msvc_cl_path",
99        "msvc_env_include",
100        "msvc_env_lib",
101        "msvc_env_path",
102        "msvc_env_tmp",
103        "msvc_lib_path",
104        "msvc_link_path",
105        "msvc_ml_path",
106        "unfiltered_compile_flags",
107        "win_compiler_deps",
108    ]:
109        if ("%{" + param + "}") not in params:
110            missing.append(param)
111
112    if missing:
113        auto_configure_fail(
114            "BUILD.tpl template is missing these variables: " +
115            str(missing) +
116            ".\nWe only got: " +
117            str(params) +
118            ".",
119        )
120
121def _get_nvcc_tmp_dir_for_windows(repository_ctx):
122    """Return the Windows tmp directory for nvcc to generate intermediate source files."""
123    escaped_tmp_dir = escape_string(
124        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
125            "\\",
126            "\\\\",
127        ),
128    )
129    return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
130
131def _get_msvc_compiler(repository_ctx):
132    vc_path = find_vc_path(repository_ctx)
133    return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
134
135def _get_win_cuda_defines(repository_ctx):
136    """Return CROSSTOOL defines for Windows"""
137
138    # If we are not on Windows, return fake vaules for Windows specific fields.
139    # This ensures the CROSSTOOL file parser is happy.
140    if not is_windows(repository_ctx):
141        return {
142            "%{msvc_env_tmp}": "msvc_not_used",
143            "%{msvc_env_path}": "msvc_not_used",
144            "%{msvc_env_include}": "msvc_not_used",
145            "%{msvc_env_lib}": "msvc_not_used",
146            "%{msvc_cl_path}": "msvc_not_used",
147            "%{msvc_ml_path}": "msvc_not_used",
148            "%{msvc_link_path}": "msvc_not_used",
149            "%{msvc_lib_path}": "msvc_not_used",
150        }
151
152    vc_path = find_vc_path(repository_ctx)
153    if not vc_path:
154        auto_configure_fail(
155            "Visual C++ build tools not found on your machine." +
156            "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
157        )
158        return {}
159
160    env = setup_vc_env_vars(repository_ctx, vc_path)
161    escaped_paths = escape_string(env["PATH"])
162    escaped_include_paths = escape_string(env["INCLUDE"])
163    escaped_lib_paths = escape_string(env["LIB"])
164    escaped_tmp_dir = escape_string(
165        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
166            "\\",
167            "\\\\",
168        ),
169    )
170
171    msvc_cl_path = get_python_bin(repository_ctx)
172    msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
173        "\\",
174        "/",
175    )
176    msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
177        "\\",
178        "/",
179    )
180    msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
181        "\\",
182        "/",
183    )
184
185    # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
186    # The generated files are guaranteed to have unique name, so they can share
187    # the same tmp directory
188    escaped_cxx_include_directories = [
189        _get_nvcc_tmp_dir_for_windows(repository_ctx),
190        "C:\\\\botcode\\\\w",
191    ]
192    for path in escaped_include_paths.split(";"):
193        if path:
194            escaped_cxx_include_directories.append(path)
195
196    return {
197        "%{msvc_env_tmp}": escaped_tmp_dir,
198        "%{msvc_env_path}": escaped_paths,
199        "%{msvc_env_include}": escaped_include_paths,
200        "%{msvc_env_lib}": escaped_lib_paths,
201        "%{msvc_cl_path}": msvc_cl_path,
202        "%{msvc_ml_path}": msvc_ml_path,
203        "%{msvc_link_path}": msvc_link_path,
204        "%{msvc_lib_path}": msvc_lib_path,
205        "%{cxx_builtin_include_directories}": to_list_of_strings(
206            escaped_cxx_include_directories,
207        ),
208    }
209
210# TODO(dzc): Once these functions have been factored out of Bazel's
211# cc_configure.bzl, load them from @bazel_tools instead.
212# BEGIN cc_configure common functions.
213def find_cc(repository_ctx):
214    """Find the C++ compiler."""
215    if is_windows(repository_ctx):
216        return _get_msvc_compiler(repository_ctx)
217
218    if _use_cuda_clang(repository_ctx):
219        target_cc_name = "clang"
220        cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
221        if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
222            return "extra_tools/bin/clang"
223    else:
224        target_cc_name = "gcc"
225        cc_path_envvar = _GCC_HOST_COMPILER_PATH
226    cc_name = target_cc_name
227
228    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
229    if cc_name_from_env:
230        cc_name = cc_name_from_env
231    if cc_name.startswith("/"):
232        # Absolute path, maybe we should make this supported by our which function.
233        return cc_name
234    cc = which(repository_ctx, cc_name)
235    if cc == None:
236        fail(("Cannot find {}, either correct your path or set the {}" +
237              " environment variable").format(target_cc_name, cc_path_envvar))
238    return cc
239
240_INC_DIR_MARKER_BEGIN = "#include <...>"
241
242# OSX add " (framework directory)" at the end of line, strip it.
243_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
244_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
245
246def _cxx_inc_convert(path):
247    """Convert path returned by cc -E xc++ in a complete path."""
248    path = path.strip()
249    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
250        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
251    return path
252
253def _normalize_include_path(repository_ctx, path):
254    """Normalizes include paths before writing them to the crosstool.
255
256      If path points inside the 'crosstool' folder of the repository, a relative
257      path is returned.
258      If path points outside the 'crosstool' folder, an absolute path is returned.
259      """
260    path = str(repository_ctx.path(path))
261    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
262
263    if path.startswith(crosstool_folder):
264        # We drop the path to "$REPO/crosstool" and a trailing path separator.
265        return path[len(crosstool_folder) + 1:]
266    return path
267
268def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
269    """Compute the list of default C or C++ include directories."""
270    if lang_is_cpp:
271        lang = "c++"
272    else:
273        lang = "c"
274    sysroot = []
275    if tf_sysroot:
276        sysroot += ["--sysroot", tf_sysroot]
277    result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] +
278                                      sysroot)
279    stderr = err_out(result)
280    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
281    if index1 == -1:
282        return []
283    index1 = stderr.find("\n", index1)
284    if index1 == -1:
285        return []
286    index2 = stderr.rfind("\n ")
287    if index2 == -1 or index2 < index1:
288        return []
289    index2 = stderr.find("\n", index2 + 1)
290    if index2 == -1:
291        inc_dirs = stderr[index1 + 1:]
292    else:
293        inc_dirs = stderr[index1 + 1:index2].strip()
294
295    return [
296        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
297        for p in inc_dirs.split("\n")
298    ]
299
300def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
301    """Compute the list of default C and C++ include directories."""
302
303    # For some reason `clang -xc` sometimes returns include paths that are
304    # different from the ones from `clang -xc++`. (Symlink and a dir)
305    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
306    includes_cpp = _get_cxx_inc_directories_impl(
307        repository_ctx,
308        cc,
309        True,
310        tf_sysroot,
311    )
312    includes_c = _get_cxx_inc_directories_impl(
313        repository_ctx,
314        cc,
315        False,
316        tf_sysroot,
317    )
318
319    return includes_cpp + [
320        inc
321        for inc in includes_c
322        if inc not in includes_cpp
323    ]
324
325def auto_configure_fail(msg):
326    """Output failure message when cuda configuration fails."""
327    red = "\033[0;31m"
328    no_color = "\033[0m"
329    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
330
331# END cc_configure common functions (see TODO above).
332
333def _cuda_include_path(repository_ctx, cuda_config):
334    """Generates the Starlark string with cuda include directories.
335
336      Args:
337        repository_ctx: The repository context.
338        cc: The path to the gcc host compiler.
339
340      Returns:
341        A list of the gcc host compiler include directories.
342      """
343    nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
344        cuda_config.cuda_toolkit_path,
345        ".exe" if cuda_config.cpu_value == "Windows" else "",
346    ))
347
348    # The expected exit code of this command is non-zero. Bazel remote execution
349    # only caches commands with zero exit code. So force a zero exit code.
350    cmd = "%s -v /dev/null -o /dev/null ; [ $? -eq 1 ]" % str(nvcc_path)
351    result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd])
352    target_dir = ""
353    for one_line in err_out(result).splitlines():
354        if one_line.startswith("#$ _TARGET_DIR_="):
355            target_dir = (
356                cuda_config.cuda_toolkit_path + "/" + one_line.replace(
357                    "#$ _TARGET_DIR_=",
358                    "",
359                ) + "/include"
360            )
361    inc_entries = []
362    if target_dir != "":
363        inc_entries.append(realpath(repository_ctx, target_dir))
364    inc_entries.append(realpath(repository_ctx, cuda_config.cuda_toolkit_path + "/include"))
365    return inc_entries
366
367def enable_cuda(repository_ctx):
368    """Returns whether to build with CUDA support."""
369    return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
370
371def matches_version(environ_version, detected_version):
372    """Checks whether the user-specified version matches the detected version.
373
374      This function performs a weak matching so that if the user specifies only
375      the
376      major or major and minor versions, the versions are still considered
377      matching
378      if the version parts match. To illustrate:
379
380          environ_version  detected_version  result
381          -----------------------------------------
382          5.1.3            5.1.3             True
383          5.1              5.1.3             True
384          5                5.1               True
385          5.1.3            5.1               False
386          5.2.3            5.1.3             False
387
388      Args:
389        environ_version: The version specified by the user via environment
390          variables.
391        detected_version: The version autodetected from the CUDA installation on
392          the system.
393      Returns: True if user-specified version matches detected version and False
394        otherwise.
395    """
396    environ_version_parts = environ_version.split(".")
397    detected_version_parts = detected_version.split(".")
398    if len(detected_version_parts) < len(environ_version_parts):
399        return False
400    for i, part in enumerate(detected_version_parts):
401        if i >= len(environ_version_parts):
402            break
403        if part != environ_version_parts[i]:
404            return False
405    return True
406
407_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
408
409_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
410
411def compute_capabilities(repository_ctx):
412    """Returns a list of strings representing cuda compute capabilities.
413
414    Args:
415      repository_ctx: the repo rule's context.
416    Returns: list of cuda architectures to compile for. 'compute_xy' refers to
417      both PTX and SASS, 'sm_xy' refers to SASS only.
418    """
419    capabilities = get_host_environ(
420        repository_ctx,
421        _TF_CUDA_COMPUTE_CAPABILITIES,
422        "compute_35,compute_52",
423    ).split(",")
424
425    # Map old 'x.y' capabilities to 'compute_xy'.
426    for i, capability in enumerate(capabilities):
427        parts = capability.split(".")
428        if len(parts) != 2:
429            continue
430        capabilities[i] = "compute_%s%s" % (parts[0], parts[1])
431
432    # Make list unique
433    capabilities = dict(zip(capabilities, capabilities)).keys()
434
435    # Validate capabilities.
436    for capability in capabilities:
437        if not capability.startswith(("compute_", "sm_")):
438            auto_configure_fail("Invalid compute capability: %s" % capability)
439        for prefix in ["compute_", "sm_"]:
440            if not capability.startswith(prefix):
441                continue
442            if len(capability) == len(prefix) + 2 and capability[-2:].isdigit():
443                continue
444            auto_configure_fail("Invalid compute capability: %s" % capability)
445
446    return capabilities
447
448def lib_name(base_name, cpu_value, version = None, static = False):
449    """Constructs the platform-specific name of a library.
450
451      Args:
452        base_name: The name of the library, such as "cudart"
453        cpu_value: The name of the host operating system.
454        version: The version of the library.
455        static: True the library is static or False if it is a shared object.
456
457      Returns:
458        The platform-specific name of the library.
459      """
460    version = "" if not version else "." + version
461    if cpu_value in ("Linux", "FreeBSD"):
462        if static:
463            return "lib%s.a" % base_name
464        return "lib%s.so%s" % (base_name, version)
465    elif cpu_value == "Windows":
466        return "%s.lib" % base_name
467    elif cpu_value == "Darwin":
468        if static:
469            return "lib%s.a" % base_name
470        return "lib%s%s.dylib" % (base_name, version)
471    else:
472        auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
473
474def _lib_path(lib, cpu_value, basedir, version, static):
475    file_name = lib_name(lib, cpu_value, version, static)
476    return "%s/%s" % (basedir, file_name)
477
478def _should_check_soname(version, static):
479    return version and not static
480
481def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False):
482    return (
483        _lib_path(lib, cpu_value, basedir, version, static),
484        _should_check_soname(version, static),
485    )
486
487def _check_cuda_libs(repository_ctx, script_path, libs):
488    python_bin = get_python_bin(repository_ctx)
489    contents = repository_ctx.read(script_path).splitlines()
490
491    cmd = "from os import linesep;"
492    cmd += "f = open('script.py', 'w');"
493    for line in contents:
494        cmd += "f.write('%s' + linesep);" % line
495    cmd += "f.close();"
496    cmd += "from os import system;"
497    args = " ".join(["\"" + path + "\" " + str(check) for path, check in libs])
498    cmd += "system('%s script.py %s');" % (python_bin, args)
499
500    all_paths = [path for path, _ in libs]
501    checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines()
502
503    # Filter out empty lines from splitting on '\r\n' on Windows
504    checked_paths = [path for path in checked_paths if len(path) > 0]
505    if all_paths != checked_paths:
506        auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths))
507
508def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
509    """Returns the CUDA and cuDNN libraries on the system.
510
511      Also, verifies that the script actually exist.
512
513      Args:
514        repository_ctx: The repository context.
515        check_cuda_libs_script: The path to a script verifying that the cuda
516          libraries exist on the system.
517        cuda_config: The CUDA config as returned by _get_cuda_config
518
519      Returns:
520        Map of library names to structs of filename and path.
521      """
522    cpu_value = cuda_config.cpu_value
523    stub_dir = "" if is_windows(repository_ctx) else "/stubs"
524
525    check_cuda_libs_params = {
526        "cuda": _check_cuda_lib_params(
527            "cuda",
528            cpu_value,
529            cuda_config.config["cuda_library_dir"] + stub_dir,
530            version = None,
531            static = False,
532        ),
533        "cudart": _check_cuda_lib_params(
534            "cudart",
535            cpu_value,
536            cuda_config.config["cuda_library_dir"],
537            cuda_config.cudart_version,
538            static = False,
539        ),
540        "cudart_static": _check_cuda_lib_params(
541            "cudart_static",
542            cpu_value,
543            cuda_config.config["cuda_library_dir"],
544            cuda_config.cudart_version,
545            static = True,
546        ),
547        "cublas": _check_cuda_lib_params(
548            "cublas",
549            cpu_value,
550            cuda_config.config["cublas_library_dir"],
551            cuda_config.cublas_version,
552            static = False,
553        ),
554        "cublasLt": _check_cuda_lib_params(
555            "cublasLt",
556            cpu_value,
557            cuda_config.config["cublas_library_dir"],
558            cuda_config.cublas_version,
559            static = False,
560        ),
561        "cusolver": _check_cuda_lib_params(
562            "cusolver",
563            cpu_value,
564            cuda_config.config["cusolver_library_dir"],
565            cuda_config.cusolver_version,
566            static = False,
567        ),
568        "curand": _check_cuda_lib_params(
569            "curand",
570            cpu_value,
571            cuda_config.config["curand_library_dir"],
572            cuda_config.curand_version,
573            static = False,
574        ),
575        "cufft": _check_cuda_lib_params(
576            "cufft",
577            cpu_value,
578            cuda_config.config["cufft_library_dir"],
579            cuda_config.cufft_version,
580            static = False,
581        ),
582        "cudnn": _check_cuda_lib_params(
583            "cudnn",
584            cpu_value,
585            cuda_config.config["cudnn_library_dir"],
586            cuda_config.cudnn_version,
587            static = False,
588        ),
589        "cupti": _check_cuda_lib_params(
590            "cupti",
591            cpu_value,
592            cuda_config.config["cupti_library_dir"],
593            cuda_config.cuda_version,
594            static = False,
595        ),
596        "cusparse": _check_cuda_lib_params(
597            "cusparse",
598            cpu_value,
599            cuda_config.config["cusparse_library_dir"],
600            cuda_config.cusparse_version,
601            static = False,
602        ),
603    }
604
605    # Verify that the libs actually exist at their locations.
606    _check_cuda_libs(repository_ctx, check_cuda_libs_script, check_cuda_libs_params.values())
607
608    paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()}
609    return paths
610
611def _cudart_static_linkopt(cpu_value):
612    """Returns additional platform-specific linkopts for cudart."""
613    return "" if cpu_value == "Darwin" else "\"-lrt\","
614
615def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
616    python_bin = get_python_bin(repository_ctx)
617
618    # If used with remote execution then repository_ctx.execute() can't
619    # access files from the source tree. A trick is to read the contents
620    # of the file in Starlark and embed them as part of the command. In
621    # this case the trick is not sufficient as the find_cuda_config.py
622    # script has more than 8192 characters. 8192 is the command length
623    # limit of cmd.exe on Windows. Thus we additionally need to compress
624    # the contents locally and decompress them as part of the execute().
625    compressed_contents = repository_ctx.read(script_path)
626    decompress_and_execute_cmd = (
627        "from zlib import decompress;" +
628        "from base64 import b64decode;" +
629        "from os import system;" +
630        "script = decompress(b64decode('%s'));" % compressed_contents +
631        "f = open('script.py', 'wb');" +
632        "f.write(script);" +
633        "f.close();" +
634        "system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
635    )
636
637    return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
638
639# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
640# and nccl_configure.bzl.
641def find_cuda_config(repository_ctx, script_path, cuda_libraries):
642    """Returns CUDA config dictionary from running find_cuda_config.py"""
643    exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
644    if exec_result.return_code:
645        auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
646
647    # Parse the dict from stdout.
648    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
649
650def _get_cuda_config(repository_ctx, find_cuda_config_script):
651    """Detects and returns information about the CUDA installation on the system.
652
653      Args:
654        repository_ctx: The repository context.
655
656      Returns:
657        A struct containing the following fields:
658          cuda_toolkit_path: The CUDA toolkit installation directory.
659          cudnn_install_basedir: The cuDNN installation directory.
660          cuda_version: The version of CUDA on the system.
661          cudart_version: The CUDA runtime version on the system.
662          cudnn_version: The version of cuDNN on the system.
663          compute_capabilities: A list of the system's CUDA compute capabilities.
664          cpu_value: The name of the host operating system.
665      """
666    config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
667    cpu_value = get_cpu_value(repository_ctx)
668    toolkit_path = config["cuda_toolkit_path"]
669
670    is_windows = cpu_value == "Windows"
671    cuda_version = config["cuda_version"].split(".")
672    cuda_major = cuda_version[0]
673    cuda_minor = cuda_version[1]
674
675    cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
676    cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
677
678    if int(cuda_major) >= 11:
679        # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
680        if int(cuda_major) == 11:
681            cudart_version = "64_110" if is_windows else "11.0"
682        else:
683            cudart_version = ("64_%s" if is_windows else "%s") % cuda_major
684        cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
685        cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
686        curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
687        cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
688        cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
689    elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
690        # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
691        # It changed from 'x.y' to just 'x' in CUDA 10.1.
692        cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
693        cudart_version = cuda_version
694        cublas_version = cuda_lib_version
695        cusolver_version = cuda_lib_version
696        curand_version = cuda_lib_version
697        cufft_version = cuda_lib_version
698        cusparse_version = cuda_lib_version
699    else:
700        cudart_version = cuda_version
701        cublas_version = cuda_version
702        cusolver_version = cuda_version
703        curand_version = cuda_version
704        cufft_version = cuda_version
705        cusparse_version = cuda_version
706
707    return struct(
708        cuda_toolkit_path = toolkit_path,
709        cuda_version = cuda_version,
710        cuda_version_major = cuda_major,
711        cudart_version = cudart_version,
712        cublas_version = cublas_version,
713        cusolver_version = cusolver_version,
714        curand_version = curand_version,
715        cufft_version = cufft_version,
716        cusparse_version = cusparse_version,
717        cudnn_version = cudnn_version,
718        compute_capabilities = compute_capabilities(repository_ctx),
719        cpu_value = cpu_value,
720        config = config,
721    )
722
723def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
724    if not out:
725        out = tpl.replace(":", "/")
726    repository_ctx.template(
727        out,
728        Label("//third_party/gpus/%s.tpl" % tpl),
729        substitutions,
730    )
731
732def _file(repository_ctx, label):
733    repository_ctx.template(
734        label.replace(":", "/"),
735        Label("//third_party/gpus/%s.tpl" % label),
736        {},
737    )
738
739_DUMMY_CROSSTOOL_BZL_FILE = """
740def error_gpu_disabled():
741  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
742       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
743       "at the prompt to build with GPU support.")
744
745  native.genrule(
746      name = "error_gen_crosstool",
747      outs = ["CROSSTOOL"],
748      cmd = "echo 'Should not be run.' && exit 1",
749  )
750
751  native.filegroup(
752      name = "crosstool",
753      srcs = [":CROSSTOOL"],
754      output_licenses = ["unencumbered"],
755  )
756"""
757
758_DUMMY_CROSSTOOL_BUILD_FILE = """
759load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
760
761error_gpu_disabled()
762"""
763
764def _create_dummy_repository(repository_ctx):
765    cpu_value = get_cpu_value(repository_ctx)
766
767    # Set up BUILD file for cuda/.
768    _tpl(
769        repository_ctx,
770        "cuda:build_defs.bzl",
771        {
772            "%{cuda_is_configured}": "False",
773            "%{cuda_extra_copts}": "[]",
774            "%{cuda_gpu_architectures}": "[]",
775        },
776    )
777    _tpl(
778        repository_ctx,
779        "cuda:BUILD",
780        {
781            "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
782            "%{cudart_static_lib}": lib_name(
783                "cudart_static",
784                cpu_value,
785                static = True,
786            ),
787            "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
788            "%{cudart_lib}": lib_name("cudart", cpu_value),
789            "%{cublas_lib}": lib_name("cublas", cpu_value),
790            "%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
791            "%{cusolver_lib}": lib_name("cusolver", cpu_value),
792            "%{cudnn_lib}": lib_name("cudnn", cpu_value),
793            "%{cufft_lib}": lib_name("cufft", cpu_value),
794            "%{curand_lib}": lib_name("curand", cpu_value),
795            "%{cupti_lib}": lib_name("cupti", cpu_value),
796            "%{cusparse_lib}": lib_name("cusparse", cpu_value),
797            "%{cub_actual}": ":cuda_headers",
798            "%{copy_rules}": """
799filegroup(name="cuda-include")
800filegroup(name="cublas-include")
801filegroup(name="cusolver-include")
802filegroup(name="cufft-include")
803filegroup(name="cusparse-include")
804filegroup(name="curand-include")
805filegroup(name="cudnn-include")
806""",
807        },
808    )
809
810    # Create dummy files for the CUDA toolkit since they are still required by
811    # tensorflow/core/platform/default/build_config:cuda.
812    repository_ctx.file("cuda/cuda/include/cuda.h")
813    repository_ctx.file("cuda/cuda/include/cublas.h")
814    repository_ctx.file("cuda/cuda/include/cudnn.h")
815    repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
816    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
817    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
818    repository_ctx.file(
819        "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
820    )
821    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
822    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
823    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
824    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
825    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
826    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value))
827    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value))
828    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value))
829
830    # Set up cuda_config.h, which is used by
831    # tensorflow/stream_executor/dso_loader.cc.
832    _tpl(
833        repository_ctx,
834        "cuda:cuda_config.h",
835        {
836            "%{cuda_version}": "",
837            "%{cudart_version}": "",
838            "%{cublas_version}": "",
839            "%{cusolver_version}": "",
840            "%{curand_version}": "",
841            "%{cufft_version}": "",
842            "%{cusparse_version}": "",
843            "%{cudnn_version}": "",
844            "%{cuda_toolkit_path}": "",
845            "%{cuda_compute_capabilities}": "",
846        },
847        "cuda/cuda/cuda_config.h",
848    )
849
850    # Set up cuda_config.py, which is used by gen_build_info to provide
851    # static build environment info to the API
852    _tpl(
853        repository_ctx,
854        "cuda:cuda_config.py",
855        _py_tmpl_dict({}),
856        "cuda/cuda/cuda_config.py",
857    )
858
859    # If cuda_configure is not configured to build with GPU support, and the user
860    # attempts to build with --config=cuda, add a dummy build rule to intercept
861    # this and fail with an actionable error message.
862    repository_ctx.file(
863        "crosstool/error_gpu_disabled.bzl",
864        _DUMMY_CROSSTOOL_BZL_FILE,
865    )
866    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
867
868def _norm_path(path):
869    """Returns a path with '/' and remove the trailing slash."""
870    path = path.replace("\\", "/")
871    if path[-1] == "/":
872        path = path[:-1]
873    return path
874
875def make_copy_files_rule(repository_ctx, name, srcs, outs):
876    """Returns a rule to copy a set of files."""
877    cmds = []
878
879    # Copy files.
880    for src, out in zip(srcs, outs):
881        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
882    outs = [('        "%s",' % out) for out in outs]
883    return """genrule(
884    name = "%s",
885    outs = [
886%s
887    ],
888    cmd = \"""%s \""",
889)""" % (name, "\n".join(outs), " && \\\n".join(cmds))
890
891def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None):
892    """Returns a rule to recursively copy a directory.
893    If exceptions is not None, it must be a list of files or directories in
894    'src_dir'; these will be excluded from copying.
895    """
896    src_dir = _norm_path(src_dir)
897    out_dir = _norm_path(out_dir)
898    outs = read_dir(repository_ctx, src_dir)
899    post_cmd = ""
900    if exceptions != None:
901        outs = [x for x in outs if not any([
902            x.startswith(src_dir + "/" + y)
903            for y in exceptions
904        ])]
905    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
906
907    # '@D' already contains the relative path for a single file, see
908    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
909    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
910    if exceptions != None:
911        for x in exceptions:
912            post_cmd += " ; rm -fR " + out_dir + "/" + x
913    return """genrule(
914    name = "%s",
915    outs = [
916%s
917    ],
918    cmd = \"""cp -rLf "%s/." "%s/" %s\""",
919)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd)
920
921def _flag_enabled(repository_ctx, flag_name):
922    return get_host_environ(repository_ctx, flag_name) == "1"
923
924def _use_cuda_clang(repository_ctx):
925    return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
926
927def _tf_sysroot(repository_ctx):
928    return get_host_environ(repository_ctx, _TF_SYSROOT, "")
929
930def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
931    copts = []
932    for capability in compute_capabilities:
933        if capability.startswith("compute_"):
934            capability = capability.replace("compute_", "sm_")
935            copts.append("--cuda-include-ptx=%s" % capability)
936        copts.append("--cuda-gpu-arch=%s" % capability)
937
938    return str(copts)
939
940def _tpl_path(repository_ctx, filename):
941    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
942
943def _basename(repository_ctx, path_str):
944    """Returns the basename of a path of type string.
945
946    This method is different from path.basename in that it also works if
947    the host platform is different from the execution platform
948    i.e. linux -> windows.
949    """
950
951    num_chars = len(path_str)
952    is_win = is_windows(repository_ctx)
953    for i in range(num_chars):
954        r_i = num_chars - 1 - i
955        if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/":
956            return path_str[r_i + 1:]
957    return path_str
958
959def _create_local_cuda_repository(repository_ctx):
960    """Creates the repository containing files set up to build with CUDA."""
961
962    # Resolve all labels before doing any real work. Resolving causes the
963    # function to be restarted with all previous state being lost. This
964    # can easily lead to a O(n^2) runtime in the number of labels.
965    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
966    tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [
967        "cuda:build_defs.bzl",
968        "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
969        "crosstool:windows/msvc_wrapper_for_nvcc.py",
970        "crosstool:BUILD",
971        "crosstool:cc_toolchain_config.bzl",
972        "cuda:cuda_config.h",
973        "cuda:cuda_config.py",
974    ]}
975    tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
976    find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
977
978    cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
979
980    cuda_include_path = cuda_config.config["cuda_include_dir"]
981    cublas_include_path = cuda_config.config["cublas_include_dir"]
982    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
983    cupti_header_dir = cuda_config.config["cupti_include_dir"]
984    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
985
986    # Create genrule to copy files from the installed CUDA toolkit into execroot.
987    copy_rules = [
988        make_copy_dir_rule(
989            repository_ctx,
990            name = "cuda-include",
991            src_dir = cuda_include_path,
992            out_dir = "cuda/include",
993        ),
994        make_copy_dir_rule(
995            repository_ctx,
996            name = "cuda-nvvm",
997            src_dir = nvvm_libdevice_dir,
998            out_dir = "cuda/nvvm/libdevice",
999        ),
1000        make_copy_dir_rule(
1001            repository_ctx,
1002            name = "cuda-extras",
1003            src_dir = cupti_header_dir,
1004            out_dir = "cuda/extras/CUPTI/include",
1005        ),
1006    ]
1007
1008    copy_rules.append(make_copy_files_rule(
1009        repository_ctx,
1010        name = "cublas-include",
1011        srcs = [
1012            cublas_include_path + "/cublas.h",
1013            cublas_include_path + "/cublas_v2.h",
1014            cublas_include_path + "/cublas_api.h",
1015            cublas_include_path + "/cublasLt.h",
1016        ],
1017        outs = [
1018            "cublas/include/cublas.h",
1019            "cublas/include/cublas_v2.h",
1020            "cublas/include/cublas_api.h",
1021            "cublas/include/cublasLt.h",
1022        ],
1023    ))
1024
1025    cusolver_include_path = cuda_config.config["cusolver_include_dir"]
1026    copy_rules.append(make_copy_files_rule(
1027        repository_ctx,
1028        name = "cusolver-include",
1029        srcs = [
1030            cusolver_include_path + "/cusolver_common.h",
1031            cusolver_include_path + "/cusolverDn.h",
1032        ],
1033        outs = [
1034            "cusolver/include/cusolver_common.h",
1035            "cusolver/include/cusolverDn.h",
1036        ],
1037    ))
1038
1039    cufft_include_path = cuda_config.config["cufft_include_dir"]
1040    copy_rules.append(make_copy_files_rule(
1041        repository_ctx,
1042        name = "cufft-include",
1043        srcs = [
1044            cufft_include_path + "/cufft.h",
1045        ],
1046        outs = [
1047            "cufft/include/cufft.h",
1048        ],
1049    ))
1050
1051    cusparse_include_path = cuda_config.config["cusparse_include_dir"]
1052    copy_rules.append(make_copy_files_rule(
1053        repository_ctx,
1054        name = "cusparse-include",
1055        srcs = [
1056            cusparse_include_path + "/cusparse.h",
1057        ],
1058        outs = [
1059            "cusparse/include/cusparse.h",
1060        ],
1061    ))
1062
1063    curand_include_path = cuda_config.config["curand_include_dir"]
1064    copy_rules.append(make_copy_files_rule(
1065        repository_ctx,
1066        name = "curand-include",
1067        srcs = [
1068            curand_include_path + "/curand.h",
1069        ],
1070        outs = [
1071            "curand/include/curand.h",
1072        ],
1073    ))
1074
1075    check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
1076    cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
1077    cuda_lib_srcs = []
1078    cuda_lib_outs = []
1079    for path in cuda_libs.values():
1080        cuda_lib_srcs.append(path)
1081        cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path))
1082    copy_rules.append(make_copy_files_rule(
1083        repository_ctx,
1084        name = "cuda-lib",
1085        srcs = cuda_lib_srcs,
1086        outs = cuda_lib_outs,
1087    ))
1088
1089    # copy files mentioned in third_party/nccl/build_defs.bzl.tpl
1090    file_ext = ".exe" if is_windows(repository_ctx) else ""
1091    bin_files = (
1092        ["crt/link.stub"] +
1093        [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]]
1094    )
1095    copy_rules.append(make_copy_files_rule(
1096        repository_ctx,
1097        name = "cuda-bin",
1098        srcs = [cuda_config.cuda_toolkit_path + "/bin/" + f for f in bin_files],
1099        outs = ["cuda/bin/" + f for f in bin_files],
1100    ))
1101
1102    # Select the headers based on the cuDNN version (strip '64_' for Windows).
1103    cudnn_headers = ["cudnn.h"]
1104    if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8":
1105        cudnn_headers += [
1106            "cudnn_backend.h",
1107            "cudnn_adv_infer.h",
1108            "cudnn_adv_train.h",
1109            "cudnn_cnn_infer.h",
1110            "cudnn_cnn_train.h",
1111            "cudnn_ops_infer.h",
1112            "cudnn_ops_train.h",
1113            "cudnn_version.h",
1114        ]
1115
1116    cudnn_srcs = []
1117    cudnn_outs = []
1118    for header in cudnn_headers:
1119        cudnn_srcs.append(cudnn_header_dir + "/" + header)
1120        cudnn_outs.append("cudnn/include/" + header)
1121
1122    copy_rules.append(make_copy_files_rule(
1123        repository_ctx,
1124        name = "cudnn-include",
1125        srcs = cudnn_srcs,
1126        outs = cudnn_outs,
1127    ))
1128
1129    # Set up BUILD file for cuda/
1130    repository_ctx.template(
1131        "cuda/build_defs.bzl",
1132        tpl_paths["cuda:build_defs.bzl"],
1133        {
1134            "%{cuda_is_configured}": "True",
1135            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1136                repository_ctx,
1137                cuda_config.compute_capabilities,
1138            ),
1139            "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
1140        },
1141    )
1142
1143    cub_actual = "@cub_archive//:cub"
1144    if int(cuda_config.cuda_version_major) >= 11:
1145        cub_actual = ":cuda_headers"
1146
1147    repository_ctx.template(
1148        "cuda/BUILD",
1149        tpl_paths["cuda:BUILD"],
1150        {
1151            "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]),
1152            "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]),
1153            "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
1154            "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
1155            "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
1156            "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
1157            "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
1158            "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
1159            "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
1160            "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
1161            "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
1162            "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
1163            "%{cub_actual}": cub_actual,
1164            "%{copy_rules}": "\n".join(copy_rules),
1165        },
1166    )
1167
1168    is_cuda_clang = _use_cuda_clang(repository_ctx)
1169    tf_sysroot = _tf_sysroot(repository_ctx)
1170
1171    should_download_clang = is_cuda_clang and _flag_enabled(
1172        repository_ctx,
1173        _TF_DOWNLOAD_CLANG,
1174    )
1175    if should_download_clang:
1176        download_clang(repository_ctx, "crosstool/extra_tools")
1177
1178    # Set up crosstool/
1179    cc = find_cc(repository_ctx)
1180    cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
1181
1182    host_compiler_includes = get_cxx_inc_directories(
1183        repository_ctx,
1184        cc_fullpath,
1185        tf_sysroot,
1186    )
1187    cuda_defines = {}
1188    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
1189    cuda_defines["%{cuda_toolkit_path}"] = ""
1190    cuda_defines["%{compiler}"] = "unknown"
1191    if is_cuda_clang:
1192        cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
1193        cuda_defines["%{compiler}"] = "clang"
1194
1195    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
1196    if not host_compiler_prefix:
1197        host_compiler_prefix = "/usr/bin"
1198
1199    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
1200
1201    # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
1202    # https://github.com/bazelbuild/bazel/issues/760).
1203    # However, this stops our custom clang toolchain from picking the provided
1204    # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
1205    # toolchain.
1206    # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
1207    #       flag from the CROSSTOOL completely (see
1208    #       https://github.com/bazelbuild/bazel/issues/5634)
1209    if should_download_clang:
1210        cuda_defines["%{linker_bin_path}"] = ""
1211    else:
1212        cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
1213
1214    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
1215    cuda_defines["%{unfiltered_compile_flags}"] = ""
1216    if is_cuda_clang:
1217        cuda_defines["%{host_compiler_path}"] = str(cc)
1218        cuda_defines["%{host_compiler_warnings}"] = """
1219        # Some parts of the codebase set -Werror and hit this warning, so
1220        # switch it off for now.
1221        "-Wno-invalid-partial-specialization"
1222    """
1223        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
1224        cuda_defines["%{compiler_deps}"] = ":empty"
1225        cuda_defines["%{win_compiler_deps}"] = ":empty"
1226        repository_ctx.file(
1227            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1228            "",
1229        )
1230        repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
1231    else:
1232        cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
1233        cuda_defines["%{host_compiler_warnings}"] = ""
1234
1235        # nvcc has the system include paths built in and will automatically
1236        # search them; we cannot work around that, so we add the relevant cuda
1237        # system paths to the allowed compiler specific include paths.
1238        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
1239            host_compiler_includes + _cuda_include_path(
1240                repository_ctx,
1241                cuda_config,
1242            ) + [cupti_header_dir, cudnn_header_dir],
1243        )
1244
1245        # For gcc, do not canonicalize system header paths; some versions of gcc
1246        # pick the shortest possible path for system includes when creating the
1247        # .d file - given that includes that are prefixed with "../" multiple
1248        # time quickly grow longer than the root of the tree, this can lead to
1249        # bazel's header check failing.
1250        cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
1251
1252        file_ext = ".exe" if is_windows(repository_ctx) else ""
1253        nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext)
1254        cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc"
1255        cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files"
1256
1257        wrapper_defines = {
1258            "%{cpu_compiler}": str(cc),
1259            "%{cuda_version}": cuda_config.cuda_version,
1260            "%{nvcc_path}": nvcc_path,
1261            "%{gcc_host_compiler_path}": str(cc),
1262            "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
1263        }
1264        repository_ctx.template(
1265            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1266            tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"],
1267            wrapper_defines,
1268        )
1269        repository_ctx.template(
1270            "crosstool/windows/msvc_wrapper_for_nvcc.py",
1271            tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"],
1272            wrapper_defines,
1273        )
1274
1275    cuda_defines.update(_get_win_cuda_defines(repository_ctx))
1276
1277    verify_build_defines(cuda_defines)
1278
1279    # Only expand template variables in the BUILD file
1280    repository_ctx.template(
1281        "crosstool/BUILD",
1282        tpl_paths["crosstool:BUILD"],
1283        cuda_defines,
1284    )
1285
1286    # No templating of cc_toolchain_config - use attributes and templatize the
1287    # BUILD file.
1288    repository_ctx.template(
1289        "crosstool/cc_toolchain_config.bzl",
1290        tpl_paths["crosstool:cc_toolchain_config.bzl"],
1291        {},
1292    )
1293
1294    # Set up cuda_config.h, which is used by
1295    # tensorflow/stream_executor/dso_loader.cc.
1296    repository_ctx.template(
1297        "cuda/cuda/cuda_config.h",
1298        tpl_paths["cuda:cuda_config.h"],
1299        {
1300            "%{cuda_version}": cuda_config.cuda_version,
1301            "%{cudart_version}": cuda_config.cudart_version,
1302            "%{cublas_version}": cuda_config.cublas_version,
1303            "%{cusolver_version}": cuda_config.cusolver_version,
1304            "%{curand_version}": cuda_config.curand_version,
1305            "%{cufft_version}": cuda_config.cufft_version,
1306            "%{cusparse_version}": cuda_config.cusparse_version,
1307            "%{cudnn_version}": cuda_config.cudnn_version,
1308            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
1309            "%{cuda_compute_capabilities}": ", ".join([
1310                cc.split("_")[1]
1311                for cc in cuda_config.compute_capabilities
1312            ]),
1313        },
1314    )
1315
1316    # Set up cuda_config.py, which is used by gen_build_info to provide
1317    # static build environment info to the API
1318    repository_ctx.template(
1319        "cuda/cuda/cuda_config.py",
1320        tpl_paths["cuda:cuda_config.py"],
1321        _py_tmpl_dict({
1322            "cuda_version": cuda_config.cuda_version,
1323            "cudnn_version": cuda_config.cudnn_version,
1324            "cuda_compute_capabilities": cuda_config.compute_capabilities,
1325            "cpu_compiler": str(cc),
1326        }),
1327    )
1328
1329def _py_tmpl_dict(d):
1330    return {"%{cuda_config}": str(d)}
1331
1332def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
1333    """Creates pointers to a remotely configured repo set up to build with CUDA."""
1334    _tpl(
1335        repository_ctx,
1336        "cuda:build_defs.bzl",
1337        {
1338            "%{cuda_is_configured}": "True",
1339            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1340                repository_ctx,
1341                compute_capabilities(repository_ctx),
1342            ),
1343        },
1344    )
1345    repository_ctx.template(
1346        "cuda/BUILD",
1347        config_repo_label(remote_config_repo, "cuda:BUILD"),
1348        {},
1349    )
1350    repository_ctx.template(
1351        "cuda/build_defs.bzl",
1352        config_repo_label(remote_config_repo, "cuda:build_defs.bzl"),
1353        {},
1354    )
1355    repository_ctx.template(
1356        "cuda/cuda/cuda_config.h",
1357        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.h"),
1358        {},
1359    )
1360    repository_ctx.template(
1361        "cuda/cuda/cuda_config.py",
1362        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.py"),
1363        _py_tmpl_dict({}),
1364    )
1365
1366    repository_ctx.template(
1367        "crosstool/BUILD",
1368        config_repo_label(remote_config_repo, "crosstool:BUILD"),
1369        {},
1370    )
1371
1372    repository_ctx.template(
1373        "crosstool/cc_toolchain_config.bzl",
1374        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
1375        {},
1376    )
1377
1378    repository_ctx.template(
1379        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1380        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
1381        {},
1382    )
1383
1384def _cuda_autoconf_impl(repository_ctx):
1385    """Implementation of the cuda_autoconf repository rule."""
1386    build_file = Label("//third_party/gpus:local_config_cuda.BUILD")
1387
1388    if not enable_cuda(repository_ctx):
1389        _create_dummy_repository(repository_ctx)
1390    elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
1391        has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
1392        has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
1393        if not has_cuda_version or not has_cudnn_version:
1394            auto_configure_fail("%s and %s must also be set if %s is specified" %
1395                                (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
1396        _create_remote_cuda_repository(
1397            repository_ctx,
1398            get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
1399        )
1400    else:
1401        _create_local_cuda_repository(repository_ctx)
1402
1403    repository_ctx.symlink(build_file, "BUILD")
1404
1405# For @bazel_tools//tools/cpp:windows_cc_configure.bzl
1406_MSVC_ENVVARS = [
1407    "BAZEL_VC",
1408    "BAZEL_VC_FULL_VERSION",
1409    "BAZEL_VS",
1410    "BAZEL_WINSDK_FULL_VERSION",
1411    "VS90COMNTOOLS",
1412    "VS100COMNTOOLS",
1413    "VS110COMNTOOLS",
1414    "VS120COMNTOOLS",
1415    "VS140COMNTOOLS",
1416    "VS150COMNTOOLS",
1417    "VS160COMNTOOLS",
1418]
1419
1420_ENVIRONS = [
1421    _GCC_HOST_COMPILER_PATH,
1422    _GCC_HOST_COMPILER_PREFIX,
1423    _CLANG_CUDA_COMPILER_PATH,
1424    "TF_NEED_CUDA",
1425    "TF_CUDA_CLANG",
1426    _TF_DOWNLOAD_CLANG,
1427    _CUDA_TOOLKIT_PATH,
1428    _CUDNN_INSTALL_PATH,
1429    _TF_CUDA_VERSION,
1430    _TF_CUDNN_VERSION,
1431    _TF_CUDA_COMPUTE_CAPABILITIES,
1432    "NVVMIR_LIBRARY_DIR",
1433    _PYTHON_BIN_PATH,
1434    "TMP",
1435    "TMPDIR",
1436    "TF_CUDA_PATHS",
1437] + _MSVC_ENVVARS
1438
1439remote_cuda_configure = repository_rule(
1440    implementation = _create_local_cuda_repository,
1441    environ = _ENVIRONS,
1442    remotable = True,
1443    attrs = {
1444        "environ": attr.string_dict(),
1445    },
1446)
1447
1448cuda_configure = repository_rule(
1449    implementation = _cuda_autoconf_impl,
1450    environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
1451)
1452"""Detects and configures the local CUDA toolchain.
1453
1454Add the following to your WORKSPACE FILE:
1455
1456```python
1457cuda_configure(name = "local_config_cuda")
1458```
1459
1460Args:
1461  name: A unique name for this workspace rule.
1462"""
1463