• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for Cloud TPUs."""
16
17import collections
18import re
19
20from tensorflow.core.protobuf.tpu import topology_pb2
21from tensorflow.python.distribute.cluster_resolver import cluster_resolver
22from tensorflow.python.framework import config as framework_config
23from tensorflow.python.framework import errors
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
26from tensorflow.python.training import server_lib
27from tensorflow.python.util import compat
28
29try:
30  from cloud_tpu_client import client  # pylint: disable=g-import-not-at-top
31except ImportError:
32  logging.debug(
33      'Falling back to TensorFlow client; we recommended you install the Cloud '
34      'TPU client directly with pip install cloud-tpu-client.')
35  from tensorflow.python.tpu.client import client  # pylint: disable=g-import-not-at-top
36
37
38def is_running_in_gce():
39  return True
40
41
42class _LocalCloudTpuClient(object):
43  """Dummy local Cloud TPU client."""
44
45  def api_available(self):
46    return False
47
48
49_TPU_DEVICE_REGEX = re.compile(
50    r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
51_TPU_CONN_RETRIES = 120
52DeviceDetails = collections.namedtuple(
53    'DeviceDetails', ['device_map', 'total_cores'])
54
55
56class TPUClusterResolver(cluster_resolver.ClusterResolver):
57  """Cluster Resolver for Google Cloud TPUs.
58
59  This is an implementation of cluster resolvers for the Google Cloud TPU
60  service.
61
62  TPUClusterResolver supports the following distinct environments:
63  Google Compute Engine
64  Google Kubernetes Engine
65  Google internal
66
67  It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
68  Cloud TPUs.
69  """
70
71  @staticmethod
72  def connect(tpu=None,
73              zone=None,
74              project=None):
75    """Initializes TPU and returns a TPUClusterResolver.
76
77    This API will connect to remote TPU cluster and initialize the TPU
78    hardwares. Example usage:
79
80    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
81    ...     tpu='')
82
83    It can be viewed as convenient wrapper of the following code:
84
85    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
86    >>> tf.config.experimental_connect_to_cluster(resolver)
87    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
88
89    Args:
90      tpu: A string corresponding to the TPU to use. It can be the TPU name or
91        TPU worker gRPC address. If not set, it will try automatically resolve
92        the TPU address on Cloud TPUs.
93      zone: Zone where the TPUs are located. If omitted or empty, we will assume
94        that the zone of the TPU is the same as the zone of the GCE VM, which we
95        will try to discover from the GCE metadata service.
96      project: Name of the GCP project containing Cloud TPUs. If omitted or
97        empty, we will try to discover the project name of the GCE VM from the
98        GCE metadata service.
99
100    Returns:
101      An instance of TPUClusterResolver object.
102
103    Raises:
104      NotFoundError: If no TPU devices found in eager mode.
105    """
106    resolver = TPUClusterResolver(tpu, zone, project)
107    from tensorflow.python.eager import remote  # pylint: disable=g-import-not-at-top
108    remote.connect_to_cluster(resolver)
109    from tensorflow.python.tpu import tpu_strategy_util  # pylint: disable=g-import-not-at-top
110    tpu_strategy_util.initialize_tpu_system(resolver)
111    return resolver
112
113  @staticmethod
114  def _get_device_dict_and_cores(devices):
115    """Returns a dict of hosts to cores and total cores given devices names.
116
117    Returns a namedtuple with two attributes:
118      device_map: A map of host_ids to a list of core_ids.
119      total_cores: The total number of cores within the TPU system.
120
121    Args:
122      devices: A list of devices returned by session.list_devices()
123    """
124    device_map = collections.defaultdict(list)
125    num_cores = 0
126    for device in devices:
127      match = _TPU_DEVICE_REGEX.match(device.name)
128      if match:
129        host_id = match.group('host_id')
130        core_id = match.group('core_id')
131        device_map[host_id].append(core_id)
132        num_cores += 1
133    return DeviceDetails(device_map, num_cores)
134
135  @staticmethod
136  def _verify_and_return_same_core_count(device_dict):
137    """Verifies that every device in device_dict has the same # of cores."""
138    num_cores_per_host_set = (
139        {len(core_ids) for core_ids in device_dict.values()})
140    if len(num_cores_per_host_set) != 1:
141      raise RuntimeError('TPU cores on each device is not the same. This '
142                         'should never happen. Devices: {}'.format(device_dict))
143    return num_cores_per_host_set.pop()
144
145  def __init__(self,
146               tpu=None,
147               zone=None,
148               project=None,
149               job_name='worker',
150               coordinator_name=None,
151               coordinator_address=None,
152               credentials='default',
153               service=None,
154               discovery_url=None):
155    """Creates a new TPUClusterResolver object.
156
157    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
158    for the IP addresses and ports of each Cloud TPU listed.
159
160    Args:
161      tpu: A string corresponding to the TPU to use. It can be the TPU name or
162        TPU worker gRPC address. If not set, it will try automatically resolve
163        the TPU address on Cloud TPUs. If set to "local", it will assume that
164        the TPU is directly connected to the VM instead of over the network.
165      zone: Zone where the TPUs are located. If omitted or empty, we will assume
166        that the zone of the TPU is the same as the zone of the GCE VM, which we
167        will try to discover from the GCE metadata service.
168      project: Name of the GCP project containing Cloud TPUs. If omitted or
169        empty, we will try to discover the project name of the GCE VM from the
170        GCE metadata service.
171      job_name: Name of the TensorFlow job the TPUs belong to.
172      coordinator_name: The name to use for the coordinator. Set to None if the
173        coordinator should not be included in the computed ClusterSpec.
174      coordinator_address: The address of the coordinator (typically an ip:port
175        pair). If set to None, a TF server will be started. If coordinator_name
176        is None, a TF server will not be started even if coordinator_address is
177        None.
178      credentials: GCE Credentials. If None, then we use default credentials
179        from the oauth2client
180      service: The GCE API object returned by the googleapiclient.discovery
181        function. If you specify a custom service object, then the credentials
182        parameter will be ignored.
183      discovery_url: A URL template that points to the location of the discovery
184        service. It should have two parameters {api} and {apiVersion} that when
185        filled in produce an absolute URL to the discovery document for that
186        service. The environment variable 'TPU_API_DISCOVERY_URL' will override
187        this.
188
189    Raises:
190      ImportError: If the googleapiclient is not installed.
191      ValueError: If no TPUs are specified.
192      RuntimeError: If an empty TPU name is specified and this is running in a
193        Google Cloud environment.
194    """
195
196    if tpu != 'local':
197      # Default Cloud environment
198      self._cloud_tpu_client = client.Client(
199          tpu=tpu,
200          zone=zone,
201          project=project,
202          credentials=credentials,
203          service=service,
204          discovery_url=discovery_url)
205      self._tpu = self._cloud_tpu_client.name()
206    else:
207      # Directly connected TPU environment
208      self._cloud_tpu_client = _LocalCloudTpuClient()
209      self._tpu = 'local'
210
211    # By default the task_type is 'worker` and the task_id is 0 (which is the
212    # first worker in the task).
213    self.task_type = job_name
214    self.task_id = 0
215    self._coordinator_name = coordinator_name
216    if (coordinator_name and not coordinator_address):
217      self._start_local_server()
218    else:
219      self._coordinator_address = coordinator_address
220
221    self._tpu_topology = None
222
223  def __enter__(self):
224    self._cloud_tpu_client.enter()
225
226  def __exit__(self, type, value, traceback):  # pylint: disable=redefined-builtin
227    self._cloud_tpu_client.exit(type, value, traceback)
228
229  def master(self, task_type=None, task_id=None, rpc_layer=None):
230    """Get the Master string to be used for the session.
231
232    In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
233    first instance in the ClusterSpec returned by the cluster_spec function.
234
235    If a non-TPU name is used when constructing a TPUClusterResolver, that will
236    be returned instead (e.g. If the tpus argument's value when constructing
237    this TPUClusterResolver was 'grpc://10.240.1.2:8470',
238    'grpc://10.240.1.2:8470' will be returned).
239
240    Args:
241      task_type: (Optional, string) The type of the TensorFlow task of the
242        master.
243      task_id: (Optional, integer) The index of the TensorFlow task of the
244        master.
245      rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
246        communicate with TPUs.
247
248    Returns:
249      string, the connection string to use when creating a session.
250
251    Raises:
252      ValueError: If none of the TPUs specified exists.
253    """
254
255    if self._tpu != 'local':
256      cluster_spec = self.cluster_spec()
257      if task_type is not None and task_id is not None:
258        # task_type and task_id is from the function parameter
259        master = cluster_spec.task_address(task_type, task_id)
260      elif self.task_type is not None and self.task_id is not None:
261        # task_type and task_id is from the object
262        master = cluster_spec.task_address(self.task_type, self.task_id)
263      else:
264        # by default we take the first item in the cluster with the right name
265        job_tasks = cluster_spec.job_tasks(self.task_type)
266        if not job_tasks:
267          raise ValueError('No TPUs with the specified names exist.')
268        master = job_tasks[0]
269      return cluster_resolver.format_master_url(master, 'grpc')
270    else:
271      return ''
272
273  def get_master(self):
274    return self.master()
275
276  def get_job_name(self):
277    return self.task_type
278
279  def get_tpu_system_metadata(self):
280    """Returns the metadata of the TPU system.
281
282    Users can call this method to get some facts of the TPU system, like
283    total number of cores, number of TPU workers and the devices. E.g.
284    ```python
285
286    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
287    tpu_system_metadata = resolver.get_tpu_system_metadata()
288    num_hosts = tpu_system_metadata.num_hosts
289    ```
290
291    Returns:
292      A `tf.tpu.experimental.TPUSystemMetadata` object.
293    """
294    cluster_spec = self.cluster_spec()
295    cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
296    tpu_system_metadata = (
297        tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
298            self.master(),
299            cluster_def=cluster_def,
300            query_topology=False))
301
302    return tpu_system_metadata
303
304  def cluster_spec(self):
305    """Returns a ClusterSpec object based on the latest TPU information.
306
307    We retrieve the information from the GCE APIs every time this method is
308    called.
309
310    Returns:
311      A ClusterSpec containing host information returned from Cloud TPUs,
312      or None.
313
314    Raises:
315      RuntimeError: If the provided TPU is not healthy.
316    """
317    ############################################################################
318    # There are 6 potential cases this code must handle:
319    #  0. [Local case.] When a TPU is connected directly to the VM.
320    #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
321    #      a. Create a ClusterSpec that includes the coordinator job
322    #      b. Create a ClusterSpec without the coordinator job.
323    #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
324    #     tasks and
325    #      a. Create a ClusterSpec with the coordinator
326    #      b. Create a ClusterSpec without the coordinator
327    ############################################################################
328
329    if self._tpu != 'local':
330      network_endpoints = self._cloud_tpu_client.network_endpoints()
331      worker_list = [
332          '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
333          for endpoint in network_endpoints
334      ]
335      cluster_spec = {self.task_type: worker_list}
336      if self._coordinator_address:
337        # {1, 2}.a
338        cluster_spec[self._coordinator_name] = [self._coordinator_address]
339      return server_lib.ClusterSpec(cluster_spec)
340    else:
341      return server_lib.ClusterSpec({})
342
343  def num_accelerators(self,
344                       task_type=None,
345                       task_id=None,
346                       config_proto=None):
347    """Returns the number of TPU cores per worker.
348
349    Connects to the master and list all the devices present in the master,
350    and counts them up. Also verifies that the device counts per host in the
351    cluster is the same before returning the number of TPU cores per host.
352
353    Args:
354      task_type: Unused.
355      task_id: Unused.
356      config_proto: Used to create a connection to a TPU master in order to
357        retrieve the system metadata.
358
359    Raises:
360      RuntimeError: If we cannot talk to a TPU worker after retrying or if the
361        number of TPU devices per host is different.
362    """
363    if self._tpu == 'local':
364      return {
365          'TPU':
366              len([
367                  d for d in framework_config.list_logical_devices()
368                  if d.device_type == 'TPU'
369              ])
370      }
371
372    retry_count = 1
373    # TODO(b/120564445): Replace with standard library for retries.
374    while True:
375      try:
376        device_details = TPUClusterResolver._get_device_dict_and_cores(
377            cluster_resolver.get_accelerator_devices(
378                self.master(), config_proto=config_proto))
379        break
380      except errors.DeadlineExceededError:
381        error_message = ('Failed to connect to master. The TPU might not be '
382                         'ready (e.g. still scheduling) or the master '
383                         'address is incorrect: got (%s)' % self.master())
384        if retry_count <= _TPU_CONN_RETRIES:
385          logging.warning(error_message)
386          logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
387          retry_count += 1
388        else:
389          raise RuntimeError(error_message)
390
391    if device_details.total_cores:
392      return {
393          'TPU':
394              TPUClusterResolver._verify_and_return_same_core_count(
395                  device_details.device_map)
396      }
397    return {'TPU': 0}
398
399  def set_tpu_topology(self, serialized_tpu_topology):
400    """Sets the tpu topology info stored in this resolver."""
401    self._tpu_topology = topology_pb2.TopologyProto()
402    self._tpu_topology.ParseFromString(serialized_tpu_topology)
403
404  @property
405  def tpu_hardware_feature(self):
406    """Returns the tpu topology info stored."""
407    if self._tpu_topology is None:
408      return self._tpu_topology
409    return self._tpu_topology.tpu_hardware_feature
410
411  @property
412  def environment(self):
413    """Returns the current environment which TensorFlow is running in."""
414    return self._environment
415
416  def _start_local_server(self):
417    address = compat.as_text(self._cloud_tpu_client.get_local_ip())
418    self._server = server_lib.Server({'local': ['0.0.0.0:0']},
419                                     protocol='grpc',
420                                     config=None,
421                                     start=True)
422    # self._server.target is of the form: grpc://ipaddress:port
423    target = compat.as_bytes(self._server.target)
424    splits = target.split(compat.as_bytes(':'))
425    assert len(splits) == 3, self._server.target
426    assert splits[0] == compat.as_bytes('grpc'), self._server.target
427    self._coordinator_port = compat.as_text(splits[2])
428    self._coordinator_address = '%s:%s' % (
429        address, compat.as_text(self._coordinator_port))
430
431  def __deepcopy__(self, memo):
432    # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
433    return self
434