1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating TensorFlow servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.core.protobuf import device_filters_pb2 23from tensorflow.core.protobuf import tensorflow_server_pb2 24from tensorflow.python.client import pywrap_tf_session as c_api 25from tensorflow.python.framework import errors 26from tensorflow.python.util import compat 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 32 config): 33 """Creates a `tf.train.ServerDef` protocol buffer. 34 35 Args: 36 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 37 protocol buffer, or a `tf.train.ClusterSpec` object, describing the server 38 to be defined and/or the cluster of which it is a member. 39 job_name: (Optional.) Specifies the name of the job of which the server is a 40 member. Defaults to the value in `server_or_cluster_def`, if specified. 41 task_index: (Optional.) Specifies the task index of the server in its job. 42 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 43 defaults to 0 if the server's job has only one task. 44 protocol: (Optional.) Specifies the protocol to be used by the server. 45 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in 46 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 47 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 48 configuration options for all sessions that run on this server. 49 50 Returns: 51 A `tf.train.ServerDef`. 52 53 Raises: 54 TypeError: If the arguments do not have the appropriate type. 55 ValueError: If an argument is not specified and cannot be inferred. 56 """ 57 server_def = tensorflow_server_pb2.ServerDef() 58 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 59 server_def.MergeFrom(server_or_cluster_def) 60 if job_name is not None: 61 server_def.job_name = job_name 62 if task_index is not None: 63 server_def.task_index = task_index 64 if protocol is not None: 65 server_def.protocol = protocol 66 if config is not None: 67 server_def.default_session_config.MergeFrom(config) 68 else: 69 try: 70 cluster_spec = ClusterSpec(server_or_cluster_def) 71 except TypeError: 72 raise TypeError("Could not convert `server_or_cluster_def` to a " 73 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 74 if job_name is None: 75 if len(cluster_spec.jobs) == 1: 76 job_name = cluster_spec.jobs[0] 77 else: 78 raise ValueError("Must specify an explicit `job_name`.") 79 if task_index is None: 80 task_indices = cluster_spec.task_indices(job_name) 81 if len(task_indices) == 1: 82 task_index = task_indices[0] 83 else: 84 raise ValueError("Must specify an explicit `task_index`.") 85 if protocol is None: 86 protocol = "grpc" 87 88 server_def = tensorflow_server_pb2.ServerDef( 89 cluster=cluster_spec.as_cluster_def(), 90 job_name=job_name, 91 task_index=task_index, 92 protocol=protocol) 93 if config is not None: 94 server_def.default_session_config.MergeFrom(config) 95 return server_def 96 97 98@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) 99@deprecation.deprecated_endpoints("train.Server") 100class Server(object): 101 """An in-process TensorFlow server, for use in distributed training. 102 103 A `tf.distribute.Server` instance encapsulates a set of devices and a 104 `tf.compat.v1.Session` target that 105 can participate in distributed training. A server belongs to a 106 cluster (specified by a `tf.train.ClusterSpec`), and 107 corresponds to a particular task in a named job. The server can 108 communicate with any other server in the same cluster. 109 """ 110 111 def __init__(self, 112 server_or_cluster_def, 113 job_name=None, 114 task_index=None, 115 protocol=None, 116 config=None, 117 start=True): 118 """Creates a new server with the given definition. 119 120 The `job_name`, `task_index`, and `protocol` arguments are optional, and 121 override any information provided in `server_or_cluster_def`. 122 123 Args: 124 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 125 protocol buffer, or a `tf.train.ClusterSpec` object, describing the 126 server to be created and/or the cluster of which it is a member. 127 job_name: (Optional.) Specifies the name of the job of which the server is 128 a member. Defaults to the value in `server_or_cluster_def`, if 129 specified. 130 task_index: (Optional.) Specifies the task index of the server in its job. 131 Defaults to the value in `server_or_cluster_def`, if specified. 132 Otherwise defaults to 0 if the server's job has only one task. 133 protocol: (Optional.) Specifies the protocol to be used by the server. 134 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value 135 in `server_or_cluster_def`, if specified. Otherwise defaults to 136 `"grpc"`. 137 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 138 configuration options for all sessions that run on this server. 139 start: (Optional.) Boolean, indicating whether to start the server after 140 creating it. Defaults to `True`. 141 142 Raises: 143 tf.errors.OpError: Or one of its subclasses if an error occurs while 144 creating the TensorFlow server. 145 """ 146 self._server_def = _make_server_def(server_or_cluster_def, job_name, 147 task_index, protocol, config) 148 self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) 149 if start: 150 self.start() 151 152 def __del__(self): 153 try: 154 c_api.TF_ServerStop(self._server) 155 # Clean shutdown of servers is not yet implemented, so 156 # we leak instead of calling c_api.TF_DeleteServer here. 157 # See: 158 # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 159 except errors.UnimplementedError: 160 pass 161 except AttributeError: 162 # At shutdown, `c_api` may have been garbage collected. 163 pass 164 self._server = None 165 166 def start(self): 167 """Starts this server. 168 169 Raises: 170 tf.errors.OpError: Or one of its subclasses if an error occurs while 171 starting the TensorFlow server. 172 """ 173 c_api.TF_ServerStart(self._server) 174 175 def join(self): 176 """Blocks until the server has shut down. 177 178 This method currently blocks forever. 179 180 Raises: 181 tf.errors.OpError: Or one of its subclasses if an error occurs while 182 joining the TensorFlow server. 183 """ 184 c_api.TF_ServerJoin(self._server) 185 186 @property 187 def server_def(self): 188 """Returns the `tf.train.ServerDef` for this server. 189 190 Returns: 191 A `tf.train.ServerDef` protocol buffer that describes the configuration 192 of this server. 193 """ 194 return self._server_def 195 196 @property 197 def target(self): 198 """Returns the target for a `tf.compat.v1.Session` to connect to this server. 199 200 To create a 201 `tf.compat.v1.Session` that 202 connects to this server, use the following snippet: 203 204 ```python 205 server = tf.distribute.Server(...) 206 with tf.compat.v1.Session(server.target): 207 # ... 208 ``` 209 210 Returns: 211 A string containing a session target for this server. 212 """ 213 return c_api.TF_ServerTarget(self._server) 214 215 @staticmethod 216 def create_local_server(config=None, start=True): 217 """Creates a new single-process cluster running on the local host. 218 219 This method is a convenience wrapper for creating a 220 `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a 221 single-process cluster containing a single task in a job called 222 `"local"`. 223 224 Args: 225 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 226 configuration options for all sessions that run on this server. 227 start: (Optional.) Boolean, indicating whether to start the server after 228 creating it. Defaults to `True`. 229 230 Returns: 231 A local `tf.distribute.Server`. 232 """ 233 # Specifying port 0 means that the OS will choose a free port for the 234 # server. 235 return Server({"localhost": ["localhost:0"]}, 236 protocol="grpc", 237 config=config, 238 start=start) 239 240 241@tf_export("train.ClusterSpec") 242class ClusterSpec(object): 243 """Represents a cluster as a set of "tasks", organized into "jobs". 244 245 A `tf.train.ClusterSpec` represents the set of processes that 246 participate in a distributed TensorFlow computation. Every 247 `tf.distribute.Server` is constructed in a particular cluster. 248 249 To create a cluster with two jobs and five tasks, you specify the 250 mapping from job names to lists of network addresses (typically 251 hostname-port pairs). 252 253 ```python 254 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 255 "worker1.example.com:2222", 256 "worker2.example.com:2222"], 257 "ps": ["ps0.example.com:2222", 258 "ps1.example.com:2222"]}) 259 ``` 260 261 Each job may also be specified as a sparse mapping from task indices 262 to network addresses. This enables a server to be configured without 263 needing to know the identity of (for example) all other worker 264 tasks: 265 266 ```python 267 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 268 "ps": ["ps0.example.com:2222", 269 "ps1.example.com:2222"]}) 270 ``` 271 """ 272 273 def __init__(self, cluster): 274 """Creates a `ClusterSpec`. 275 276 Args: 277 cluster: A dictionary mapping one or more job names to (i) a list of 278 network addresses, or (ii) a dictionary mapping integer task indices to 279 network addresses; or a `tf.train.ClusterDef` protocol buffer. 280 281 Raises: 282 TypeError: If `cluster` is not a dictionary mapping strings to lists 283 of strings, and not a `tf.train.ClusterDef` protobuf. 284 """ 285 if isinstance(cluster, dict): 286 self._cluster_spec = {} 287 for job_name, tasks in cluster.items(): 288 if isinstance(tasks, (list, tuple)): 289 job_tasks = {i: task for i, task in enumerate(tasks)} 290 elif isinstance(tasks, dict): 291 job_tasks = {i: task for i, task in tasks.items()} 292 else: 293 raise TypeError("The tasks for job %r must be a list or a dictionary " 294 "from integers to strings." % job_name) 295 self._cluster_spec[job_name] = job_tasks 296 self._make_cluster_def() 297 elif isinstance(cluster, cluster_pb2.ClusterDef): 298 self._cluster_def = cluster 299 self._cluster_spec = {} 300 for job_def in self._cluster_def.job: 301 self._cluster_spec[job_def.name] = { 302 i: t for i, t in job_def.tasks.items() 303 } 304 elif isinstance(cluster, ClusterSpec): 305 self._cluster_def = cluster_pb2.ClusterDef() 306 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 307 self._cluster_spec = {} 308 for job_def in self._cluster_def.job: 309 self._cluster_spec[job_def.name] = { 310 i: t for i, t in job_def.tasks.items() 311 } 312 else: 313 raise TypeError("`cluster` must be a dictionary mapping one or more " 314 "job names to lists of network addresses, or a " 315 "`ClusterDef` protocol buffer") 316 317 def __nonzero__(self): 318 return bool(self._cluster_spec) 319 320 # Python 3.x 321 __bool__ = __nonzero__ 322 323 def __eq__(self, other): 324 return self._cluster_spec == other 325 326 def __ne__(self, other): 327 return self._cluster_spec != other 328 329 def __repr__(self): 330 key_values = self.as_dict() 331 string_items = [ 332 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values) 333 ] 334 return "ClusterSpec({" + ", ".join(string_items) + "})" 335 336 def as_dict(self): 337 """Returns a dictionary from job names to their tasks. 338 339 For each job, if the task index space is dense, the corresponding 340 value will be a list of network addresses; otherwise it will be a 341 dictionary mapping (sparse) task indices to the corresponding 342 addresses. 343 344 Returns: 345 A dictionary mapping job names to lists or dictionaries 346 describing the tasks in those jobs. 347 """ 348 ret = {} 349 for job in self.jobs: 350 task_indices = self.task_indices(job) 351 if len(task_indices) == 0: 352 ret[job] = {} 353 continue 354 if max(task_indices) + 1 == len(task_indices): 355 # Return a list because the task indices are dense. This 356 # matches the behavior of `as_dict()` before support for 357 # sparse jobs was added. 358 ret[job] = self.job_tasks(job) 359 else: 360 ret[job] = {i: self.task_address(job, i) for i in task_indices} 361 return ret 362 363 def as_cluster_def(self): 364 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 365 return self._cluster_def 366 367 @property 368 def jobs(self): 369 """Returns a list of job names in this cluster. 370 371 Returns: 372 A list of strings, corresponding to the names of jobs in this cluster. 373 """ 374 return list(self._cluster_spec.keys()) 375 376 def num_tasks(self, job_name): 377 """Returns the number of tasks defined in the given job. 378 379 Args: 380 job_name: The string name of a job in this cluster. 381 382 Returns: 383 The number of tasks defined in the given job. 384 385 Raises: 386 ValueError: If `job_name` does not name a job in this cluster. 387 """ 388 try: 389 job = self._cluster_spec[job_name] 390 except KeyError: 391 raise ValueError("No such job in cluster: %r" % job_name) 392 return len(job) 393 394 def task_indices(self, job_name): 395 """Returns a list of valid task indices in the given job. 396 397 Args: 398 job_name: The string name of a job in this cluster. 399 400 Returns: 401 A list of valid task indices in the given job. 402 403 Raises: 404 ValueError: If `job_name` does not name a job in this cluster, 405 or no task with index `task_index` is defined in that job. 406 """ 407 try: 408 job = self._cluster_spec[job_name] 409 except KeyError: 410 raise ValueError("No such job in cluster: %r" % job_name) 411 return list(sorted(job.keys())) 412 413 def task_address(self, job_name, task_index): 414 """Returns the address of the given task in the given job. 415 416 Args: 417 job_name: The string name of a job in this cluster. 418 task_index: A non-negative integer. 419 420 Returns: 421 The address of the given task in the given job. 422 423 Raises: 424 ValueError: If `job_name` does not name a job in this cluster, 425 or no task with index `task_index` is defined in that job. 426 """ 427 try: 428 job = self._cluster_spec[job_name] 429 except KeyError: 430 raise ValueError("No such job in cluster: %r" % job_name) 431 try: 432 return job[task_index] 433 except KeyError: 434 raise ValueError("No task with index %r in job %r" % 435 (task_index, job_name)) 436 437 def job_tasks(self, job_name): 438 """Returns a mapping from task ID to address in the given job. 439 440 NOTE: For backwards compatibility, this method returns a list. If 441 the given job was defined with a sparse set of task indices, the 442 length of this list may not reflect the number of tasks defined in 443 this job. Use the `tf.train.ClusterSpec.num_tasks` method 444 to find the number of tasks defined in a particular job. 445 446 Args: 447 job_name: The string name of a job in this cluster. 448 449 Returns: 450 A list of task addresses, where the index in the list 451 corresponds to the task index of each task. The list may contain 452 `None` if the job was defined with a sparse set of task indices. 453 454 Raises: 455 ValueError: If `job_name` does not name a job in this cluster. 456 """ 457 try: 458 job = self._cluster_spec[job_name] 459 except KeyError: 460 raise ValueError("No such job in cluster: %r" % job_name) 461 ret = [None for _ in range(max(job.keys()) + 1)] 462 for i, task in job.items(): 463 ret[i] = task 464 return ret 465 466 def _make_cluster_def(self): 467 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 468 469 Raises: 470 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 471 of strings. 472 """ 473 self._cluster_def = cluster_pb2.ClusterDef() 474 475 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 476 for job_name, tasks in sorted(self._cluster_spec.items()): 477 try: 478 job_name = compat.as_bytes(job_name) 479 except TypeError: 480 raise TypeError("Job name %r must be bytes or unicode" % job_name) 481 482 job_def = self._cluster_def.job.add() 483 job_def.name = job_name 484 485 for i, task_address in sorted(tasks.items()): 486 try: 487 task_address = compat.as_bytes(task_address) 488 except TypeError: 489 raise TypeError("Task address %r must be bytes or unicode" % 490 task_address) 491 job_def.tasks[i] = task_address 492 493 494@tf_export("config.experimental.ClusterDeviceFilters") 495class ClusterDeviceFilters(object): 496 """Represent a collection of device filters for the remote workers in cluster. 497 498 NOTE: this is an experimental API and subject to changes. 499 500 Set device filters for selective jobs and tasks. For each remote worker, the 501 device filters are a list of strings. When any filters are present, the remote 502 worker will ignore all devices which do not match any of its filters. Each 503 filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", 504 etc. Note that a device is always visible to the worker it is located on. 505 506 For example, to set the device filters for a parameter server cluster: 507 508 ```python 509 cdf = tf.config.experimental.ClusterDeviceFilters() 510 for i in range(num_workers): 511 cdf.set_device_filters('worker', i, ['/job:ps']) 512 for i in range(num_ps): 513 cdf.set_device_filters('ps', i, ['/job:worker']) 514 515 tf.config.experimental_connect_to_cluster(cluster_def, 516 cluster_device_filters=cdf) 517 ``` 518 519 The device filters can be partically specified. For remote tasks that do not 520 have device filters specified, all devices will be visible to them. 521 """ 522 523 def __init__(self): 524 # `_device_filters` is a dict mapping job names to job device filters. 525 # Job device filters further maps task IDs to task device filters. 526 # Task device filters are a list of strings, each one is a device filter. 527 self._device_filters = {} 528 529 # Serialized protobuf for cluster device filters. 530 self._cluster_device_filters = None 531 532 def set_device_filters(self, job_name, task_index, device_filters): 533 """Set the device filters for given job name and task id.""" 534 assert all(isinstance(df, str) for df in device_filters) 535 self._device_filters.setdefault(job_name, {}) 536 self._device_filters[job_name][task_index] = [df for df in device_filters] 537 # Due to updates in data, invalidate the serialized proto cache. 538 self._cluster_device_filters = None 539 540 def _as_cluster_device_filters(self): 541 """Returns a serialized protobuf of cluster device filters.""" 542 if self._cluster_device_filters: 543 return self._cluster_device_filters 544 545 self._make_cluster_device_filters() 546 return self._cluster_device_filters 547 548 def _make_cluster_device_filters(self): 549 """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. 550 551 Raises: 552 TypeError: If `_device_filters` is not a dictionary mapping strings to 553 a map of task indices and device filters. 554 """ 555 self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() 556 557 # Sort by job_name to produce deterministic protobufs. 558 for job_name, tasks in sorted(self._device_filters.items()): 559 try: 560 job_name = compat.as_bytes(job_name) 561 except TypeError: 562 raise TypeError("Job name %r must be bytes or unicode" % job_name) 563 564 jdf = self._cluster_device_filters.jobs.add() 565 jdf.name = job_name 566 567 for i, task_device_filters in sorted(tasks.items()): 568 for tdf in task_device_filters: 569 try: 570 tdf = compat.as_bytes(tdf) 571 except TypeError: 572 raise TypeError("Device filter %r must be bytes or unicode" % tdf) 573 jdf.tasks[i].device_filters.append(tdf) 574