1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Sets up a repository for using the system-provided TensorFlow.""" 15 16def _process_compile_flags(repository_ctx, python3, headers_dir): 17 """Processes compilation flags required by the system-provided TF package. 18 19 The tf.sysconfig module provides the compilation flags that should be used 20 for custom operators. These will include the directory containing the 21 TensorFlow C++ headers ("-I/some/path") and possibly other flags (e.g., 22 "-DSOME_FLAG=2"). 23 24 A symlink is created from `headers_dir` to the directory containing the C++ 25 headers. The list of other flags is returned. 26 """ 27 result = repository_ctx.execute([ 28 python3, 29 "-c", 30 ";".join([ 31 "import tensorflow as tf", 32 "print('\\0'.join(tf.sysconfig.get_compile_flags()))", 33 ]), 34 ]) 35 if result.return_code != 0: 36 fail("Failed to determine TensorFlow compile flags; is TensorFlow installed?") 37 include_dir = None 38 copts = [] 39 cxxopts = [] 40 for flag in result.stdout.strip().split("\0"): 41 if flag.startswith("-I"): 42 if include_dir != None: 43 fail("Only one TensorFlow headers directory is supported.") 44 include_dir = flag[2:] 45 elif flag.startswith("--std=c++"): # Don't add C++-only flags to copts. 46 cxxopts.append(flag) 47 else: 48 copts.append(flag) 49 50 if not include_dir: 51 fail("Unable to find TensorFlow headers directory.") 52 repository_ctx.symlink(include_dir, headers_dir) 53 54 return copts, cxxopts 55 56def _process_link_flags(repository_ctx, python3, library_file): 57 """Processes linker flags required by the system-provided TF package. 58 59 The tf.sysconfig module provides the linker flags that should be used 60 for custom operators. These will include the directory containing 61 libtensorflow_framework.so ("-L/some/path"), the library to link 62 ("-l:libtensorflow_framework.so.2"), and possibly other flags. 63 64 A symlink is created from `library_file` to libtensorflow_framework.so. The 65 list of other flags is returned. 66 """ 67 result = repository_ctx.execute([ 68 python3, 69 "-c", 70 ";".join([ 71 "import tensorflow as tf", 72 "print('\\0'.join(tf.sysconfig.get_link_flags()))", 73 ]), 74 ]) 75 if result.return_code != 0: 76 fail("Failed to determine TensorFlow link flags; is TensorFlow installed?") 77 link_dir = None 78 library = None 79 linkopts = [] 80 for flag in result.stdout.strip().split("\0"): 81 if flag.startswith("-L"): 82 if link_dir != None: 83 fail("Only one TensorFlow libraries directory is supported.") 84 link_dir = flag[2:] 85 elif flag.startswith("-l"): 86 if library != None: 87 fail("Only one TensorFlow library is supported.") 88 89 # "-l" may be followed by ":" to force the linker to use exact 90 # library name resolution. 91 library = flag[2:].lstrip(":") 92 else: 93 linkopts.append(flag) 94 95 if not link_dir or not library: 96 fail("Unable to find TensorFlow library.") 97 repository_ctx.symlink(link_dir + "/" + library, library_file) 98 99 return linkopts 100 101def _tf_custom_op_configure_impl(repository_ctx): 102 """Defines a repository for using the system-provided TensorFlow package. 103 104 This is a lot like new_local_repository except that (a) the files to 105 include are dynamically determined using TensorFlow's `tf.sysconfig` Python 106 module, and (b) it provides build rules to compile and link C++ code with 107 the necessary options to be compatible with the system-provided TensorFlow 108 package. 109 """ 110 python3 = repository_ctx.os.environ.get("PYTHON_BIN_PATH", "python3") 111 112 # Name of the sub-directory that will link to TensorFlow C++ headers. 113 headers_dir = "headers" 114 115 # Name of the file that will link to libtensorflow_framework.so. 116 library_file = "libtensorflow_framework.so" 117 118 copts, cxxopts = _process_compile_flags(repository_ctx, python3, headers_dir) 119 linkopts = _process_link_flags(repository_ctx, python3, library_file) 120 121 # Create a BUILD file providing targets for the TensorFlow C++ headers and 122 # framework library. 123 repository_ctx.template( 124 "BUILD", 125 Label("//fcp/tensorflow/system_provided_tf:templates/BUILD.tpl"), 126 substitutions = { 127 "%{HEADERS_DIR}": headers_dir, 128 "%{LIBRARY_FILE}": library_file, 129 }, 130 executable = False, 131 ) 132 133 # Create a bzl file providing rules for compiling C++ code compatible with 134 # the TensorFlow package. 135 repository_ctx.template( 136 "system_provided_tf.bzl", 137 Label("//fcp/tensorflow/system_provided_tf:templates/system_provided_tf.bzl.tpl"), 138 substitutions = { 139 "%{COPTS}": str(copts), 140 "%{CXXOPTS}": str(cxxopts), 141 "%{LINKOPTS}": str(linkopts), 142 "%{REPOSITORY_NAME}": repository_ctx.name, 143 }, 144 executable = False, 145 ) 146 147system_provided_tf = repository_rule( 148 implementation = _tf_custom_op_configure_impl, 149 configure = True, 150 doc = """Creates a repository with targets for the system-provided TensorFlow. 151 152This repository defines (a) //:tf_headers providing the C++ TensorFlow headers, 153(b) //:libtensorflow_framework providing the TensorFlow framework shared 154library, and (c) //:system_provided_tf.bzl for building custom op libraries 155that are compatible with the system-provided TensorFlow package. 156""", 157 environ = [ 158 "PYTHON_BIN_PATH", 159 ], 160 local = True, 161) 162