• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Sets up a repository for using the system-provided TensorFlow."""
15
16def _process_compile_flags(repository_ctx, python3, headers_dir):
17    """Processes compilation flags required by the system-provided TF package.
18
19    The tf.sysconfig module provides the compilation flags that should be used
20    for custom operators. These will include the directory containing the
21    TensorFlow C++ headers ("-I/some/path") and possibly other flags (e.g.,
22    "-DSOME_FLAG=2").
23
24    A symlink is created from `headers_dir` to the directory containing the C++
25    headers. The list of other flags is returned.
26    """
27    result = repository_ctx.execute([
28        python3,
29        "-c",
30        ";".join([
31            "import tensorflow as tf",
32            "print('\\0'.join(tf.sysconfig.get_compile_flags()))",
33        ]),
34    ])
35    if result.return_code != 0:
36        fail("Failed to determine TensorFlow compile flags; is TensorFlow installed?")
37    include_dir = None
38    copts = []
39    cxxopts = []
40    for flag in result.stdout.strip().split("\0"):
41        if flag.startswith("-I"):
42            if include_dir != None:
43                fail("Only one TensorFlow headers directory is supported.")
44            include_dir = flag[2:]
45        elif flag.startswith("--std=c++"):  # Don't add C++-only flags to copts.
46            cxxopts.append(flag)
47        else:
48            copts.append(flag)
49
50    if not include_dir:
51        fail("Unable to find TensorFlow headers directory.")
52    repository_ctx.symlink(include_dir, headers_dir)
53
54    return copts, cxxopts
55
56def _process_link_flags(repository_ctx, python3, library_file):
57    """Processes linker flags required by the system-provided TF package.
58
59    The tf.sysconfig module provides the linker flags that should be used
60    for custom operators. These will include the directory containing
61    libtensorflow_framework.so ("-L/some/path"), the library to link
62    ("-l:libtensorflow_framework.so.2"), and possibly other flags.
63
64    A symlink is created from `library_file` to libtensorflow_framework.so. The
65    list of other flags is returned.
66    """
67    result = repository_ctx.execute([
68        python3,
69        "-c",
70        ";".join([
71            "import tensorflow as tf",
72            "print('\\0'.join(tf.sysconfig.get_link_flags()))",
73        ]),
74    ])
75    if result.return_code != 0:
76        fail("Failed to determine TensorFlow link flags; is TensorFlow installed?")
77    link_dir = None
78    library = None
79    linkopts = []
80    for flag in result.stdout.strip().split("\0"):
81        if flag.startswith("-L"):
82            if link_dir != None:
83                fail("Only one TensorFlow libraries directory is supported.")
84            link_dir = flag[2:]
85        elif flag.startswith("-l"):
86            if library != None:
87                fail("Only one TensorFlow library is supported.")
88
89            # "-l" may be followed by ":" to force the linker to use exact
90            # library name resolution.
91            library = flag[2:].lstrip(":")
92        else:
93            linkopts.append(flag)
94
95    if not link_dir or not library:
96        fail("Unable to find TensorFlow library.")
97    repository_ctx.symlink(link_dir + "/" + library, library_file)
98
99    return linkopts
100
101def _tf_custom_op_configure_impl(repository_ctx):
102    """Defines a repository for using the system-provided TensorFlow package.
103
104    This is a lot like new_local_repository except that (a) the files to
105    include are dynamically determined using TensorFlow's `tf.sysconfig` Python
106    module, and (b) it provides build rules to compile and link C++ code with
107    the necessary options to be compatible with the system-provided TensorFlow
108    package.
109    """
110    python3 = repository_ctx.os.environ.get("PYTHON_BIN_PATH", "python3")
111
112    # Name of the sub-directory that will link to TensorFlow C++ headers.
113    headers_dir = "headers"
114
115    # Name of the file that will link to libtensorflow_framework.so.
116    library_file = "libtensorflow_framework.so"
117
118    copts, cxxopts = _process_compile_flags(repository_ctx, python3, headers_dir)
119    linkopts = _process_link_flags(repository_ctx, python3, library_file)
120
121    # Create a BUILD file providing targets for the TensorFlow C++ headers and
122    # framework library.
123    repository_ctx.template(
124        "BUILD",
125        Label("//fcp/tensorflow/system_provided_tf:templates/BUILD.tpl"),
126        substitutions = {
127            "%{HEADERS_DIR}": headers_dir,
128            "%{LIBRARY_FILE}": library_file,
129        },
130        executable = False,
131    )
132
133    # Create a bzl file providing rules for compiling C++ code compatible with
134    # the TensorFlow package.
135    repository_ctx.template(
136        "system_provided_tf.bzl",
137        Label("//fcp/tensorflow/system_provided_tf:templates/system_provided_tf.bzl.tpl"),
138        substitutions = {
139            "%{COPTS}": str(copts),
140            "%{CXXOPTS}": str(cxxopts),
141            "%{LINKOPTS}": str(linkopts),
142            "%{REPOSITORY_NAME}": repository_ctx.name,
143        },
144        executable = False,
145    )
146
147system_provided_tf = repository_rule(
148    implementation = _tf_custom_op_configure_impl,
149    configure = True,
150    doc = """Creates a repository with targets for the system-provided TensorFlow.
151
152This repository defines (a) //:tf_headers providing the C++ TensorFlow headers,
153(b) //:libtensorflow_framework providing the TensorFlow framework shared
154library, and (c) //:system_provided_tf.bzl for building custom op libraries
155that are compatible with the system-provided TensorFlow package.
156""",
157    environ = [
158        "PYTHON_BIN_PATH",
159    ],
160    local = True,
161)
162