1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to connect to remote servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23from absl import logging 24 25from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef 26from tensorflow.python import pywrap_tfe 27from tensorflow.python.distribute import device_util 28from tensorflow.python.distribute.cluster_resolver import cluster_resolver 29from tensorflow.python.eager import context 30from tensorflow.python.framework import ops 31from tensorflow.python.platform import remote_utils 32from tensorflow.python.training import server_lib 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36 37_GRPC_PREFIX = "grpc://" 38_LOCAL_MASTERS = ("", "local") 39 40 41@tf_export("config.experimental_connect_to_host") 42def connect_to_remote_host(remote_host=None, job_name="worker"): 43 """Connects to a single machine to enable remote execution on it. 44 45 Will make devices on the remote host available to use. Note that calling this 46 more than once will work, but will invalidate any tensor handles on the old 47 remote devices. 48 49 Using the default job_name of worker, you can schedule ops to run remotely as 50 follows: 51 ```python 52 # When eager execution is enabled, connect to the remote host. 53 tf.config.experimental_connect_to_host("exampleaddr.com:9876") 54 55 with ops.device("job:worker/replica:0/task:1/device:CPU:0"): 56 # The following tensors should be resident on the remote device, and the op 57 # will also execute remotely. 58 x1 = array_ops.ones([2, 2]) 59 x2 = array_ops.ones([2, 2]) 60 y = math_ops.matmul(x1, x2) 61 ``` 62 63 Args: 64 remote_host: a single or a list the remote server addr in host-port format. 65 job_name: The job name under which the new server will be accessible. 66 67 Raises: 68 ValueError: if remote_host is None. 69 """ 70 if not remote_host: 71 raise ValueError("Must provide at least one remote_host") 72 73 remote_hosts = nest.flatten(remote_host) 74 cluster_spec = server_lib.ClusterSpec( 75 {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]}) 76 77 connect_to_cluster(cluster_spec) 78 79 80@tf_export("config.experimental_connect_to_cluster") 81def connect_to_cluster(cluster_spec_or_resolver, 82 job_name="localhost", 83 task_index=0, 84 protocol=None, 85 make_master_device_default=True, 86 cluster_device_filters=None): 87 """Connects to the given cluster. 88 89 Will make devices on the cluster available to use. Note that calling this more 90 than once will work, but will invalidate any tensor handles on the old remote 91 devices. 92 93 If the given local job name is not present in the cluster specification, it 94 will be automatically added, using an unused port on the localhost. 95 96 Device filters can be specified to isolate groups of remote tasks to avoid 97 undesired accesses between workers. Workers accessing resources or launching 98 ops / functions on filtered remote devices will result in errors (unknown 99 devices). For any remote task, if no device filter is present, all cluster 100 devices will be visible; if any device filter is specified, it can only 101 see devices matching at least one filter. Devices on the task itself are 102 always visible. Device filters can be particially specified. 103 104 For example, for a cluster set up for parameter server training, the following 105 device filters might be specified: 106 107 ```python 108 cdf = tf.config.experimental.ClusterDeviceFilters() 109 # For any worker, only the devices on PS nodes and itself are visible 110 for i in range(num_workers): 111 cdf.set_device_filters('worker', i, ['/job:ps']) 112 # Similarly for any ps, only the devices on workers and itself are visible 113 for i in range(num_ps): 114 cdf.set_device_filters('ps', i, ['/job:worker']) 115 116 tf.config.experimental_connect_to_cluster(cluster_def, 117 cluster_device_filters=cdf) 118 ``` 119 120 Args: 121 cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing 122 the cluster. 123 job_name: The name of the local job. 124 task_index: The local task index. 125 protocol: The communication protocol, such as `"grpc"`. If unspecified, will 126 use the default from `python/platform/remote_utils.py`. 127 make_master_device_default: If True and a cluster resolver is passed, will 128 automatically enter the master task device scope, which indicates the 129 master becomes the default device to run ops. It won't do anything if 130 a cluster spec is passed. Will throw an error if the caller is currently 131 already in some device scope. 132 cluster_device_filters: an instance of 133 `tf.train.experimental/ClusterDeviceFilters` that specify device filters 134 to the remote tasks in cluster. 135 """ 136 if not context.executing_eagerly(): 137 raise ValueError( 138 "`tf.config.experimental_connect_to_cluster` can only be called in " 139 "eager mode." 140 ) 141 protocol = protocol or remote_utils.get_default_communication_protocol() 142 if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec): 143 cluster_spec = cluster_spec_or_resolver 144 elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver): 145 if cluster_spec_or_resolver.master() in _LOCAL_MASTERS: 146 # Do nothing if the master is local. 147 return 148 cluster_spec = cluster_spec_or_resolver.cluster_spec() 149 else: 150 raise ValueError( 151 "`cluster_spec_or_resolver` must be a `ClusterSpec` or a " 152 "`ClusterResolver`.") 153 154 cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) 155 if cluster_device_filters: 156 if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): 157 cluster_device_filters = copy.deepcopy( 158 cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access 159 else: 160 raise ValueError("`cluster_device_filters` must be an instance of " 161 "`tf.train.experimental.ClusterDeviceFilters`.") 162 163 # Automatically add local job, if not part of the cluster spec. 164 if job_name not in cluster_spec.jobs: 165 local_port = pywrap_tfe.TF_PickUnusedPortOrDie() 166 job_def = cluster_def.job.add() 167 job_def.name = job_name 168 # TODO(fishx): Update this to make sure remote worker has valid ip address 169 # to connect with local. 170 job_def.tasks[0] = "localhost:{}".format(local_port) 171 172 server_def = ServerDef( 173 cluster=cluster_def, 174 job_name=job_name, 175 task_index=task_index, 176 protocol=protocol, 177 default_session_config=context.context().config, 178 cluster_device_filters=cluster_device_filters) 179 180 if context.get_server_def() is None: 181 context.set_server_def(server_def) 182 else: 183 context.update_server_def(server_def) 184 185 if make_master_device_default and isinstance( 186 cluster_spec_or_resolver, 187 cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master(): 188 master = cluster_spec_or_resolver.master() 189 master_job_name = None 190 master_task_id = None 191 for job_name in cluster_spec.jobs: 192 for task_id in cluster_spec.task_indices(job_name): 193 task_address = cluster_spec.task_address(job_name, task_id) 194 if master in task_address or task_address in master: 195 master_job_name = job_name 196 master_task_id = task_id 197 break 198 199 if not master_job_name: 200 raise ValueError( 201 "`make_master_device_default` is set to True but cannot find " 202 "master %s in the cluster" % master) 203 204 master_device = "/job:{}/replica:0/task:{}".format(master_job_name, 205 master_task_id) 206 master_device = device_util.canonicalize(master_device) 207 current_device = device_util.current() 208 if current_device: 209 current_device = device_util.canonicalize(current_device) 210 if current_device and current_device != master_device: 211 raise ValueError("`connect_to_cluster` is called inside existing device " 212 "scope %s, which is different from the master device " 213 "scope %s to enter. This is not allowed." % 214 (current_device, master_device)) 215 # TODO(b/138389076): Think of the entering device scope behavior in the 216 # failure recovery case when dealing with preemptions. 217 if not current_device: 218 logging.info("Entering into master device scope: %s", master_device) 219 ops.device(master_device).__enter__() 220 221 222def _strip_prefix(s, prefix): 223 return s[len(prefix):] if s.startswith(prefix) else s 224