1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""configure script to get build parameters from user.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import errno 23import os 24import platform 25import re 26import subprocess 27import sys 28 29# pylint: disable=g-import-not-at-top 30try: 31 from shutil import which 32except ImportError: 33 from distutils.spawn import find_executable as which 34# pylint: enable=g-import-not-at-top 35 36_DEFAULT_CUDA_VERSION = '9.0' 37_DEFAULT_CUDNN_VERSION = '7' 38_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' 39_DEFAULT_CUDA_PATH = '/usr/local/cuda' 40_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' 41_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 42 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) 43_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' 44_TF_OPENCL_VERSION = '1.2' 45_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' 46_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' 47_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] 48 49_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 50 51_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) 52_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' 53_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) 54_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') 55 56 57class UserInputError(Exception): 58 pass 59 60 61def is_windows(): 62 return platform.system() == 'Windows' 63 64 65def is_linux(): 66 return platform.system() == 'Linux' 67 68 69def is_macos(): 70 return platform.system() == 'Darwin' 71 72 73def is_ppc64le(): 74 return platform.machine() == 'ppc64le' 75 76 77def is_cygwin(): 78 return platform.system().startswith('CYGWIN_NT') 79 80 81def get_input(question): 82 try: 83 try: 84 answer = raw_input(question) 85 except NameError: 86 answer = input(question) # pylint: disable=bad-builtin 87 except EOFError: 88 answer = '' 89 return answer 90 91 92def symlink_force(target, link_name): 93 """Force symlink, equivalent of 'ln -sf'. 94 95 Args: 96 target: items to link to. 97 link_name: name of the link. 98 """ 99 try: 100 os.symlink(target, link_name) 101 except OSError as e: 102 if e.errno == errno.EEXIST: 103 os.remove(link_name) 104 os.symlink(target, link_name) 105 else: 106 raise e 107 108 109def sed_in_place(filename, old, new): 110 """Replace old string with new string in file. 111 112 Args: 113 filename: string for filename. 114 old: string to replace. 115 new: new string to replace to. 116 """ 117 with open(filename, 'r') as f: 118 filedata = f.read() 119 newdata = filedata.replace(old, new) 120 with open(filename, 'w') as f: 121 f.write(newdata) 122 123 124def write_to_bazelrc(line): 125 with open(_TF_BAZELRC, 'a') as f: 126 f.write(line + '\n') 127 128 129def write_action_env_to_bazelrc(var_name, var): 130 write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) 131 132 133def run_shell(cmd, allow_non_zero=False): 134 if allow_non_zero: 135 try: 136 output = subprocess.check_output(cmd) 137 except subprocess.CalledProcessError as e: 138 output = e.output 139 else: 140 output = subprocess.check_output(cmd) 141 return output.decode('UTF-8').strip() 142 143 144def cygpath(path): 145 """Convert path from posix to windows.""" 146 return os.path.abspath(path).replace('\\', '/') 147 148 149def get_python_path(environ_cp, python_bin_path): 150 """Get the python site package paths.""" 151 python_paths = [] 152 if environ_cp.get('PYTHONPATH'): 153 python_paths = environ_cp.get('PYTHONPATH').split(':') 154 try: 155 library_paths = run_shell( 156 [python_bin_path, '-c', 157 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') 158 except subprocess.CalledProcessError: 159 library_paths = [run_shell( 160 [python_bin_path, '-c', 161 'from distutils.sysconfig import get_python_lib;' 162 'print(get_python_lib())'])] 163 164 all_paths = set(python_paths + library_paths) 165 166 paths = [] 167 for path in all_paths: 168 if os.path.isdir(path): 169 paths.append(path) 170 return paths 171 172 173def get_python_major_version(python_bin_path): 174 """Get the python major version.""" 175 return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])']) 176 177 178def setup_python(environ_cp): 179 """Setup python related env variables.""" 180 # Get PYTHON_BIN_PATH, default is the current running python. 181 default_python_bin_path = sys.executable 182 ask_python_bin_path = ('Please specify the location of python. [Default is ' 183 '%s]: ') % default_python_bin_path 184 while True: 185 python_bin_path = get_from_env_or_user_or_default( 186 environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, 187 default_python_bin_path) 188 # Check if the path is valid 189 if os.path.isfile(python_bin_path) and os.access( 190 python_bin_path, os.X_OK): 191 break 192 elif not os.path.exists(python_bin_path): 193 print('Invalid python path: %s cannot be found.' % python_bin_path) 194 else: 195 print('%s is not executable. Is it the python binary?' % python_bin_path) 196 environ_cp['PYTHON_BIN_PATH'] = '' 197 198 # Convert python path to Windows style before checking lib and version 199 if is_windows() or is_cygwin(): 200 python_bin_path = cygpath(python_bin_path) 201 202 # Get PYTHON_LIB_PATH 203 python_lib_path = environ_cp.get('PYTHON_LIB_PATH') 204 if not python_lib_path: 205 python_lib_paths = get_python_path(environ_cp, python_bin_path) 206 if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1': 207 python_lib_path = python_lib_paths[0] 208 else: 209 print('Found possible Python library paths:\n %s' % 210 '\n '.join(python_lib_paths)) 211 default_python_lib_path = python_lib_paths[0] 212 python_lib_path = get_input( 213 'Please input the desired Python library path to use. ' 214 'Default is [%s]\n' % python_lib_paths[0]) 215 if not python_lib_path: 216 python_lib_path = default_python_lib_path 217 environ_cp['PYTHON_LIB_PATH'] = python_lib_path 218 219 python_major_version = get_python_major_version(python_bin_path) 220 221 # Convert python path to Windows style before writing into bazel.rc 222 if is_windows() or is_cygwin(): 223 python_lib_path = cygpath(python_lib_path) 224 225 # Set-up env variables used by python_configure.bzl 226 write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) 227 write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) 228 write_to_bazelrc('build --force_python=py%s' % python_major_version) 229 write_to_bazelrc('build --host_force_python=py%s' % python_major_version) 230 write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) 231 environ_cp['PYTHON_BIN_PATH'] = python_bin_path 232 233 # Write tools/python_bin_path.sh 234 with open(os.path.join( 235 _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: 236 f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) 237 238 239def reset_tf_configure_bazelrc(workspace_path): 240 """Reset file that contains customized config settings.""" 241 open(_TF_BAZELRC, 'w').close() 242 bazelrc_path = os.path.join(workspace_path, '.bazelrc') 243 244 data = [] 245 if os.path.exists(bazelrc_path): 246 with open(bazelrc_path, 'r') as f: 247 data = f.read().splitlines() 248 with open(bazelrc_path, 'w') as f: 249 for l in data: 250 if _TF_BAZELRC_FILENAME in l: 251 continue 252 f.write('%s\n' % l) 253 f.write('import %s\n' % _TF_BAZELRC) 254 255 256def cleanup_makefile(): 257 """Delete any leftover BUILD files from the Makefile build. 258 259 These files could interfere with Bazel parsing. 260 """ 261 makefile_download_dir = os.path.join( 262 _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') 263 if os.path.isdir(makefile_download_dir): 264 for root, _, filenames in os.walk(makefile_download_dir): 265 for f in filenames: 266 if f.endswith('BUILD'): 267 os.remove(os.path.join(root, f)) 268 269 270def get_var(environ_cp, 271 var_name, 272 query_item, 273 enabled_by_default, 274 question=None, 275 yes_reply=None, 276 no_reply=None): 277 """Get boolean input from user. 278 279 If var_name is not set in env, ask user to enable query_item or not. If the 280 response is empty, use the default. 281 282 Args: 283 environ_cp: copy of the os.environ. 284 var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". 285 query_item: string for feature related to the variable, e.g. "Hadoop File 286 System". 287 enabled_by_default: boolean for default behavior. 288 question: optional string for how to ask for user input. 289 yes_reply: optional string for reply when feature is enabled. 290 no_reply: optional string for reply when feature is disabled. 291 292 Returns: 293 boolean value of the variable. 294 295 Raises: 296 UserInputError: if an environment variable is set, but it cannot be 297 interpreted as a boolean indicator, assume that the user has made a 298 scripting error, and will continue to provide invalid input. 299 Raise the error to avoid infinitely looping. 300 """ 301 if not question: 302 question = 'Do you wish to build TensorFlow with %s support?' % query_item 303 if not yes_reply: 304 yes_reply = '%s support will be enabled for TensorFlow.' % query_item 305 if not no_reply: 306 no_reply = 'No %s' % yes_reply 307 308 yes_reply += '\n' 309 no_reply += '\n' 310 311 if enabled_by_default: 312 question += ' [Y/n]: ' 313 else: 314 question += ' [y/N]: ' 315 316 var = environ_cp.get(var_name) 317 if var is not None: 318 var_content = var.strip().lower() 319 true_strings = ('1', 't', 'true', 'y', 'yes') 320 false_strings = ('0', 'f', 'false', 'n', 'no') 321 if var_content in true_strings: 322 var = True 323 elif var_content in false_strings: 324 var = False 325 else: 326 raise UserInputError( 327 'Environment variable %s must be set as a boolean indicator.\n' 328 'The following are accepted as TRUE : %s.\n' 329 'The following are accepted as FALSE: %s.\n' 330 'Current value is %s.' % ( 331 var_name, ', '.join(true_strings), ', '.join(false_strings), 332 var)) 333 334 while var is None: 335 user_input_origin = get_input(question) 336 user_input = user_input_origin.strip().lower() 337 if user_input == 'y': 338 print(yes_reply) 339 var = True 340 elif user_input == 'n': 341 print(no_reply) 342 var = False 343 elif not user_input: 344 if enabled_by_default: 345 print(yes_reply) 346 var = True 347 else: 348 print(no_reply) 349 var = False 350 else: 351 print('Invalid selection: %s' % user_input_origin) 352 return var 353 354 355def set_build_var(environ_cp, var_name, query_item, option_name, 356 enabled_by_default, bazel_config_name=None): 357 """Set if query_item will be enabled for the build. 358 359 Ask user if query_item will be enabled. Default is used if no input is given. 360 Set subprocess environment variable and write to .bazelrc if enabled. 361 362 Args: 363 environ_cp: copy of the os.environ. 364 var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". 365 query_item: string for feature related to the variable, e.g. "Hadoop File 366 System". 367 option_name: string for option to define in .bazelrc. 368 enabled_by_default: boolean for default behavior. 369 bazel_config_name: Name for Bazel --config argument to enable build feature. 370 """ 371 372 var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) 373 environ_cp[var_name] = var 374 if var == '1': 375 write_to_bazelrc('build --define %s=true' % option_name) 376 elif bazel_config_name is not None: 377 # TODO(mikecase): Migrate all users of configure.py to use --config Bazel 378 # options and not to set build configs through environment variables. 379 write_to_bazelrc('build:%s --define %s=true' 380 % (bazel_config_name, option_name)) 381 382 383def set_action_env_var(environ_cp, 384 var_name, 385 query_item, 386 enabled_by_default, 387 question=None, 388 yes_reply=None, 389 no_reply=None): 390 """Set boolean action_env variable. 391 392 Ask user if query_item will be enabled. Default is used if no input is given. 393 Set environment variable and write to .bazelrc. 394 395 Args: 396 environ_cp: copy of the os.environ. 397 var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". 398 query_item: string for feature related to the variable, e.g. "Hadoop File 399 System". 400 enabled_by_default: boolean for default behavior. 401 question: optional string for how to ask for user input. 402 yes_reply: optional string for reply when feature is enabled. 403 no_reply: optional string for reply when feature is disabled. 404 """ 405 var = int( 406 get_var(environ_cp, var_name, query_item, enabled_by_default, question, 407 yes_reply, no_reply)) 408 409 write_action_env_to_bazelrc(var_name, var) 410 environ_cp[var_name] = str(var) 411 412 413def convert_version_to_int(version): 414 """Convert a version number to a integer that can be used to compare. 415 416 Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The 417 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored. 418 419 Args: 420 version: a version to be converted 421 422 Returns: 423 An integer if converted successfully, otherwise return None. 424 """ 425 version = version.split('-')[0] 426 version_segments = version.split('.') 427 for seg in version_segments: 428 if not seg.isdigit(): 429 return None 430 431 version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) 432 return int(version_str) 433 434 435def check_bazel_version(min_version): 436 """Check installed bazel version is at least min_version. 437 438 Args: 439 min_version: string for minimum bazel version. 440 441 Returns: 442 The bazel version detected. 443 """ 444 if which('bazel') is None: 445 print('Cannot find bazel. Please install bazel.') 446 sys.exit(0) 447 curr_version = run_shell(['bazel', '--batch', 'version']) 448 449 for line in curr_version.split('\n'): 450 if 'Build label: ' in line: 451 curr_version = line.split('Build label: ')[1] 452 break 453 454 min_version_int = convert_version_to_int(min_version) 455 curr_version_int = convert_version_to_int(curr_version) 456 457 # Check if current bazel version can be detected properly. 458 if not curr_version_int: 459 print('WARNING: current bazel installation is not a release version.') 460 print('Make sure you are running at least bazel %s' % min_version) 461 return curr_version 462 463 print('You have bazel %s installed.' % curr_version) 464 465 if curr_version_int < min_version_int: 466 print('Please upgrade your bazel installation to version %s or higher to ' 467 'build TensorFlow!' % min_version) 468 sys.exit(0) 469 return curr_version 470 471 472def set_cc_opt_flags(environ_cp): 473 """Set up architecture-dependent optimization flags. 474 475 Also append CC optimization flags to bazel.rc.. 476 477 Args: 478 environ_cp: copy of the os.environ. 479 """ 480 if is_ppc64le(): 481 # gcc on ppc64le does not support -march, use mcpu instead 482 default_cc_opt_flags = '-mcpu=native' 483 else: 484 default_cc_opt_flags = '-march=native' 485 question = ('Please specify optimization flags to use during compilation when' 486 ' bazel option "--config=opt" is specified [Default is %s]: ' 487 ) % default_cc_opt_flags 488 cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', 489 question, default_cc_opt_flags) 490 for opt in cc_opt_flags.split(): 491 write_to_bazelrc('build:opt --copt=%s' % opt) 492 # It should be safe on the same build host. 493 if not is_ppc64le(): 494 write_to_bazelrc('build:opt --host_copt=-march=native') 495 write_to_bazelrc('build:opt --define with_default_optimizations=true') 496 # TODO(mikecase): Remove these default defines once we are able to get 497 # TF Lite targets building without them. 498 write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') 499 write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') 500 501 502def set_tf_cuda_clang(environ_cp): 503 """set TF_CUDA_CLANG action_env. 504 505 Args: 506 environ_cp: copy of the os.environ. 507 """ 508 question = 'Do you want to use clang as CUDA compiler?' 509 yes_reply = 'Clang will be used as CUDA compiler.' 510 no_reply = 'nvcc will be used as CUDA compiler.' 511 set_action_env_var( 512 environ_cp, 513 'TF_CUDA_CLANG', 514 None, 515 False, 516 question=question, 517 yes_reply=yes_reply, 518 no_reply=no_reply) 519 520 521def set_tf_download_clang(environ_cp): 522 """Set TF_DOWNLOAD_CLANG action_env.""" 523 question = 'Do you want to download a fresh release of clang? (Experimental)' 524 yes_reply = 'Clang will be downloaded and used to compile tensorflow.' 525 no_reply = 'Clang will not be downloaded.' 526 set_action_env_var( 527 environ_cp, 528 'TF_DOWNLOAD_CLANG', 529 None, 530 False, 531 question=question, 532 yes_reply=yes_reply, 533 no_reply=no_reply) 534 535 536def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, 537 var_default): 538 """Get var_name either from env, or user or default. 539 540 If var_name has been set as environment variable, use the preset value, else 541 ask for user input. If no input is provided, the default is used. 542 543 Args: 544 environ_cp: copy of the os.environ. 545 var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". 546 ask_for_var: string for how to ask for user input. 547 var_default: default value string. 548 549 Returns: 550 string value for var_name 551 """ 552 var = environ_cp.get(var_name) 553 if not var: 554 var = get_input(ask_for_var) 555 print('\n') 556 if not var: 557 var = var_default 558 return var 559 560 561def set_clang_cuda_compiler_path(environ_cp): 562 """Set CLANG_CUDA_COMPILER_PATH.""" 563 default_clang_path = which('clang') or '' 564 ask_clang_path = ('Please specify which clang should be used as device and ' 565 'host compiler. [Default is %s]: ') % default_clang_path 566 567 while True: 568 clang_cuda_compiler_path = get_from_env_or_user_or_default( 569 environ_cp, 'CLANG_CUDA_COMPILER_PATH', ask_clang_path, 570 default_clang_path) 571 if os.path.exists(clang_cuda_compiler_path): 572 break 573 574 # Reset and retry 575 print('Invalid clang path: %s cannot be found.' % clang_cuda_compiler_path) 576 environ_cp['CLANG_CUDA_COMPILER_PATH'] = '' 577 578 # Set CLANG_CUDA_COMPILER_PATH 579 environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path 580 write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH', 581 clang_cuda_compiler_path) 582 583 584def prompt_loop_or_load_from_env( 585 environ_cp, 586 var_name, 587 var_default, 588 ask_for_var, 589 check_success, 590 error_msg, 591 suppress_default_error=False, 592 n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS 593): 594 """Loop over user prompts for an ENV param until receiving a valid response. 595 596 For the env param var_name, read from the environment or verify user input 597 until receiving valid input. When done, set var_name in the environ_cp to its 598 new value. 599 600 Args: 601 environ_cp: (Dict) copy of the os.environ. 602 var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". 603 var_default: (String) default value string. 604 ask_for_var: (String) string for how to ask for user input. 605 check_success: (Function) function that takes one argument and returns a 606 boolean. Should return True if the value provided is considered valid. May 607 contain a complex error message if error_msg does not provide enough 608 information. In that case, set suppress_default_error to True. 609 error_msg: (String) String with one and only one '%s'. Formatted with each 610 invalid response upon check_success(input) failure. 611 suppress_default_error: (Bool) Suppress the above error message in favor of 612 one from the check_success function. 613 n_ask_attempts: (Integer) Number of times to query for valid input before 614 raising an error and quitting. 615 616 Returns: 617 [String] The value of var_name after querying for input. 618 619 Raises: 620 UserInputError: if a query has been attempted n_ask_attempts times without 621 success, assume that the user has made a scripting error, and will 622 continue to provide invalid input. Raise the error to avoid infinitely 623 looping. 624 """ 625 default = environ_cp.get(var_name) or var_default 626 full_query = '%s [Default is %s]: ' % ( 627 ask_for_var, 628 default, 629 ) 630 631 for _ in range(n_ask_attempts): 632 val = get_from_env_or_user_or_default(environ_cp, 633 var_name, 634 full_query, 635 default) 636 if check_success(val): 637 break 638 if not suppress_default_error: 639 print(error_msg % val) 640 environ_cp[var_name] = '' 641 else: 642 raise UserInputError('Invalid %s setting was provided %d times in a row. ' 643 'Assuming to be a scripting mistake.' % 644 (var_name, n_ask_attempts)) 645 646 environ_cp[var_name] = val 647 return val 648 649 650def create_android_ndk_rule(environ_cp): 651 """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" 652 if is_windows() or is_cygwin(): 653 default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % 654 environ_cp['APPDATA']) 655 elif is_macos(): 656 default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] 657 else: 658 default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] 659 660 def valid_ndk_path(path): 661 return (os.path.exists(path) and 662 os.path.exists(os.path.join(path, 'source.properties'))) 663 664 android_ndk_home_path = prompt_loop_or_load_from_env( 665 environ_cp, 666 var_name='ANDROID_NDK_HOME', 667 var_default=default_ndk_path, 668 ask_for_var='Please specify the home path of the Android NDK to use.', 669 check_success=valid_ndk_path, 670 error_msg=('The path %s or its child file "source.properties" ' 671 'does not exist.') 672 ) 673 674 write_android_ndk_workspace_rule(android_ndk_home_path) 675 676 677def create_android_sdk_rule(environ_cp): 678 """Set Android variables and write Android SDK WORKSPACE rule.""" 679 if is_windows() or is_cygwin(): 680 default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) 681 elif is_macos(): 682 default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] 683 else: 684 default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] 685 686 def valid_sdk_path(path): 687 return (os.path.exists(path) and 688 os.path.exists(os.path.join(path, 'platforms')) and 689 os.path.exists(os.path.join(path, 'build-tools'))) 690 691 android_sdk_home_path = prompt_loop_or_load_from_env( 692 environ_cp, 693 var_name='ANDROID_SDK_HOME', 694 var_default=default_sdk_path, 695 ask_for_var='Please specify the home path of the Android SDK to use.', 696 check_success=valid_sdk_path, 697 error_msg=('Either %s does not exist, or it does not contain the ' 698 'subdirectories "platforms" and "build-tools".')) 699 700 platforms = os.path.join(android_sdk_home_path, 'platforms') 701 api_levels = sorted(os.listdir(platforms)) 702 api_levels = [x.replace('android-', '') for x in api_levels] 703 704 def valid_api_level(api_level): 705 return os.path.exists(os.path.join(android_sdk_home_path, 706 'platforms', 707 'android-' + api_level)) 708 709 android_api_level = prompt_loop_or_load_from_env( 710 environ_cp, 711 var_name='ANDROID_API_LEVEL', 712 var_default=api_levels[-1], 713 ask_for_var=('Please specify the Android SDK API level to use. ' 714 '[Available levels: %s]') % api_levels, 715 check_success=valid_api_level, 716 error_msg='Android-%s is not present in the SDK path.') 717 718 build_tools = os.path.join(android_sdk_home_path, 'build-tools') 719 versions = sorted(os.listdir(build_tools)) 720 721 def valid_build_tools(version): 722 return os.path.exists(os.path.join(android_sdk_home_path, 723 'build-tools', 724 version)) 725 726 android_build_tools_version = prompt_loop_or_load_from_env( 727 environ_cp, 728 var_name='ANDROID_BUILD_TOOLS_VERSION', 729 var_default=versions[-1], 730 ask_for_var=('Please specify an Android build tools version to use. ' 731 '[Available versions: %s]') % versions, 732 check_success=valid_build_tools, 733 error_msg=('The selected SDK does not have build-tools version %s ' 734 'available.')) 735 736 write_android_sdk_workspace_rule(android_sdk_home_path, 737 android_build_tools_version, 738 android_api_level) 739 740 741def write_android_sdk_workspace_rule(android_sdk_home_path, 742 android_build_tools_version, 743 android_api_level): 744 print('Writing android_sdk_workspace rule.\n') 745 with open(_TF_WORKSPACE, 'a') as f: 746 f.write(""" 747android_sdk_repository( 748 name="androidsdk", 749 api_level=%s, 750 path="%s", 751 build_tools_version="%s")\n 752""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) 753 754 755def write_android_ndk_workspace_rule(android_ndk_home_path): 756 print('Writing android_ndk_workspace rule.') 757 ndk_api_level = check_ndk_level(android_ndk_home_path) 758 if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: 759 print('WARNING: The API level of the NDK in %s is %s, which is not ' 760 'supported by Bazel (officially supported versions: %s). Please use ' 761 'another version. Compiling Android targets may result in confusing ' 762 'errors.\n' % (android_ndk_home_path, ndk_api_level, 763 _SUPPORTED_ANDROID_NDK_VERSIONS)) 764 with open(_TF_WORKSPACE, 'a') as f: 765 f.write(""" 766android_ndk_repository( 767 name="androidndk", 768 path="%s", 769 api_level=%s)\n 770""" % (android_ndk_home_path, ndk_api_level)) 771 772 773def check_ndk_level(android_ndk_home_path): 774 """Check the revision number of an Android NDK path.""" 775 properties_path = '%s/source.properties' % android_ndk_home_path 776 if is_windows() or is_cygwin(): 777 properties_path = cygpath(properties_path) 778 with open(properties_path, 'r') as f: 779 filedata = f.read() 780 781 revision = re.search(r'Pkg.Revision = (\d+)', filedata) 782 if revision: 783 return revision.group(1) 784 return None 785 786 787def workspace_has_any_android_rule(): 788 """Check the WORKSPACE for existing android_*_repository rules.""" 789 with open(_TF_WORKSPACE, 'r') as f: 790 workspace = f.read() 791 has_any_rule = re.search(r'^android_[ns]dk_repository', 792 workspace, 793 re.MULTILINE) 794 return has_any_rule 795 796 797def set_gcc_host_compiler_path(environ_cp): 798 """Set GCC_HOST_COMPILER_PATH.""" 799 default_gcc_host_compiler_path = which('gcc') or '' 800 cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') 801 802 if os.path.islink(cuda_bin_symlink): 803 # os.readlink is only available in linux 804 default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) 805 806 gcc_host_compiler_path = prompt_loop_or_load_from_env( 807 environ_cp, 808 var_name='GCC_HOST_COMPILER_PATH', 809 var_default=default_gcc_host_compiler_path, 810 ask_for_var= 811 'Please specify which gcc should be used by nvcc as the host compiler.', 812 check_success=os.path.exists, 813 error_msg='Invalid gcc path. %s cannot be found.', 814 ) 815 816 write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) 817 818 819def reformat_version_sequence(version_str, sequence_count): 820 """Reformat the version string to have the given number of sequences. 821 822 For example: 823 Given (7, 2) -> 7.0 824 (7.0.1, 2) -> 7.0 825 (5, 1) -> 5 826 (5.0.3.2, 1) -> 5 827 828 Args: 829 version_str: String, the version string. 830 sequence_count: int, an integer. 831 Returns: 832 string, reformatted version string. 833 """ 834 v = version_str.split('.') 835 if len(v) < sequence_count: 836 v = v + (['0'] * (sequence_count - len(v))) 837 838 return '.'.join(v[:sequence_count]) 839 840 841def set_tf_cuda_version(environ_cp): 842 """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" 843 ask_cuda_version = ( 844 'Please specify the CUDA SDK version you want to use, ' 845 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION 846 847 for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): 848 # Configure the Cuda SDK version to use. 849 tf_cuda_version = get_from_env_or_user_or_default( 850 environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) 851 tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2) 852 853 # Find out where the CUDA toolkit is installed 854 default_cuda_path = _DEFAULT_CUDA_PATH 855 if is_windows() or is_cygwin(): 856 default_cuda_path = cygpath( 857 environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN)) 858 elif is_linux(): 859 # If the default doesn't exist, try an alternative default. 860 if (not os.path.exists(default_cuda_path) 861 ) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX): 862 default_cuda_path = _DEFAULT_CUDA_PATH_LINUX 863 ask_cuda_path = ('Please specify the location where CUDA %s toolkit is' 864 ' installed. Refer to README.md for more details. ' 865 '[Default is %s]: ') % (tf_cuda_version, default_cuda_path) 866 cuda_toolkit_path = get_from_env_or_user_or_default( 867 environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path) 868 869 if is_windows(): 870 cuda_rt_lib_path = 'lib/x64/cudart.lib' 871 elif is_linux(): 872 cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version 873 elif is_macos(): 874 cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version 875 876 cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) 877 if os.path.exists(cuda_toolkit_path_full): 878 break 879 880 # Reset and retry 881 print('Invalid path to CUDA %s toolkit. %s cannot be found' % 882 (tf_cuda_version, cuda_toolkit_path_full)) 883 environ_cp['TF_CUDA_VERSION'] = '' 884 environ_cp['CUDA_TOOLKIT_PATH'] = '' 885 886 else: 887 raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' 888 'times in a row. Assuming to be a scripting mistake.' % 889 _DEFAULT_PROMPT_ASK_ATTEMPTS) 890 891 # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION 892 environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path 893 write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) 894 environ_cp['TF_CUDA_VERSION'] = tf_cuda_version 895 write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) 896 897 898def set_tf_cudnn_version(environ_cp): 899 """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" 900 ask_cudnn_version = ( 901 'Please specify the cuDNN version you want to use. ' 902 '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION 903 904 for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): 905 tf_cudnn_version = get_from_env_or_user_or_default( 906 environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, 907 _DEFAULT_CUDNN_VERSION) 908 tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) 909 910 default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') 911 ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' 912 'installed. Refer to README.md for more details. [Default' 913 ' is %s]:') % (tf_cudnn_version, default_cudnn_path) 914 cudnn_install_path = get_from_env_or_user_or_default( 915 environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path) 916 917 # Result returned from "read" will be used unexpanded. That make "~" 918 # unusable. Going through one more level of expansion to handle that. 919 cudnn_install_path = os.path.realpath( 920 os.path.expanduser(cudnn_install_path)) 921 if is_windows() or is_cygwin(): 922 cudnn_install_path = cygpath(cudnn_install_path) 923 924 if is_windows(): 925 cuda_dnn_lib_path = 'lib/x64/cudnn.lib' 926 cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib' 927 elif is_linux(): 928 cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version 929 cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version 930 elif is_macos(): 931 cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version 932 cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version 933 934 cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path) 935 cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path, 936 cuda_dnn_lib_alt_path) 937 if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists( 938 cuda_dnn_lib_alt_path_full): 939 break 940 941 # Try another alternative for Linux 942 if is_linux(): 943 ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' 944 cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) 945 cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)', 946 cudnn_path_from_ldconfig) 947 if cudnn_path_from_ldconfig: 948 cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) 949 if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, 950 tf_cudnn_version)): 951 cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) 952 break 953 954 # Reset and Retry 955 print( 956 'Invalid path to cuDNN %s toolkit. None of the following files can be ' 957 'found:' % tf_cudnn_version) 958 print(cuda_dnn_lib_path_full) 959 print(cuda_dnn_lib_alt_path_full) 960 if is_linux(): 961 print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) 962 963 environ_cp['TF_CUDNN_VERSION'] = '' 964 else: 965 raise UserInputError('Invalid TF_CUDNN setting was provided %d ' 966 'times in a row. Assuming to be a scripting mistake.' % 967 _DEFAULT_PROMPT_ASK_ATTEMPTS) 968 969 # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION 970 environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path 971 write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path) 972 environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version 973 write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) 974 975 976def set_tf_tensorrt_install_path(environ_cp): 977 """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. 978 979 Adapted from code contributed by Sami Kama (https://github.com/samikama). 980 981 Args: 982 environ_cp: copy of the os.environ. 983 984 Raises: 985 ValueError: if this method was called under non-Linux platform. 986 UserInputError: if user has provided invalid input multiple times. 987 """ 988 if not is_linux(): 989 raise ValueError('Currently TensorRT is only supported on Linux platform.') 990 991 # Ask user whether to add TensorRT support. 992 if str(int(get_var( 993 environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': 994 return 995 996 for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): 997 ask_tensorrt_path = (r'Please specify the location where TensorRT is ' 998 'installed. [Default is %s]:') % ( 999 _DEFAULT_TENSORRT_PATH_LINUX) 1000 trt_install_path = get_from_env_or_user_or_default( 1001 environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path, 1002 _DEFAULT_TENSORRT_PATH_LINUX) 1003 1004 # Result returned from "read" will be used unexpanded. That make "~" 1005 # unusable. Going through one more level of expansion to handle that. 1006 trt_install_path = os.path.realpath( 1007 os.path.expanduser(trt_install_path)) 1008 1009 def find_libs(search_path): 1010 """Search for libnvinfer.so in "search_path".""" 1011 fl = set() 1012 if os.path.exists(search_path) and os.path.isdir(search_path): 1013 fl.update([os.path.realpath(os.path.join(search_path, x)) 1014 for x in os.listdir(search_path) if 'libnvinfer.so' in x]) 1015 return fl 1016 1017 possible_files = find_libs(trt_install_path) 1018 possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) 1019 possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) 1020 1021 def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): 1022 """Check the compatibility between tensorrt and cudnn/cudart libraries.""" 1023 ldd_bin = which('ldd') or '/usr/bin/ldd' 1024 ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) 1025 cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') 1026 cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') 1027 cudnn = None 1028 cudart = None 1029 for line in ldd_out: 1030 if 'libcudnn.so' in line: 1031 cudnn = cudnn_pattern.search(line) 1032 elif 'libcudart.so' in line: 1033 cudart = cuda_pattern.search(line) 1034 if cudnn and len(cudnn.group(1)): 1035 cudnn = convert_version_to_int(cudnn.group(1)) 1036 if cudart and len(cudart.group(1)): 1037 cudart = convert_version_to_int(cudart.group(1)) 1038 return (cudnn == cudnn_ver) and (cudart == cuda_ver) 1039 1040 cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) 1041 cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) 1042 nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') 1043 highest_ver = [0, None, None] 1044 1045 for lib_file in possible_files: 1046 if is_compatible(lib_file, cuda_ver, cudnn_ver): 1047 ver_str = nvinfer_pattern.search(lib_file).group(1) 1048 ver = convert_version_to_int(ver_str) if len(ver_str) else 0 1049 if ver > highest_ver[0]: 1050 highest_ver = [ver, ver_str, lib_file] 1051 if highest_ver[1] is not None: 1052 trt_install_path = os.path.dirname(highest_ver[2]) 1053 tf_tensorrt_version = highest_ver[1] 1054 break 1055 1056 # Try another alternative from ldconfig. 1057 ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' 1058 ldconfig_output = run_shell([ldconfig_bin, '-p']) 1059 search_result = re.search( 1060 '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) 1061 if search_result: 1062 libnvinfer_path_from_ldconfig = search_result.group(2) 1063 if os.path.exists(libnvinfer_path_from_ldconfig): 1064 if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): 1065 trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) 1066 tf_tensorrt_version = search_result.group(1) 1067 break 1068 1069 # Reset and Retry 1070 if len(possible_files): 1071 print('TensorRT libraries found in one the following directories', 1072 'are not compatible with selected cuda and cudnn installations') 1073 print(trt_install_path) 1074 print(os.path.join(trt_install_path, 'lib')) 1075 print(os.path.join(trt_install_path, 'lib64')) 1076 if search_result: 1077 print(libnvinfer_path_from_ldconfig) 1078 else: 1079 print('Invalid path to TensorRT. None of the following files can be found:') 1080 print(trt_install_path) 1081 print(os.path.join(trt_install_path, 'lib')) 1082 print(os.path.join(trt_install_path, 'lib64')) 1083 if search_result: 1084 print(libnvinfer_path_from_ldconfig) 1085 1086 else: 1087 raise UserInputError('Invalid TF_TENSORRT setting was provided %d ' 1088 'times in a row. Assuming to be a scripting mistake.' % 1089 _DEFAULT_PROMPT_ASK_ATTEMPTS) 1090 1091 # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION 1092 environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path 1093 write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) 1094 environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version 1095 write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) 1096 1097 1098def get_native_cuda_compute_capabilities(environ_cp): 1099 """Get native cuda compute capabilities. 1100 1101 Args: 1102 environ_cp: copy of the os.environ. 1103 Returns: 1104 string of native cuda compute capabilities, separated by comma. 1105 """ 1106 device_query_bin = os.path.join( 1107 environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') 1108 if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK): 1109 try: 1110 output = run_shell(device_query_bin).split('\n') 1111 pattern = re.compile('[0-9]*\\.[0-9]*') 1112 output = [pattern.search(x) for x in output if 'Capability' in x] 1113 output = ','.join(x.group() for x in output if x is not None) 1114 except subprocess.CalledProcessError: 1115 output = '' 1116 else: 1117 output = '' 1118 return output 1119 1120 1121def set_tf_cuda_compute_capabilities(environ_cp): 1122 """Set TF_CUDA_COMPUTE_CAPABILITIES.""" 1123 while True: 1124 native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( 1125 environ_cp) 1126 if not native_cuda_compute_capabilities: 1127 default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES 1128 else: 1129 default_cuda_compute_capabilities = native_cuda_compute_capabilities 1130 1131 ask_cuda_compute_capabilities = ( 1132 'Please specify a list of comma-separated ' 1133 'Cuda compute capabilities you want to ' 1134 'build with.\nYou can find the compute ' 1135 'capability of your device at: ' 1136 'https://developer.nvidia.com/cuda-gpus.\nPlease' 1137 ' note that each additional compute ' 1138 'capability significantly increases your ' 1139 'build time and binary size. [Default is: %s]' % 1140 default_cuda_compute_capabilities) 1141 tf_cuda_compute_capabilities = get_from_env_or_user_or_default( 1142 environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', 1143 ask_cuda_compute_capabilities, default_cuda_compute_capabilities) 1144 # Check whether all capabilities from the input is valid 1145 all_valid = True 1146 for compute_capability in tf_cuda_compute_capabilities.split(','): 1147 m = re.match('[0-9]+.[0-9]+', compute_capability) 1148 if not m: 1149 print('Invalid compute capability: ' % compute_capability) 1150 all_valid = False 1151 else: 1152 ver = int(m.group(0).split('.')[0]) 1153 if ver < 3: 1154 print('Only compute capabilities 3.0 or higher are supported.') 1155 all_valid = False 1156 1157 if all_valid: 1158 break 1159 1160 # Reset and Retry 1161 environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' 1162 1163 # Set TF_CUDA_COMPUTE_CAPABILITIES 1164 environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities 1165 write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES', 1166 tf_cuda_compute_capabilities) 1167 1168 1169def set_other_cuda_vars(environ_cp): 1170 """Set other CUDA related variables.""" 1171 if is_windows(): 1172 # The following three variables are needed for MSVC toolchain configuration 1173 # in Bazel 1174 environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH') 1175 environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get( 1176 'TF_CUDA_COMPUTE_CAPABILITIES') 1177 environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1 1178 write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH')) 1179 write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE', 1180 environ_cp.get('CUDA_COMPUTE_CAPABILITIE')) 1181 write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION', 1182 environ_cp.get('NO_WHOLE_ARCHIVE_OPTION')) 1183 write_to_bazelrc('build --config=win-cuda') 1184 write_to_bazelrc('test --config=win-cuda') 1185 else: 1186 # If CUDA is enabled, always use GPU during build and test. 1187 if environ_cp.get('TF_CUDA_CLANG') == '1': 1188 write_to_bazelrc('build --config=cuda_clang') 1189 write_to_bazelrc('test --config=cuda_clang') 1190 else: 1191 write_to_bazelrc('build --config=cuda') 1192 write_to_bazelrc('test --config=cuda') 1193 1194 1195def set_host_cxx_compiler(environ_cp): 1196 """Set HOST_CXX_COMPILER.""" 1197 default_cxx_host_compiler = which('g++') or '' 1198 1199 host_cxx_compiler = prompt_loop_or_load_from_env( 1200 environ_cp, 1201 var_name='HOST_CXX_COMPILER', 1202 var_default=default_cxx_host_compiler, 1203 ask_for_var=('Please specify which C++ compiler should be used as the ' 1204 'host C++ compiler.'), 1205 check_success=os.path.exists, 1206 error_msg='Invalid C++ compiler path. %s cannot be found.', 1207 ) 1208 1209 write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) 1210 1211 1212def set_host_c_compiler(environ_cp): 1213 """Set HOST_C_COMPILER.""" 1214 default_c_host_compiler = which('gcc') or '' 1215 1216 host_c_compiler = prompt_loop_or_load_from_env( 1217 environ_cp, 1218 var_name='HOST_C_COMPILER', 1219 var_default=default_c_host_compiler, 1220 ask_for_var=('Please specify which C compiler should be used as the host ' 1221 'C compiler.'), 1222 check_success=os.path.exists, 1223 error_msg='Invalid C compiler path. %s cannot be found.', 1224 ) 1225 1226 write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) 1227 1228 1229def set_computecpp_toolkit_path(environ_cp): 1230 """Set COMPUTECPP_TOOLKIT_PATH.""" 1231 1232 def toolkit_exists(toolkit_path): 1233 """Check if a computecpp toolkit path is valid.""" 1234 if is_linux(): 1235 sycl_rt_lib_path = 'lib/libComputeCpp.so' 1236 else: 1237 sycl_rt_lib_path = '' 1238 1239 sycl_rt_lib_path_full = os.path.join(toolkit_path, 1240 sycl_rt_lib_path) 1241 exists = os.path.exists(sycl_rt_lib_path_full) 1242 if not exists: 1243 print('Invalid SYCL %s library path. %s cannot be found' % 1244 (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) 1245 return exists 1246 1247 computecpp_toolkit_path = prompt_loop_or_load_from_env( 1248 environ_cp, 1249 var_name='COMPUTECPP_TOOLKIT_PATH', 1250 var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, 1251 ask_for_var=( 1252 'Please specify the location where ComputeCpp for SYCL %s is ' 1253 'installed.' % _TF_OPENCL_VERSION), 1254 check_success=toolkit_exists, 1255 error_msg='Invalid SYCL compiler path. %s cannot be found.', 1256 suppress_default_error=True) 1257 1258 write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', 1259 computecpp_toolkit_path) 1260 1261 1262def set_trisycl_include_dir(environ_cp): 1263 """Set TRISYCL_INCLUDE_DIR.""" 1264 1265 ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 1266 'include directory. (Use --config=sycl_trisycl ' 1267 'when building with Bazel) ' 1268 '[Default is %s]: ' 1269 ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) 1270 1271 while True: 1272 trisycl_include_dir = get_from_env_or_user_or_default( 1273 environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, 1274 _DEFAULT_TRISYCL_INCLUDE_DIR) 1275 if os.path.exists(trisycl_include_dir): 1276 break 1277 1278 print('Invalid triSYCL include directory, %s cannot be found' 1279 % (trisycl_include_dir)) 1280 1281 # Set TRISYCL_INCLUDE_DIR 1282 environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir 1283 write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', 1284 trisycl_include_dir) 1285 1286 1287def set_mpi_home(environ_cp): 1288 """Set MPI_HOME.""" 1289 1290 default_mpi_home = which('mpirun') or which('mpiexec') or '' 1291 default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) 1292 1293 def valid_mpi_path(mpi_home): 1294 exists = (os.path.exists(os.path.join(mpi_home, 'include')) and 1295 os.path.exists(os.path.join(mpi_home, 'lib'))) 1296 if not exists: 1297 print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % 1298 (os.path.join(mpi_home, 'include'), 1299 os.path.exists(os.path.join(mpi_home, 'lib')))) 1300 return exists 1301 1302 _ = prompt_loop_or_load_from_env( 1303 environ_cp, 1304 var_name='MPI_HOME', 1305 var_default=default_mpi_home, 1306 ask_for_var='Please specify the MPI toolkit folder.', 1307 check_success=valid_mpi_path, 1308 error_msg='', 1309 suppress_default_error=True) 1310 1311 1312def set_other_mpi_vars(environ_cp): 1313 """Set other MPI related variables.""" 1314 # Link the MPI header files 1315 mpi_home = environ_cp.get('MPI_HOME') 1316 symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h') 1317 1318 # Determine if we use OpenMPI or MVAPICH, these require different header files 1319 # to be included here to make bazel dependency checker happy 1320 if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')): 1321 symlink_force( 1322 os.path.join(mpi_home, 'include/mpi_portable_platform.h'), 1323 'third_party/mpi/mpi_portable_platform.h') 1324 # TODO(gunan): avoid editing files in configure 1325 sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False', 1326 'MPI_LIB_IS_OPENMPI=True') 1327 else: 1328 # MVAPICH / MPICH 1329 symlink_force( 1330 os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h') 1331 symlink_force( 1332 os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h') 1333 # TODO(gunan): avoid editing files in configure 1334 sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True', 1335 'MPI_LIB_IS_OPENMPI=False') 1336 1337 if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): 1338 symlink_force( 1339 os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') 1340 else: 1341 raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) 1342 1343 1344def set_grpc_build_flags(): 1345 write_to_bazelrc('build --define grpc_no_ares=true') 1346 1347 1348def set_windows_build_flags(): 1349 if is_windows(): 1350 # The non-monolithic build is not supported yet 1351 write_to_bazelrc('build --config monolithic') 1352 # Suppress warning messages 1353 write_to_bazelrc('build --copt=-w --host_copt=-w') 1354 # Output more verbose information when something goes wrong 1355 write_to_bazelrc('build --verbose_failures') 1356 1357 1358def config_info_line(name, help_text): 1359 """Helper function to print formatted help text for Bazel config options.""" 1360 print('\t--config=%-12s\t# %s' % (name, help_text)) 1361 1362 1363def main(): 1364 parser = argparse.ArgumentParser() 1365 parser.add_argument("--workspace", 1366 type=str, 1367 default=_TF_WORKSPACE_ROOT, 1368 help="The absolute path to your active Bazel workspace.") 1369 args = parser.parse_args() 1370 1371 # Make a copy of os.environ to be clear when functions and getting and setting 1372 # environment variables. 1373 environ_cp = dict(os.environ) 1374 1375 check_bazel_version('0.5.4') 1376 1377 reset_tf_configure_bazelrc(args.workspace) 1378 cleanup_makefile() 1379 setup_python(environ_cp) 1380 1381 if is_windows(): 1382 environ_cp['TF_NEED_S3'] = '0' 1383 environ_cp['TF_NEED_GCP'] = '0' 1384 environ_cp['TF_NEED_HDFS'] = '0' 1385 environ_cp['TF_NEED_JEMALLOC'] = '0' 1386 environ_cp['TF_NEED_KAFKA'] = '0' 1387 environ_cp['TF_NEED_OPENCL_SYCL'] = '0' 1388 environ_cp['TF_NEED_COMPUTECPP'] = '0' 1389 environ_cp['TF_NEED_OPENCL'] = '0' 1390 environ_cp['TF_CUDA_CLANG'] = '0' 1391 environ_cp['TF_NEED_TENSORRT'] = '0' 1392 1393 if is_macos(): 1394 environ_cp['TF_NEED_JEMALLOC'] = '0' 1395 environ_cp['TF_NEED_TENSORRT'] = '0' 1396 1397 set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 1398 'with_jemalloc', True) 1399 set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', 1400 'with_gcp_support', True, 'gcp') 1401 set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', 1402 'with_hdfs_support', True, 'hdfs') 1403 set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 1404 'with_s3_support', True, 's3') 1405 set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', 1406 'with_kafka_support', False, 'kafka') 1407 set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', 1408 False, 'xla') 1409 set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', 1410 False, 'gdr') 1411 set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', 1412 False, 'verbs') 1413 1414 set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) 1415 if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': 1416 set_host_cxx_compiler(environ_cp) 1417 set_host_c_compiler(environ_cp) 1418 set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True) 1419 if environ_cp.get('TF_NEED_COMPUTECPP') == '1': 1420 set_computecpp_toolkit_path(environ_cp) 1421 else: 1422 set_trisycl_include_dir(environ_cp) 1423 1424 set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) 1425 if (environ_cp.get('TF_NEED_CUDA') == '1' and 1426 'TF_CUDA_CONFIG_REPO' not in environ_cp): 1427 set_tf_cuda_version(environ_cp) 1428 set_tf_cudnn_version(environ_cp) 1429 if is_linux(): 1430 set_tf_tensorrt_install_path(environ_cp) 1431 set_tf_cuda_compute_capabilities(environ_cp) 1432 if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 1433 'LD_LIBRARY_PATH') != '1': 1434 write_action_env_to_bazelrc('LD_LIBRARY_PATH', 1435 environ_cp.get('LD_LIBRARY_PATH')) 1436 1437 set_tf_cuda_clang(environ_cp) 1438 if environ_cp.get('TF_CUDA_CLANG') == '1': 1439 if not is_windows(): 1440 # Ask if we want to download clang release while building. 1441 set_tf_download_clang(environ_cp) 1442 else: 1443 # We use bazel's generated crosstool on Windows and there is no 1444 # way to provide downloaded toolchain for that yet. 1445 # TODO(ibiryukov): Investigate using clang as a cuda compiler on 1446 # Windows. 1447 environ_cp['TF_DOWNLOAD_CLANG'] = '0' 1448 1449 if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': 1450 # Set up which clang we should use as the cuda / host compiler. 1451 set_clang_cuda_compiler_path(environ_cp) 1452 else: 1453 # Set up which gcc nvcc should use as the host compiler 1454 # No need to set this on Windows 1455 if not is_windows(): 1456 set_gcc_host_compiler_path(environ_cp) 1457 set_other_cuda_vars(environ_cp) 1458 1459 set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) 1460 if environ_cp.get('TF_NEED_MPI') == '1': 1461 set_mpi_home(environ_cp) 1462 set_other_mpi_vars(environ_cp) 1463 1464 set_grpc_build_flags() 1465 set_cc_opt_flags(environ_cp) 1466 set_windows_build_flags() 1467 1468 if workspace_has_any_android_rule(): 1469 print('The WORKSPACE file has at least one of ["android_sdk_repository", ' 1470 '"android_ndk_repository"] already set. Will not ask to help ' 1471 'configure the WORKSPACE. Please delete the existing rules to ' 1472 'activate the helper.\n') 1473 else: 1474 if get_var( 1475 environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', 1476 False, 1477 ('Would you like to interactively configure ./WORKSPACE for ' 1478 'Android builds?'), 1479 'Searching for NDK and SDK installations.', 1480 'Not configuring the WORKSPACE for Android builds.'): 1481 create_android_ndk_rule(environ_cp) 1482 create_android_sdk_rule(environ_cp) 1483 1484 print('Preconfigured Bazel build configs. You can use any of the below by ' 1485 'adding "--config=<>" to your build command. See tools/bazel.rc for ' 1486 'more details.') 1487 config_info_line('mkl', 'Build with MKL support.') 1488 config_info_line('monolithic', 'Config for mostly static monolithic build.') 1489 1490if __name__ == '__main__': 1491 main() 1492