• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of ClusterResolvers for GCE instance groups."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
22from tensorflow.python.training.server_lib import ClusterSpec
23from tensorflow.python.util.tf_export import tf_export
24
25
26_GOOGLE_API_CLIENT_INSTALLED = True
27try:
28  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
29  from oauth2client.client import GoogleCredentials  # pylint: disable=g-import-not-at-top
30except ImportError:
31  _GOOGLE_API_CLIENT_INSTALLED = False
32
33
34@tf_export('distribute.cluster_resolver.GCEClusterResolver')
35class GCEClusterResolver(ClusterResolver):
36  """ClusterResolver for Google Compute Engine.
37
38  This is an implementation of cluster resolvers for the Google Compute Engine
39  instance group platform. By specifying a project, zone, and instance group,
40  this will retrieve the IP address of all the instances within the instance
41  group and return a ClusterResolver object suitable for use for distributed
42  TensorFlow.
43
44  Note: this cluster resolver cannot retrieve `task_type`, `task_id` or
45  `rpc_layer`. To use it with some distribution strategies like
46  `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
47  specify `task_type` and `task_id` in the constructor.
48
49  Usage example with tf.distribute.Strategy:
50
51    ```Python
52    # On worker 0
53    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
54                                          "my-instance-group",
55                                          task_type="worker", task_id=0)
56    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
57        cluster_resolver=cluster_resolver)
58
59    # On worker 1
60    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
61                                          "my-instance-group",
62                                          task_type="worker", task_id=1)
63    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
64        cluster_resolver=cluster_resolver)
65    ```
66  """
67
68  def __init__(self,
69               project,
70               zone,
71               instance_group,
72               port,
73               task_type='worker',
74               task_id=0,
75               rpc_layer='grpc',
76               credentials='default',
77               service=None):
78    """Creates a new GCEClusterResolver object.
79
80    This takes in a few parameters and creates a GCEClusterResolver project. It
81    will then use these parameters to query the GCE API for the IP addresses of
82    each instance in the instance group.
83
84    Args:
85      project: Name of the GCE project.
86      zone: Zone of the GCE instance group.
87      instance_group: Name of the GCE instance group.
88      port: Port of the listening TensorFlow server (default: 8470)
89      task_type: Name of the TensorFlow job this GCE instance group of VM
90        instances belong to.
91      task_id: The task index for this particular VM, within the GCE
92        instance group. In particular, every single instance should be assigned
93        a unique ordinal index within an instance group manually so that they
94        can be distinguished from each other.
95      rpc_layer: The RPC layer TensorFlow should use to communicate across
96        instances.
97      credentials: GCE Credentials. If nothing is specified, this defaults to
98        GoogleCredentials.get_application_default().
99      service: The GCE API object returned by the googleapiclient.discovery
100        function. (Default: discovery.build('compute', 'v1')). If you specify a
101        custom service object, then the credentials parameter will be ignored.
102
103    Raises:
104      ImportError: If the googleapiclient is not installed.
105    """
106    self._project = project
107    self._zone = zone
108    self._instance_group = instance_group
109    self._task_type = task_type
110    self._task_id = task_id
111    self._rpc_layer = rpc_layer
112    self._port = port
113    self._credentials = credentials
114
115    if credentials == 'default':
116      if _GOOGLE_API_CLIENT_INSTALLED:
117        self._credentials = GoogleCredentials.get_application_default()
118
119    if service is None:
120      if not _GOOGLE_API_CLIENT_INSTALLED:
121        raise ImportError('googleapiclient must be installed before using the '
122                          'GCE cluster resolver')
123      self._service = discovery.build(
124          'compute', 'v1',
125          credentials=self._credentials)
126    else:
127      self._service = service
128
129  def cluster_spec(self):
130    """Returns a ClusterSpec object based on the latest instance group info.
131
132    This returns a ClusterSpec object for use based on information from the
133    specified instance group. We will retrieve the information from the GCE APIs
134    every time this method is called.
135
136    Returns:
137      A ClusterSpec containing host information retrieved from GCE.
138    """
139    request_body = {'instanceState': 'RUNNING'}
140    request = self._service.instanceGroups().listInstances(
141        project=self._project,
142        zone=self._zone,
143        instanceGroups=self._instance_group,
144        body=request_body,
145        orderBy='name')
146
147    worker_list = []
148
149    while request is not None:
150      response = request.execute()
151
152      items = response['items']
153      for instance in items:
154        instance_name = instance['instance'].split('/')[-1]
155
156        instance_request = self._service.instances().get(
157            project=self._project,
158            zone=self._zone,
159            instance=instance_name)
160
161        if instance_request is not None:
162          instance_details = instance_request.execute()
163          ip_address = instance_details['networkInterfaces'][0]['networkIP']
164          instance_url = '%s:%s' % (ip_address, self._port)
165          worker_list.append(instance_url)
166
167      request = self._service.instanceGroups().listInstances_next(
168          previous_request=request,
169          previous_response=response)
170
171    worker_list.sort()
172    return ClusterSpec({self._task_type: worker_list})
173
174  def master(self, task_type=None, task_id=None, rpc_layer=None):
175    task_type = task_type if task_type is not None else self._task_type
176    task_id = task_id if task_id is not None else self._task_id
177
178    if task_type is not None and task_id is not None:
179      master = self.cluster_spec().task_address(task_type, task_id)
180      if rpc_layer or self._rpc_layer:
181        return '%s://%s' % (rpc_layer or self._rpc_layer, master)
182      else:
183        return master
184
185    return ''
186
187  @property
188  def task_type(self):
189    return self._task_type
190
191  @property
192  def task_id(self):
193    return self._task_id
194
195  @task_type.setter
196  def task_type(self, task_type):
197    raise RuntimeError(
198        'You cannot reset the task_type of the GCEClusterResolver after it has '
199        'been created.')
200
201  @task_id.setter
202  def task_id(self, task_id):
203    self._task_id = task_id
204
205  @property
206  def rpc_layer(self):
207    return self._rpc_layer
208
209  @rpc_layer.setter
210  def rpc_layer(self, rpc_layer):
211    self._rpc_layer = rpc_layer
212