1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU system metadata and associated tooling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.client import session as session_lib 26from tensorflow.python.eager import context 27from tensorflow.python.framework import device as tf_device 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.tpu import tpu 32 33_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min 34_RETRY_TIMES = 120 35_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins 36 37_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') 38_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*') 39 40_DEFAULT_JOB_NAME = 'tpu_worker' 41_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' 42_LOCAL_MASTERS = ('', 'local') 43 44# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, 45# including num_cores and num_hosts. 46_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ 47 'num_cores', 48 'num_hosts', 49 'num_of_cores_per_host', 50 'topology', 51 'devices', 52]) 53 54 55def _query_tpu_system_metadata(master_address, cluster_def=None, 56 query_topology=False): 57 """Automatically detects the TPU system metadata in the system.""" 58 tpu_core_count = 0 59 devices = [] 60 device_dict = collections.defaultdict(list) 61 62 if context.executing_eagerly(): 63 device_names = context.list_devices() 64 devices = [] 65 66 # We want the output type to match in both eager and session mode 67 for name in device_names: 68 device_match = _DEVICE_TYPE_REGEX.match(name) 69 device_type = 'CPU' 70 if device_match: 71 device_type = device_match.group(1) 72 devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access 73 else: 74 # TODO(b/120564445): Replace with standard library for retries. 75 retry_count = 1 76 while True: 77 logging.info('Querying Tensorflow master (%s) for TPU system metadata.', 78 master_address) 79 try: 80 with ops.Graph().as_default(): 81 with session_lib.Session( 82 master_address, 83 config=get_session_config_with_timeout( 84 _PINGING_MASTER_TIMEOUT_IN_MS, 85 cluster_def)) as sess: 86 devices = sess.list_devices() 87 break 88 except errors.DeadlineExceededError: 89 msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' 90 'not be ready (still scheduling) or the Tensorflow master ' 91 'address is incorrect: got (%s).' % 92 (master_address)) 93 94 # TODO(xiejw): For local or grpc master we might not need retry logic 95 # here. 96 if retry_count <= _RETRY_TIMES: 97 logging.warning('%s', msg) 98 logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) 99 retry_count += 1 100 else: 101 raise ValueError(msg) 102 103 for device in devices: 104 match = _TPU_DEVICE_REG.match(device.name) 105 if match: 106 host_id = match.group(1) 107 core_id = match.group(2) 108 device_dict[host_id].append(core_id) 109 tpu_core_count += 1 110 111 num_of_cores_per_host = 0 112 if tpu_core_count: 113 num_cores_per_host_set = set( 114 [len(core_ids) for core_ids in device_dict.values()]) 115 if len(num_cores_per_host_set) != 1: 116 raise RuntimeError( 117 'TPU cores on each host is not same. This should not happen!. ' 118 'devices: {}'.format(devices)) 119 num_of_cores_per_host = num_cores_per_host_set.pop() 120 121 topology = None 122 if query_topology: 123 if not tpu_core_count: 124 raise RuntimeError( 125 'Cannot find any TPU cores in the system (master address {}). ' 126 'This usually means the master address is incorrect or the ' 127 'TPU worker has some problems. Available devices: {}'.format( 128 master_address, devices)) 129 130 topology = _obtain_topology(master_address, cluster_def) 131 132 # We sort the metadata devices so that downstream users get a sorted list 133 # for creating mirrored variables correctly. 134 def _sort_key(device): 135 spec = tf_device.DeviceSpec.from_string(device.name) 136 return (spec.job, spec.replica, spec.task, spec.device_type, 137 spec.device_index) 138 devices = tuple(sorted(devices, key=_sort_key)) 139 140 metadata = _TPUSystemMetadata( 141 num_cores=tpu_core_count, 142 num_hosts=len(device_dict), 143 num_of_cores_per_host=num_of_cores_per_host, 144 topology=topology, 145 devices=devices) 146 147 if tpu_core_count: 148 logging.info('Found TPU system:') 149 logging.info('*** Num TPU Cores: %d', metadata.num_cores) 150 logging.info('*** Num TPU Workers: %d', metadata.num_hosts) 151 logging.info('*** Num TPU Cores Per Worker: %d', 152 metadata.num_of_cores_per_host) 153 for device in metadata.devices: 154 logging.info('*** Available Device: %s', device) 155 else: 156 logging.info('Failed to find TPU: %s', metadata) 157 return metadata 158 159 160def _obtain_topology(master_address, cluster_def): 161 """Obtains TPU fabric topology.""" 162 try: 163 logging.info('Initializing TPU system (master: %s) to fetch topology ' 164 'for model parallelism. This might take a while.', 165 master_address) 166 with ops.Graph().as_default(): 167 session_config = get_session_config_with_timeout( 168 _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) 169 with session_lib.Session( 170 master_address, config=session_config) as sess: 171 topology = sess.run(tpu.initialize_system()) 172 return topology 173 except errors.DeadlineExceededError: 174 raise ValueError( 175 'Fail to initialize TPU system with master (%s). ' 176 'Please double check the TPU system is functional.' % ( 177 master_address)) 178 179 180def get_session_config_with_timeout(timeout_in_secs, cluster_def): 181 """Returns a session given a timeout and a cluster configuration.""" 182 config = config_pb2.ConfigProto( 183 operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) 184 return config 185 186 187def master_job(master, cluster_def): 188 """Returns the canonnical job name to use to place TPU computations on. 189 190 Args: 191 master: A `string` representing the TensorFlow master to use. 192 cluster_def: A ClusterDef object describing the TPU cluster. 193 194 195 Returns: 196 A string containing the job name, or None if no job should be specified. 197 198 Raises: 199 ValueError: If the user needs to specify a tpu_job_name, because we are 200 unable to infer the job name automatically, or if the user-specified job 201 names are inappropriate. 202 """ 203 # If the user specifies the tpu_job_name, use that. 204 205 if master in _LOCAL_MASTERS: 206 return None 207 208 if (not cluster_def or not cluster_def.job): 209 return _DEFAULT_JOB_NAME 210 job_names = set([job.name for job in cluster_def.job]) 211 if _DEFAULT_JOB_NAME in job_names: 212 # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. 213 raise ValueError('Currently, tpu_worker is not an allowed job name.') 214 if len(job_names) == 1: 215 return cluster_def.job[0].name 216 if len(job_names) == 2: 217 if _DEFAULT_COORDINATOR_JOB_NAME in job_names: 218 job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) 219 return job_names.pop() 220 # TODO(b/67716447): Include more sophisticated heuristics. 221 raise ValueError( 222 'Could not infer TPU job name. Please specify a tpu_job_name as part ' 223 'of your TPUConfig.') 224