1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating TensorFlow servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.core.protobuf import device_filters_pb2 23from tensorflow.core.protobuf import tensorflow_server_pb2 24from tensorflow.python.client import pywrap_tf_session as c_api 25from tensorflow.python.framework import errors 26from tensorflow.python.util import compat 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 32 config): 33 """Creates a `tf.train.ServerDef` protocol buffer. 34 35 Args: 36 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 37 protocol buffer, or a `tf.train.ClusterSpec` object, describing the server 38 to be defined and/or the cluster of which it is a member. 39 job_name: (Optional.) Specifies the name of the job of which the server is a 40 member. Defaults to the value in `server_or_cluster_def`, if specified. 41 task_index: (Optional.) Specifies the task index of the server in its job. 42 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 43 defaults to 0 if the server's job has only one task. 44 protocol: (Optional.) Specifies the protocol to be used by the server. 45 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in 46 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 47 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 48 configuration options for all sessions that run on this server. 49 50 Returns: 51 A `tf.train.ServerDef`. 52 53 Raises: 54 TypeError: If the arguments do not have the appropriate type. 55 ValueError: If an argument is not specified and cannot be inferred. 56 """ 57 server_def = tensorflow_server_pb2.ServerDef() 58 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 59 server_def.MergeFrom(server_or_cluster_def) 60 if job_name is not None: 61 server_def.job_name = job_name 62 if task_index is not None: 63 server_def.task_index = task_index 64 if protocol is not None: 65 server_def.protocol = protocol 66 if config is not None: 67 server_def.default_session_config.MergeFrom(config) 68 else: 69 try: 70 cluster_spec = ClusterSpec(server_or_cluster_def) 71 except TypeError: 72 raise TypeError("Could not convert `server_or_cluster_def` to a " 73 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 74 if job_name is None: 75 if len(cluster_spec.jobs) == 1: 76 job_name = cluster_spec.jobs[0] 77 else: 78 raise ValueError("Must specify an explicit `job_name`.") 79 if task_index is None: 80 task_indices = cluster_spec.task_indices(job_name) 81 if len(task_indices) == 1: 82 task_index = task_indices[0] 83 else: 84 raise ValueError("Must specify an explicit `task_index`.") 85 if protocol is None: 86 protocol = "grpc" 87 88 server_def = tensorflow_server_pb2.ServerDef( 89 cluster=cluster_spec.as_cluster_def(), 90 job_name=job_name, 91 task_index=task_index, 92 protocol=protocol) 93 if config is not None: 94 server_def.default_session_config.MergeFrom(config) 95 return server_def 96 97 98@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) 99@deprecation.deprecated_endpoints("train.Server") 100class Server(object): 101 """An in-process TensorFlow server, for use in distributed training. 102 103 A `tf.distribute.Server` instance encapsulates a set of devices and a 104 `tf.compat.v1.Session` target that 105 can participate in distributed training. A server belongs to a 106 cluster (specified by a `tf.train.ClusterSpec`), and 107 corresponds to a particular task in a named job. The server can 108 communicate with any other server in the same cluster. 109 """ 110 111 def __init__(self, 112 server_or_cluster_def, 113 job_name=None, 114 task_index=None, 115 protocol=None, 116 config=None, 117 start=True): 118 """Creates a new server with the given definition. 119 120 The `job_name`, `task_index`, and `protocol` arguments are optional, and 121 override any information provided in `server_or_cluster_def`. 122 123 Args: 124 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 125 protocol buffer, or a `tf.train.ClusterSpec` object, describing the 126 server to be created and/or the cluster of which it is a member. 127 job_name: (Optional.) Specifies the name of the job of which the server is 128 a member. Defaults to the value in `server_or_cluster_def`, if 129 specified. 130 task_index: (Optional.) Specifies the task index of the server in its job. 131 Defaults to the value in `server_or_cluster_def`, if specified. 132 Otherwise defaults to 0 if the server's job has only one task. 133 protocol: (Optional.) Specifies the protocol to be used by the server. 134 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value 135 in `server_or_cluster_def`, if specified. Otherwise defaults to 136 `"grpc"`. 137 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 138 configuration options for all sessions that run on this server. 139 start: (Optional.) Boolean, indicating whether to start the server after 140 creating it. Defaults to `True`. 141 142 Raises: 143 tf.errors.OpError: Or one of its subclasses if an error occurs while 144 creating the TensorFlow server. 145 """ 146 self._server_def = _make_server_def(server_or_cluster_def, job_name, 147 task_index, protocol, config) 148 self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) 149 if start: 150 self.start() 151 152 def __del__(self): 153 # At shutdown, `errors` may have been garbage collected. 154 if errors is not None: 155 exception = errors.UnimplementedError 156 else: 157 exception = Exception 158 try: 159 c_api.TF_ServerStop(self._server) 160 # Clean shutdown of servers is not yet implemented, so 161 # we leak instead of calling c_api.TF_DeleteServer here. 162 # See: 163 # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 164 except AttributeError: 165 # At shutdown, `c_api` may have been garbage collected. 166 pass 167 except exception: 168 pass 169 self._server = None 170 171 def start(self): 172 """Starts this server. 173 174 Raises: 175 tf.errors.OpError: Or one of its subclasses if an error occurs while 176 starting the TensorFlow server. 177 """ 178 c_api.TF_ServerStart(self._server) 179 180 def join(self): 181 """Blocks until the server has shut down. 182 183 This method currently blocks forever. 184 185 Raises: 186 tf.errors.OpError: Or one of its subclasses if an error occurs while 187 joining the TensorFlow server. 188 """ 189 c_api.TF_ServerJoin(self._server) 190 191 @property 192 def server_def(self): 193 """Returns the `tf.train.ServerDef` for this server. 194 195 Returns: 196 A `tf.train.ServerDef` protocol buffer that describes the configuration 197 of this server. 198 """ 199 return self._server_def 200 201 @property 202 def target(self): 203 """Returns the target for a `tf.compat.v1.Session` to connect to this server. 204 205 To create a 206 `tf.compat.v1.Session` that 207 connects to this server, use the following snippet: 208 209 ```python 210 server = tf.distribute.Server(...) 211 with tf.compat.v1.Session(server.target): 212 # ... 213 ``` 214 215 Returns: 216 A string containing a session target for this server. 217 """ 218 return c_api.TF_ServerTarget(self._server) 219 220 @staticmethod 221 def create_local_server(config=None, start=True): 222 """Creates a new single-process cluster running on the local host. 223 224 This method is a convenience wrapper for creating a 225 `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a 226 single-process cluster containing a single task in a job called 227 `"local"`. 228 229 Args: 230 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 231 configuration options for all sessions that run on this server. 232 start: (Optional.) Boolean, indicating whether to start the server after 233 creating it. Defaults to `True`. 234 235 Returns: 236 A local `tf.distribute.Server`. 237 """ 238 # Specifying port 0 means that the OS will choose a free port for the 239 # server. 240 return Server({"localhost": ["localhost:0"]}, 241 protocol="grpc", 242 config=config, 243 start=start) 244 245 246@tf_export("train.ClusterSpec") 247class ClusterSpec(object): 248 """Represents a cluster as a set of "tasks", organized into "jobs". 249 250 A `tf.train.ClusterSpec` represents the set of processes that 251 participate in a distributed TensorFlow computation. Every 252 `tf.distribute.Server` is constructed in a particular cluster. 253 254 To create a cluster with two jobs and five tasks, you specify the 255 mapping from job names to lists of network addresses (typically 256 hostname-port pairs). 257 258 ```python 259 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 260 "worker1.example.com:2222", 261 "worker2.example.com:2222"], 262 "ps": ["ps0.example.com:2222", 263 "ps1.example.com:2222"]}) 264 ``` 265 266 Each job may also be specified as a sparse mapping from task indices 267 to network addresses. This enables a server to be configured without 268 needing to know the identity of (for example) all other worker 269 tasks: 270 271 ```python 272 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 273 "ps": ["ps0.example.com:2222", 274 "ps1.example.com:2222"]}) 275 ``` 276 """ 277 278 def __init__(self, cluster): 279 """Creates a `ClusterSpec`. 280 281 Args: 282 cluster: A dictionary mapping one or more job names to (i) a list of 283 network addresses, or (ii) a dictionary mapping integer task indices to 284 network addresses; or a `tf.train.ClusterDef` protocol buffer. 285 286 Raises: 287 TypeError: If `cluster` is not a dictionary mapping strings to lists 288 of strings, and not a `tf.train.ClusterDef` protobuf. 289 """ 290 if isinstance(cluster, dict): 291 self._cluster_spec = {} 292 for job_name, tasks in cluster.items(): 293 if isinstance(tasks, (list, tuple)): 294 job_tasks = {i: task for i, task in enumerate(tasks)} 295 elif isinstance(tasks, dict): 296 job_tasks = {i: task for i, task in tasks.items()} 297 else: 298 raise TypeError("The tasks for job %r must be a list or a dictionary " 299 "from integers to strings." % job_name) 300 self._cluster_spec[job_name] = job_tasks 301 self._make_cluster_def() 302 elif isinstance(cluster, cluster_pb2.ClusterDef): 303 self._cluster_def = cluster 304 self._cluster_spec = {} 305 for job_def in self._cluster_def.job: 306 self._cluster_spec[job_def.name] = { 307 i: t for i, t in job_def.tasks.items() 308 } 309 elif isinstance(cluster, ClusterSpec): 310 self._cluster_def = cluster_pb2.ClusterDef() 311 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 312 self._cluster_spec = {} 313 for job_def in self._cluster_def.job: 314 self._cluster_spec[job_def.name] = { 315 i: t for i, t in job_def.tasks.items() 316 } 317 else: 318 raise TypeError("`cluster` must be a dictionary mapping one or more " 319 "job names to lists of network addresses, or a " 320 "`ClusterDef` protocol buffer") 321 322 def __bool__(self): 323 return bool(self._cluster_spec) 324 325 # Python 2.x 326 __nonzero__ = __bool__ 327 328 def __eq__(self, other): 329 return self._cluster_spec == other 330 331 def __ne__(self, other): 332 return self._cluster_spec != other 333 334 def __repr__(self): 335 key_values = self.as_dict() 336 string_items = [ 337 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values) 338 ] 339 return "ClusterSpec({" + ", ".join(string_items) + "})" 340 341 def as_dict(self): 342 """Returns a dictionary from job names to their tasks. 343 344 For each job, if the task index space is dense, the corresponding 345 value will be a list of network addresses; otherwise it will be a 346 dictionary mapping (sparse) task indices to the corresponding 347 addresses. 348 349 Returns: 350 A dictionary mapping job names to lists or dictionaries 351 describing the tasks in those jobs. 352 """ 353 ret = {} 354 for job in self.jobs: 355 task_indices = self.task_indices(job) 356 if len(task_indices) == 0: 357 ret[job] = {} 358 continue 359 if max(task_indices) + 1 == len(task_indices): 360 # Return a list because the task indices are dense. This 361 # matches the behavior of `as_dict()` before support for 362 # sparse jobs was added. 363 ret[job] = self.job_tasks(job) 364 else: 365 ret[job] = {i: self.task_address(job, i) for i in task_indices} 366 return ret 367 368 def as_cluster_def(self): 369 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 370 return self._cluster_def 371 372 @property 373 def jobs(self): 374 """Returns a list of job names in this cluster. 375 376 Returns: 377 A list of strings, corresponding to the names of jobs in this cluster. 378 """ 379 return list(self._cluster_spec.keys()) 380 381 def num_tasks(self, job_name): 382 """Returns the number of tasks defined in the given job. 383 384 Args: 385 job_name: The string name of a job in this cluster. 386 387 Returns: 388 The number of tasks defined in the given job. 389 390 Raises: 391 ValueError: If `job_name` does not name a job in this cluster. 392 """ 393 try: 394 job = self._cluster_spec[job_name] 395 except KeyError: 396 raise ValueError("No such job in cluster: %r" % job_name) 397 return len(job) 398 399 def task_indices(self, job_name): 400 """Returns a list of valid task indices in the given job. 401 402 Args: 403 job_name: The string name of a job in this cluster. 404 405 Returns: 406 A list of valid task indices in the given job. 407 408 Raises: 409 ValueError: If `job_name` does not name a job in this cluster, 410 or no task with index `task_index` is defined in that job. 411 """ 412 try: 413 job = self._cluster_spec[job_name] 414 except KeyError: 415 raise ValueError("No such job in cluster: %r" % job_name) 416 return list(sorted(job.keys())) 417 418 def task_address(self, job_name, task_index): 419 """Returns the address of the given task in the given job. 420 421 Args: 422 job_name: The string name of a job in this cluster. 423 task_index: A non-negative integer. 424 425 Returns: 426 The address of the given task in the given job. 427 428 Raises: 429 ValueError: If `job_name` does not name a job in this cluster, 430 or no task with index `task_index` is defined in that job. 431 """ 432 try: 433 job = self._cluster_spec[job_name] 434 except KeyError: 435 raise ValueError("No such job in cluster: %r" % job_name) 436 try: 437 return job[task_index] 438 except KeyError: 439 raise ValueError("No task with index %r in job %r" % 440 (task_index, job_name)) 441 442 def job_tasks(self, job_name): 443 """Returns a mapping from task ID to address in the given job. 444 445 NOTE: For backwards compatibility, this method returns a list. If 446 the given job was defined with a sparse set of task indices, the 447 length of this list may not reflect the number of tasks defined in 448 this job. Use the `tf.train.ClusterSpec.num_tasks` method 449 to find the number of tasks defined in a particular job. 450 451 Args: 452 job_name: The string name of a job in this cluster. 453 454 Returns: 455 A list of task addresses, where the index in the list 456 corresponds to the task index of each task. The list may contain 457 `None` if the job was defined with a sparse set of task indices. 458 459 Raises: 460 ValueError: If `job_name` does not name a job in this cluster. 461 """ 462 try: 463 job = self._cluster_spec[job_name] 464 except KeyError: 465 raise ValueError("No such job in cluster: %r" % job_name) 466 ret = [None for _ in range(max(job.keys()) + 1)] 467 for i, task in job.items(): 468 ret[i] = task 469 return ret 470 471 def _make_cluster_def(self): 472 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 473 474 Raises: 475 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 476 of strings. 477 """ 478 self._cluster_def = cluster_pb2.ClusterDef() 479 480 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 481 for job_name, tasks in sorted(self._cluster_spec.items()): 482 try: 483 job_name = compat.as_bytes(job_name) 484 except TypeError: 485 raise TypeError("Job name %r must be bytes or unicode" % job_name) 486 487 job_def = self._cluster_def.job.add() 488 job_def.name = job_name 489 490 for i, task_address in sorted(tasks.items()): 491 try: 492 task_address = compat.as_bytes(task_address) 493 except TypeError: 494 raise TypeError("Task address %r must be bytes or unicode" % 495 task_address) 496 job_def.tasks[i] = task_address 497 498 499@tf_export("config.experimental.ClusterDeviceFilters") 500class ClusterDeviceFilters(object): 501 """Represent a collection of device filters for the remote workers in cluster. 502 503 NOTE: this is an experimental API and subject to changes. 504 505 Set device filters for selective jobs and tasks. For each remote worker, the 506 device filters are a list of strings. When any filters are present, the remote 507 worker will ignore all devices which do not match any of its filters. Each 508 filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", 509 etc. Note that a device is always visible to the worker it is located on. 510 511 For example, to set the device filters for a parameter server cluster: 512 513 ```python 514 cdf = tf.config.experimental.ClusterDeviceFilters() 515 for i in range(num_workers): 516 cdf.set_device_filters('worker', i, ['/job:ps']) 517 for i in range(num_ps): 518 cdf.set_device_filters('ps', i, ['/job:worker']) 519 520 tf.config.experimental_connect_to_cluster(cluster_def, 521 cluster_device_filters=cdf) 522 ``` 523 524 The device filters can be partically specified. For remote tasks that do not 525 have device filters specified, all devices will be visible to them. 526 """ 527 528 def __init__(self): 529 # `_device_filters` is a dict mapping job names to job device filters. 530 # Job device filters further maps task IDs to task device filters. 531 # Task device filters are a list of strings, each one is a device filter. 532 self._device_filters = {} 533 534 # Serialized protobuf for cluster device filters. 535 self._cluster_device_filters = None 536 537 def set_device_filters(self, job_name, task_index, device_filters): 538 """Set the device filters for given job name and task id.""" 539 assert all(isinstance(df, str) for df in device_filters) 540 self._device_filters.setdefault(job_name, {}) 541 self._device_filters[job_name][task_index] = [df for df in device_filters] 542 # Due to updates in data, invalidate the serialized proto cache. 543 self._cluster_device_filters = None 544 545 def _as_cluster_device_filters(self): 546 """Returns a serialized protobuf of cluster device filters.""" 547 if self._cluster_device_filters: 548 return self._cluster_device_filters 549 550 self._make_cluster_device_filters() 551 return self._cluster_device_filters 552 553 def _make_cluster_device_filters(self): 554 """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. 555 556 Raises: 557 TypeError: If `_device_filters` is not a dictionary mapping strings to 558 a map of task indices and device filters. 559 """ 560 self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() 561 562 # Sort by job_name to produce deterministic protobufs. 563 for job_name, tasks in sorted(self._device_filters.items()): 564 try: 565 job_name = compat.as_bytes(job_name) 566 except TypeError: 567 raise TypeError("Job name %r must be bytes or unicode" % job_name) 568 569 jdf = self._cluster_device_filters.jobs.add() 570 jdf.name = job_name 571 572 for i, task_device_filters in sorted(tasks.items()): 573 for tdf in task_device_filters: 574 try: 575 tdf = compat.as_bytes(tdf) 576 except TypeError: 577 raise TypeError("Device filter %r must be bytes or unicode" % tdf) 578 jdf.tasks[i].device_filters.append(tdf) 579