1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains function to log if devices are compatible with mixed precision.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23from tensorflow.python.framework import config 24from tensorflow.python.platform import tf_logging 25 26 27_COMPAT_CHECK_PREFIX = 'Mixed precision compatibility check (mixed_float16): ' 28_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + 'OK' 29_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + 'WARNING' 30_COMPAT_CHECK_WARNING_SUFFIX = ( 31 'If you will use compatible GPU(s) not attached to this host, e.g. by ' 32 'running a multi-worker model, you can ignore this warning. This message ' 33 'will only be logged once') 34 35 36def _dedup_strings(device_strs): 37 """Groups together consecutive identical strings. 38 39 For example, given: 40 ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3'] 41 This function returns: 42 ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)'] 43 44 Args: 45 device_strs: A list of strings, each representing a device. 46 47 Returns: 48 A copy of the input, but identical consecutive strings are merged into a 49 single string. 50 """ 51 new_device_strs = [] 52 for device_str, vals in itertools.groupby(device_strs): 53 num = len(list(vals)) 54 if num == 1: 55 new_device_strs.append(device_str) 56 else: 57 new_device_strs.append('%s (x%d)' % (device_str, num)) 58 return new_device_strs 59 60 61def _log_device_compatibility_check(policy_name, gpu_details_list): 62 """Logs a compatibility check if the devices support the policy. 63 64 Currently only logs for the policy mixed_float16. 65 66 Args: 67 policy_name: The name of the dtype policy. 68 gpu_details_list: A list of dicts, one dict per GPU. Each dict 69 is the device details for a GPU, as returned by 70 `tf.config.experimental.get_device_details()`. 71 """ 72 if policy_name != 'mixed_float16': 73 # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires 74 # checking if a TPU is available. 75 return 76 supported_device_strs = [] 77 unsupported_device_strs = [] 78 for details in gpu_details_list: 79 name = details.get('device_name', 'Unknown GPU') 80 cc = details.get('compute_capability') 81 if cc: 82 device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1]) 83 if cc >= (7, 0): 84 supported_device_strs.append(device_str) 85 else: 86 unsupported_device_strs.append(device_str) 87 else: 88 unsupported_device_strs.append( 89 name + ', no compute capability (probably not an Nvidia GPU)') 90 91 if unsupported_device_strs: 92 warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n' 93 if supported_device_strs: 94 warning_str += ('Some of your GPUs may run slowly with dtype policy ' 95 'mixed_float16 because they do not all have compute ' 96 'capability of at least 7.0. Your GPUs:\n') 97 elif len(unsupported_device_strs) == 1: 98 warning_str += ('Your GPU may run slowly with dtype policy mixed_float16 ' 99 'because it does not have compute capability of at least ' 100 '7.0. Your GPU:\n') 101 else: 102 warning_str += ('Your GPUs may run slowly with dtype policy ' 103 'mixed_float16 because they do not have compute ' 104 'capability of at least 7.0. Your GPUs:\n') 105 for device_str in _dedup_strings(supported_device_strs + 106 unsupported_device_strs): 107 warning_str += ' ' + device_str + '\n' 108 warning_str += ('See https://developer.nvidia.com/cuda-gpus for a list of ' 109 'GPUs and their compute capabilities.\n') 110 warning_str += _COMPAT_CHECK_WARNING_SUFFIX 111 tf_logging.warn(warning_str) 112 elif not supported_device_strs: 113 tf_logging.warn('%s\n' 114 'The dtype policy mixed_float16 may run slowly because ' 115 'this machine does not have a GPU. Only Nvidia GPUs with ' 116 'compute capability of at least 7.0 run quickly with ' 117 'mixed_float16.\n%s' % (_COMPAT_CHECK_WARNING_PREFIX, 118 _COMPAT_CHECK_WARNING_SUFFIX)) 119 elif len(supported_device_strs) == 1: 120 tf_logging.info('%s\n' 121 'Your GPU will likely run quickly with dtype policy ' 122 'mixed_float16 as it has compute capability of at least ' 123 '7.0. Your GPU: %s' % (_COMPAT_CHECK_OK_PREFIX, 124 supported_device_strs[0])) 125 else: 126 tf_logging.info('%s\n' 127 'Your GPUs will likely run quickly with dtype policy ' 128 'mixed_float16 as they all have compute capability of at ' 129 'least 7.0' % _COMPAT_CHECK_OK_PREFIX) 130 131 132_logged_compatibility_check = False 133 134 135def log_device_compatibility_check(policy_name): 136 """Logs a compatibility check if the devices support the policy. 137 138 Currently only logs for the policy mixed_float16. A log is shown only the 139 first time this function is called. 140 141 Args: 142 policy_name: The name of the dtype policy. 143 """ 144 global _logged_compatibility_check 145 if _logged_compatibility_check: 146 return 147 _logged_compatibility_check = True 148 gpus = config.list_physical_devices('GPU') 149 gpu_details_list = [config.get_device_details(g) for g in gpus] 150 _log_device_compatibility_check(policy_name, gpu_details_list) 151