• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU system metadata and associated tooling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session as session_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.eager import context
27from tensorflow.python.framework import config
28from tensorflow.python.framework import device as tf_device
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.tpu import tpu
33
34_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000  # 10 min
35_RETRY_TIMES = 12 * 24  # 1 day
36_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000  # 5 mins
37
38_DEFAULT_JOB_NAME = 'tpu_worker'
39_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
40_LOCAL_MASTERS = ('', 'local')
41
42# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
43# including num_cores and num_hosts.
44_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
45    'num_cores',
46    'num_hosts',
47    'num_of_cores_per_host',
48    'topology',
49    'devices',
50])
51
52
53def _query_tpu_system_metadata(master_address, cluster_def=None,
54                               query_topology=False):
55  """Automatically detects the TPU system metadata in the system."""
56  tpu_core_count = 0
57  devices = []
58  device_dict = collections.defaultdict(list)
59
60  if context.executing_eagerly():
61    logical_devices = config.list_logical_devices()
62
63    # We want the output type to match in both eager and session mode
64    devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name),  # pylint: disable=protected-access
65                                             d.device_type, 0, 0)
66               for d in logical_devices]
67  else:
68    # TODO(b/120564445): Replace with standard library for retries.
69    retry_count = 1
70    while True:
71      logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
72                   master_address)
73      try:
74        with ops.Graph().as_default():
75          with session_lib.Session(
76              master_address,
77              config=get_session_config_with_timeout(
78                  _PINGING_MASTER_TIMEOUT_IN_MS,
79                  cluster_def)) as sess:
80            devices = sess.list_devices()
81            break
82      except errors.DeadlineExceededError:
83        msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
84               'not be ready (still scheduling) or the Tensorflow master '
85               'address is incorrect: got (%s).' %
86               (master_address))
87
88        # TODO(xiejw): For local or grpc master we might not need retry logic
89        # here.
90        if retry_count <= _RETRY_TIMES:
91          logging.warning('%s', msg)
92          logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
93          retry_count += 1
94        else:
95          raise ValueError(msg)
96
97  for device in devices:
98    spec = tf_device.DeviceSpec.from_string(device.name)
99    if spec.device_type == 'TPU':
100      device_dict[spec.task].append(spec.device_index)
101      tpu_core_count += 1
102
103  num_of_cores_per_host = 0
104  if tpu_core_count:
105    num_cores_per_host_set = set(
106        [len(core_ids) for core_ids in device_dict.values()])
107    if len(num_cores_per_host_set) != 1:
108      raise RuntimeError(
109          'TPU cores on each host is not same. This should not happen!. '
110          'devices: {}'.format(devices))
111    num_of_cores_per_host = num_cores_per_host_set.pop()
112
113  topology = None
114  if query_topology:
115    if not tpu_core_count:
116      raise RuntimeError(
117          'Cannot find any TPU cores in the system (master address {}). '
118          'This usually means the master address is incorrect or the '
119          'TPU worker has some problems. Available devices: {}'.format(
120              master_address, devices))
121
122    topology = _obtain_topology(master_address, cluster_def)
123
124  # We sort the metadata devices so that downstream users get a sorted list
125  # for creating mirrored variables correctly.
126  def _sort_key(device):
127    spec = tf_device.DeviceSpec.from_string(device.name)
128    return (spec.job, spec.replica, spec.task, spec.device_type,
129            spec.device_index)
130  devices = tuple(sorted(devices, key=_sort_key))
131
132  metadata = _TPUSystemMetadata(
133      num_cores=tpu_core_count,
134      num_hosts=len(device_dict),
135      num_of_cores_per_host=num_of_cores_per_host,
136      topology=topology,
137      devices=devices)
138
139  if tpu_core_count:
140    logging.info('Found TPU system:')
141    logging.info('*** Num TPU Cores: %d', metadata.num_cores)
142    logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
143    logging.info('*** Num TPU Cores Per Worker: %d',
144                 metadata.num_of_cores_per_host)
145    for device in metadata.devices:
146      logging.info('*** Available Device: %s', device)
147  else:
148    logging.info('Failed to find TPU: %s', metadata)
149  return metadata
150
151
152def _obtain_topology(master_address, cluster_def):
153  """Obtains TPU fabric topology."""
154  try:
155    logging.info('Initializing TPU system (master: %s) to fetch topology '
156                 'for model parallelism. This might take a while.',
157                 master_address)
158    with ops.Graph().as_default():
159      session_config = get_session_config_with_timeout(
160          _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
161      with session_lib.Session(
162          master_address, config=session_config) as sess:
163        topology = sess.run(tpu.initialize_system())
164        return topology
165  except errors.DeadlineExceededError:
166    raise ValueError(
167        'Fail to initialize TPU system with master (%s). '
168        'Please double check the TPU system is functional.' % (
169            master_address))
170
171
172def get_session_config_with_timeout(timeout_in_secs, cluster_def):
173  """Returns a session given a timeout and a cluster configuration."""
174  config_proto = config_pb2.ConfigProto(
175      operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
176  return config_proto
177
178
179def master_job(master, cluster_def):
180  """Returns the canonnical job name to use to place TPU computations on.
181
182  Args:
183    master: A `string` representing the TensorFlow master to use.
184    cluster_def: A ClusterDef object describing the TPU cluster.
185
186
187  Returns:
188    A string containing the job name, or None if no job should be specified.
189
190  Raises:
191    ValueError: If the user needs to specify a tpu_job_name, because we are
192      unable to infer the job name automatically, or if the user-specified job
193      names are inappropriate.
194  """
195  # If the user specifies the tpu_job_name, use that.
196
197  if master in _LOCAL_MASTERS:
198    return None
199
200  if (not cluster_def or not cluster_def.job):
201    return _DEFAULT_JOB_NAME
202  job_names = set(job.name for job in cluster_def.job)
203  if _DEFAULT_JOB_NAME in job_names:
204    # b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
205    raise ValueError('Currently, tpu_worker is not an allowed job name.')
206  if len(job_names) == 1:
207    return cluster_def.job[0].name
208  if len(job_names) == 2:
209    if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
210      job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
211      return job_names.pop()
212    # TODO(b/67716447): Include more sophisticated heuristics.
213  raise ValueError('Could not infer TPU job name.')
214