1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU system metadata and associated tooling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session as session_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.eager import context 27from tensorflow.python.framework import config 28from tensorflow.python.framework import device as tf_device 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.tpu import tpu 33 34_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min 35_RETRY_TIMES = 12 * 24 # 1 day 36_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins 37 38_DEFAULT_JOB_NAME = 'tpu_worker' 39_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' 40_LOCAL_MASTERS = ('', 'local') 41 42# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, 43# including num_cores and num_hosts. 44_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ 45 'num_cores', 46 'num_hosts', 47 'num_of_cores_per_host', 48 'topology', 49 'devices', 50]) 51 52 53def _query_tpu_system_metadata(master_address, cluster_def=None, 54 query_topology=False): 55 """Automatically detects the TPU system metadata in the system.""" 56 tpu_core_count = 0 57 devices = [] 58 device_dict = collections.defaultdict(list) 59 60 if context.executing_eagerly(): 61 logical_devices = config.list_logical_devices() 62 63 # We want the output type to match in both eager and session mode 64 devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access 65 d.device_type, 0, 0) 66 for d in logical_devices] 67 else: 68 # TODO(b/120564445): Replace with standard library for retries. 69 retry_count = 1 70 while True: 71 logging.info('Querying Tensorflow master (%s) for TPU system metadata.', 72 master_address) 73 try: 74 with ops.Graph().as_default(): 75 with session_lib.Session( 76 master_address, 77 config=get_session_config_with_timeout( 78 _PINGING_MASTER_TIMEOUT_IN_MS, 79 cluster_def)) as sess: 80 devices = sess.list_devices() 81 break 82 except errors.DeadlineExceededError: 83 msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' 84 'not be ready (still scheduling) or the Tensorflow master ' 85 'address is incorrect: got (%s).' % 86 (master_address)) 87 88 # TODO(xiejw): For local or grpc master we might not need retry logic 89 # here. 90 if retry_count <= _RETRY_TIMES: 91 logging.warning('%s', msg) 92 logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) 93 retry_count += 1 94 else: 95 raise ValueError(msg) 96 97 for device in devices: 98 spec = tf_device.DeviceSpec.from_string(device.name) 99 if spec.device_type == 'TPU': 100 device_dict[spec.task].append(spec.device_index) 101 tpu_core_count += 1 102 103 num_of_cores_per_host = 0 104 if tpu_core_count: 105 num_cores_per_host_set = set( 106 [len(core_ids) for core_ids in device_dict.values()]) 107 if len(num_cores_per_host_set) != 1: 108 raise RuntimeError( 109 'TPU cores on each host is not same. This should not happen!. ' 110 'devices: {}'.format(devices)) 111 num_of_cores_per_host = num_cores_per_host_set.pop() 112 113 topology = None 114 if query_topology: 115 if not tpu_core_count: 116 raise RuntimeError( 117 'Cannot find any TPU cores in the system (master address {}). ' 118 'This usually means the master address is incorrect or the ' 119 'TPU worker has some problems. Available devices: {}'.format( 120 master_address, devices)) 121 122 topology = _obtain_topology(master_address, cluster_def) 123 124 # We sort the metadata devices so that downstream users get a sorted list 125 # for creating mirrored variables correctly. 126 def _sort_key(device): 127 spec = tf_device.DeviceSpec.from_string(device.name) 128 return (spec.job, spec.replica, spec.task, spec.device_type, 129 spec.device_index) 130 devices = tuple(sorted(devices, key=_sort_key)) 131 132 metadata = _TPUSystemMetadata( 133 num_cores=tpu_core_count, 134 num_hosts=len(device_dict), 135 num_of_cores_per_host=num_of_cores_per_host, 136 topology=topology, 137 devices=devices) 138 139 if tpu_core_count: 140 logging.info('Found TPU system:') 141 logging.info('*** Num TPU Cores: %d', metadata.num_cores) 142 logging.info('*** Num TPU Workers: %d', metadata.num_hosts) 143 logging.info('*** Num TPU Cores Per Worker: %d', 144 metadata.num_of_cores_per_host) 145 for device in metadata.devices: 146 logging.info('*** Available Device: %s', device) 147 else: 148 logging.info('Failed to find TPU: %s', metadata) 149 return metadata 150 151 152def _obtain_topology(master_address, cluster_def): 153 """Obtains TPU fabric topology.""" 154 try: 155 logging.info('Initializing TPU system (master: %s) to fetch topology ' 156 'for model parallelism. This might take a while.', 157 master_address) 158 with ops.Graph().as_default(): 159 session_config = get_session_config_with_timeout( 160 _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) 161 with session_lib.Session( 162 master_address, config=session_config) as sess: 163 topology = sess.run(tpu.initialize_system()) 164 return topology 165 except errors.DeadlineExceededError: 166 raise ValueError( 167 'Fail to initialize TPU system with master (%s). ' 168 'Please double check the TPU system is functional.' % ( 169 master_address)) 170 171 172def get_session_config_with_timeout(timeout_in_secs, cluster_def): 173 """Returns a session given a timeout and a cluster configuration.""" 174 config_proto = config_pb2.ConfigProto( 175 operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) 176 return config_proto 177 178 179def master_job(master, cluster_def): 180 """Returns the canonnical job name to use to place TPU computations on. 181 182 Args: 183 master: A `string` representing the TensorFlow master to use. 184 cluster_def: A ClusterDef object describing the TPU cluster. 185 186 187 Returns: 188 A string containing the job name, or None if no job should be specified. 189 190 Raises: 191 ValueError: If the user needs to specify a tpu_job_name, because we are 192 unable to infer the job name automatically, or if the user-specified job 193 names are inappropriate. 194 """ 195 # If the user specifies the tpu_job_name, use that. 196 197 if master in _LOCAL_MASTERS: 198 return None 199 200 if (not cluster_def or not cluster_def.job): 201 return _DEFAULT_JOB_NAME 202 job_names = set(job.name for job in cluster_def.job) 203 if _DEFAULT_JOB_NAME in job_names: 204 # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. 205 raise ValueError('Currently, tpu_worker is not an allowed job name.') 206 if len(job_names) == 1: 207 return cluster_def.job[0].name 208 if len(job_names) == 2: 209 if _DEFAULT_COORDINATOR_JOB_NAME in job_names: 210 job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) 211 return job_names.pop() 212 # TODO(b/67716447): Include more sophisticated heuristics. 213 raise ValueError('Could not infer TPU job name.') 214