• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains function to log if devices are compatible with mixed precision."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from tensorflow.python.framework import config
24from tensorflow.python.platform import tf_logging
25
26
27_COMPAT_CHECK_PREFIX = 'Mixed precision compatibility check (mixed_float16): '
28_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + 'OK'
29_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + 'WARNING'
30_COMPAT_CHECK_WARNING_SUFFIX = (
31    'If you will use compatible GPU(s) not attached to this host, e.g. by '
32    'running a multi-worker model, you can ignore this warning. This message '
33    'will only be logged once')
34
35
36def _dedup_strings(device_strs):
37  """Groups together consecutive identical strings.
38
39  For example, given:
40      ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3']
41  This function returns:
42      ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)']
43
44  Args:
45    device_strs: A list of strings, each representing a device.
46
47  Returns:
48    A copy of the input, but identical consecutive strings are merged into a
49    single string.
50  """
51  new_device_strs = []
52  for device_str, vals in itertools.groupby(device_strs):
53    num = len(list(vals))
54    if num == 1:
55      new_device_strs.append(device_str)
56    else:
57      new_device_strs.append('%s (x%d)' % (device_str, num))
58  return new_device_strs
59
60
61def _log_device_compatibility_check(policy_name, gpu_details_list):
62  """Logs a compatibility check if the devices support the policy.
63
64  Currently only logs for the policy mixed_float16.
65
66  Args:
67    policy_name: The name of the dtype policy.
68    gpu_details_list: A list of dicts, one dict per GPU. Each dict
69      is the device details for a GPU, as returned by
70      `tf.config.experimental.get_device_details()`.
71  """
72  if policy_name != 'mixed_float16':
73    # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires
74    # checking if a TPU is available.
75    return
76  supported_device_strs = []
77  unsupported_device_strs = []
78  for details in gpu_details_list:
79    name = details.get('device_name', 'Unknown GPU')
80    cc = details.get('compute_capability')
81    if cc:
82      device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
83      if cc >= (7, 0):
84        supported_device_strs.append(device_str)
85      else:
86        unsupported_device_strs.append(device_str)
87    else:
88      unsupported_device_strs.append(
89          name + ', no compute capability (probably not an Nvidia GPU)')
90
91  if unsupported_device_strs:
92    warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n'
93    if supported_device_strs:
94      warning_str += ('Some of your GPUs may run slowly with dtype policy '
95                      'mixed_float16 because they do not all have compute '
96                      'capability of at least 7.0. Your GPUs:\n')
97    elif len(unsupported_device_strs) == 1:
98      warning_str += ('Your GPU may run slowly with dtype policy mixed_float16 '
99                      'because it does not have compute capability of at least '
100                      '7.0. Your GPU:\n')
101    else:
102      warning_str += ('Your GPUs may run slowly with dtype policy '
103                      'mixed_float16 because they do not have compute '
104                      'capability of at least 7.0. Your GPUs:\n')
105    for device_str in _dedup_strings(supported_device_strs +
106                                     unsupported_device_strs):
107      warning_str += '  ' + device_str + '\n'
108    warning_str += ('See https://developer.nvidia.com/cuda-gpus for a list of '
109                    'GPUs and their compute capabilities.\n')
110    warning_str += _COMPAT_CHECK_WARNING_SUFFIX
111    tf_logging.warn(warning_str)
112  elif not supported_device_strs:
113    tf_logging.warn('%s\n'
114                    'The dtype policy mixed_float16 may run slowly because '
115                    'this machine does not have a GPU. Only Nvidia GPUs with '
116                    'compute capability of at least 7.0 run quickly with '
117                    'mixed_float16.\n%s' % (_COMPAT_CHECK_WARNING_PREFIX,
118                                            _COMPAT_CHECK_WARNING_SUFFIX))
119  elif len(supported_device_strs) == 1:
120    tf_logging.info('%s\n'
121                    'Your GPU will likely run quickly with dtype policy '
122                    'mixed_float16 as it has compute capability of at least '
123                    '7.0. Your GPU: %s' % (_COMPAT_CHECK_OK_PREFIX,
124                                           supported_device_strs[0]))
125  else:
126    tf_logging.info('%s\n'
127                    'Your GPUs will likely run quickly with dtype policy '
128                    'mixed_float16 as they all have compute capability of at '
129                    'least 7.0' % _COMPAT_CHECK_OK_PREFIX)
130
131
132_logged_compatibility_check = False
133
134
135def log_device_compatibility_check(policy_name):
136  """Logs a compatibility check if the devices support the policy.
137
138  Currently only logs for the policy mixed_float16. A log is shown only the
139  first time this function is called.
140
141  Args:
142    policy_name: The name of the dtype policy.
143  """
144  global _logged_compatibility_check
145  if _logged_compatibility_check:
146    return
147  _logged_compatibility_check = True
148  gpus = config.list_physical_devices('GPU')
149  gpu_details_list = [config.get_device_details(g) for g in gpus]
150  _log_device_compatibility_check(policy_name, gpu_details_list)
151