1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of ClusterResolvers for GCE instance groups.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 22from tensorflow.python.training.server_lib import ClusterSpec 23from tensorflow.python.util.tf_export import tf_export 24 25 26_GOOGLE_API_CLIENT_INSTALLED = True 27try: 28 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 29 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top 30except ImportError: 31 _GOOGLE_API_CLIENT_INSTALLED = False 32 33 34@tf_export('distribute.cluster_resolver.GCEClusterResolver') 35class GCEClusterResolver(ClusterResolver): 36 """ClusterResolver for Google Compute Engine. 37 38 This is an implementation of cluster resolvers for the Google Compute Engine 39 instance group platform. By specifying a project, zone, and instance group, 40 this will retrieve the IP address of all the instances within the instance 41 group and return a ClusterResolver object suitable for use for distributed 42 TensorFlow. 43 44 Note: this cluster resolver cannot retrieve `task_type`, `task_id` or 45 `rpc_layer`. To use it with some distribution strategies like 46 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to 47 specify `task_type` and `task_id` in the constructor. 48 49 Usage example with tf.distribute.Strategy: 50 51 ```Python 52 # On worker 0 53 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 54 "my-instance-group", 55 task_type="worker", task_id=0) 56 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 57 cluster_resolver=cluster_resolver) 58 59 # On worker 1 60 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 61 "my-instance-group", 62 task_type="worker", task_id=1) 63 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 64 cluster_resolver=cluster_resolver) 65 ``` 66 """ 67 68 def __init__(self, 69 project, 70 zone, 71 instance_group, 72 port, 73 task_type='worker', 74 task_id=0, 75 rpc_layer='grpc', 76 credentials='default', 77 service=None): 78 """Creates a new GCEClusterResolver object. 79 80 This takes in a few parameters and creates a GCEClusterResolver project. It 81 will then use these parameters to query the GCE API for the IP addresses of 82 each instance in the instance group. 83 84 Args: 85 project: Name of the GCE project. 86 zone: Zone of the GCE instance group. 87 instance_group: Name of the GCE instance group. 88 port: Port of the listening TensorFlow server (default: 8470) 89 task_type: Name of the TensorFlow job this GCE instance group of VM 90 instances belong to. 91 task_id: The task index for this particular VM, within the GCE 92 instance group. In particular, every single instance should be assigned 93 a unique ordinal index within an instance group manually so that they 94 can be distinguished from each other. 95 rpc_layer: The RPC layer TensorFlow should use to communicate across 96 instances. 97 credentials: GCE Credentials. If nothing is specified, this defaults to 98 GoogleCredentials.get_application_default(). 99 service: The GCE API object returned by the googleapiclient.discovery 100 function. (Default: discovery.build('compute', 'v1')). If you specify a 101 custom service object, then the credentials parameter will be ignored. 102 103 Raises: 104 ImportError: If the googleapiclient is not installed. 105 """ 106 self._project = project 107 self._zone = zone 108 self._instance_group = instance_group 109 self._task_type = task_type 110 self._task_id = task_id 111 self._rpc_layer = rpc_layer 112 self._port = port 113 self._credentials = credentials 114 115 if credentials == 'default': 116 if _GOOGLE_API_CLIENT_INSTALLED: 117 self._credentials = GoogleCredentials.get_application_default() 118 119 if service is None: 120 if not _GOOGLE_API_CLIENT_INSTALLED: 121 raise ImportError('googleapiclient must be installed before using the ' 122 'GCE cluster resolver') 123 self._service = discovery.build( 124 'compute', 'v1', 125 credentials=self._credentials) 126 else: 127 self._service = service 128 129 def cluster_spec(self): 130 """Returns a ClusterSpec object based on the latest instance group info. 131 132 This returns a ClusterSpec object for use based on information from the 133 specified instance group. We will retrieve the information from the GCE APIs 134 every time this method is called. 135 136 Returns: 137 A ClusterSpec containing host information retrieved from GCE. 138 """ 139 request_body = {'instanceState': 'RUNNING'} 140 request = self._service.instanceGroups().listInstances( 141 project=self._project, 142 zone=self._zone, 143 instanceGroups=self._instance_group, 144 body=request_body, 145 orderBy='name') 146 147 worker_list = [] 148 149 while request is not None: 150 response = request.execute() 151 152 items = response['items'] 153 for instance in items: 154 instance_name = instance['instance'].split('/')[-1] 155 156 instance_request = self._service.instances().get( 157 project=self._project, 158 zone=self._zone, 159 instance=instance_name) 160 161 if instance_request is not None: 162 instance_details = instance_request.execute() 163 ip_address = instance_details['networkInterfaces'][0]['networkIP'] 164 instance_url = '%s:%s' % (ip_address, self._port) 165 worker_list.append(instance_url) 166 167 request = self._service.instanceGroups().listInstances_next( 168 previous_request=request, 169 previous_response=response) 170 171 worker_list.sort() 172 return ClusterSpec({self._task_type: worker_list}) 173 174 def master(self, task_type=None, task_id=None, rpc_layer=None): 175 task_type = task_type if task_type is not None else self._task_type 176 task_id = task_id if task_id is not None else self._task_id 177 178 if task_type is not None and task_id is not None: 179 master = self.cluster_spec().task_address(task_type, task_id) 180 if rpc_layer or self._rpc_layer: 181 return '%s://%s' % (rpc_layer or self._rpc_layer, master) 182 else: 183 return master 184 185 return '' 186 187 @property 188 def task_type(self): 189 return self._task_type 190 191 @property 192 def task_id(self): 193 return self._task_id 194 195 @task_type.setter 196 def task_type(self, task_type): 197 raise RuntimeError( 198 'You cannot reset the task_type of the GCEClusterResolver after it has ' 199 'been created.') 200 201 @task_id.setter 202 def task_id(self, task_id): 203 self._task_id = task_id 204 205 @property 206 def rpc_layer(self): 207 return self._rpc_layer 208 209 @rpc_layer.setter 210 def rpc_layer(self, rpc_layer): 211 self._rpc_layer = rpc_layer 212