1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating TensorFlow servers.""" 16 17from tensorflow.core.protobuf import cluster_pb2 18from tensorflow.core.protobuf import device_filters_pb2 19from tensorflow.core.protobuf import tensorflow_server_pb2 20from tensorflow.python.client import pywrap_tf_session as c_api 21from tensorflow.python.framework import errors 22from tensorflow.python.util import compat 23from tensorflow.python.util import deprecation 24from tensorflow.python.util.tf_export import tf_export 25 26 27def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 28 config): 29 """Creates a `tf.train.ServerDef` protocol buffer. 30 31 Args: 32 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 33 protocol buffer, or a `tf.train.ClusterSpec` object, describing the server 34 to be defined and/or the cluster of which it is a member. 35 job_name: (Optional.) Specifies the name of the job of which the server is a 36 member. Defaults to the value in `server_or_cluster_def`, if specified. 37 task_index: (Optional.) Specifies the task index of the server in its job. 38 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 39 defaults to 0 if the server's job has only one task. 40 protocol: (Optional.) Specifies the protocol to be used by the server. 41 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in 42 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 43 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 44 configuration options for all sessions that run on this server. 45 46 Returns: 47 A `tf.train.ServerDef`. 48 49 Raises: 50 TypeError: If the arguments do not have the appropriate type. 51 ValueError: If an argument is not specified and cannot be inferred. 52 """ 53 server_def = tensorflow_server_pb2.ServerDef() 54 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 55 server_def.MergeFrom(server_or_cluster_def) 56 if job_name is not None: 57 server_def.job_name = job_name 58 if task_index is not None: 59 server_def.task_index = task_index 60 if protocol is not None: 61 server_def.protocol = protocol 62 if config is not None: 63 server_def.default_session_config.MergeFrom(config) 64 else: 65 try: 66 cluster_spec = ClusterSpec(server_or_cluster_def) 67 except TypeError: 68 raise TypeError("Could not convert `server_or_cluster_def` to a " 69 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 70 if job_name is None: 71 if len(cluster_spec.jobs) == 1: 72 job_name = cluster_spec.jobs[0] 73 else: 74 raise ValueError("Must specify an explicit `job_name`.") 75 if task_index is None: 76 task_indices = cluster_spec.task_indices(job_name) 77 if len(task_indices) == 1: 78 task_index = task_indices[0] 79 else: 80 raise ValueError("Must specify an explicit `task_index`.") 81 if protocol is None: 82 protocol = "grpc" 83 84 server_def = tensorflow_server_pb2.ServerDef( 85 cluster=cluster_spec.as_cluster_def(), 86 job_name=job_name, 87 task_index=task_index, 88 protocol=protocol) 89 if config is not None: 90 server_def.default_session_config.MergeFrom(config) 91 return server_def 92 93 94@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) 95@deprecation.deprecated_endpoints("train.Server") 96class Server: 97 """An in-process TensorFlow server, for use in distributed training. 98 99 A `tf.distribute.Server` instance encapsulates a set of devices and a 100 `tf.compat.v1.Session` target that 101 can participate in distributed training. A server belongs to a 102 cluster (specified by a `tf.train.ClusterSpec`), and 103 corresponds to a particular task in a named job. The server can 104 communicate with any other server in the same cluster. 105 """ 106 107 def __init__(self, 108 server_or_cluster_def, 109 job_name=None, 110 task_index=None, 111 protocol=None, 112 config=None, 113 start=True): 114 """Creates a new server with the given definition. 115 116 The `job_name`, `task_index`, and `protocol` arguments are optional, and 117 override any information provided in `server_or_cluster_def`. 118 119 Args: 120 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 121 protocol buffer, or a `tf.train.ClusterSpec` object, describing the 122 server to be created and/or the cluster of which it is a member. 123 job_name: (Optional.) Specifies the name of the job of which the server is 124 a member. Defaults to the value in `server_or_cluster_def`, if 125 specified. 126 task_index: (Optional.) Specifies the task index of the server in its job. 127 Defaults to the value in `server_or_cluster_def`, if specified. 128 Otherwise defaults to 0 if the server's job has only one task. 129 protocol: (Optional.) Specifies the protocol to be used by the server. 130 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value 131 in `server_or_cluster_def`, if specified. Otherwise defaults to 132 `"grpc"`. 133 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 134 configuration options for all sessions that run on this server. 135 start: (Optional.) Boolean, indicating whether to start the server after 136 creating it. Defaults to `True`. 137 138 Raises: 139 tf.errors.OpError: Or one of its subclasses if an error occurs while 140 creating the TensorFlow server. 141 """ 142 self._server_def = _make_server_def(server_or_cluster_def, job_name, 143 task_index, protocol, config) 144 self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) 145 if start: 146 self.start() 147 148 def __del__(self): 149 # At shutdown, `errors` may have been garbage collected. 150 if errors is not None: 151 exception = errors.UnimplementedError 152 else: 153 exception = Exception 154 try: 155 c_api.TF_ServerStop(self._server) 156 # Clean shutdown of servers is not yet implemented, so 157 # we leak instead of calling c_api.TF_DeleteServer here. 158 # See: 159 # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 160 except AttributeError: 161 # At shutdown, `c_api` may have been garbage collected. 162 pass 163 except exception: 164 pass 165 self._server = None 166 167 def start(self): 168 """Starts this server. 169 170 Raises: 171 tf.errors.OpError: Or one of its subclasses if an error occurs while 172 starting the TensorFlow server. 173 """ 174 c_api.TF_ServerStart(self._server) 175 176 def join(self): 177 """Blocks until the server has shut down. 178 179 This method currently blocks forever. 180 181 Raises: 182 tf.errors.OpError: Or one of its subclasses if an error occurs while 183 joining the TensorFlow server. 184 """ 185 c_api.TF_ServerJoin(self._server) 186 187 @property 188 def server_def(self): 189 """Returns the `tf.train.ServerDef` for this server. 190 191 Returns: 192 A `tf.train.ServerDef` protocol buffer that describes the configuration 193 of this server. 194 """ 195 return self._server_def 196 197 @property 198 def target(self): 199 """Returns the target for a `tf.compat.v1.Session` to connect to this server. 200 201 To create a 202 `tf.compat.v1.Session` that 203 connects to this server, use the following snippet: 204 205 ```python 206 server = tf.distribute.Server(...) 207 with tf.compat.v1.Session(server.target): 208 # ... 209 ``` 210 211 Returns: 212 A string containing a session target for this server. 213 """ 214 return c_api.TF_ServerTarget(self._server) 215 216 @staticmethod 217 def create_local_server(config=None, start=True): 218 """Creates a new single-process cluster running on the local host. 219 220 This method is a convenience wrapper for creating a 221 `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a 222 single-process cluster containing a single task in a job called 223 `"local"`. 224 225 Args: 226 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 227 configuration options for all sessions that run on this server. 228 start: (Optional.) Boolean, indicating whether to start the server after 229 creating it. Defaults to `True`. 230 231 Returns: 232 A local `tf.distribute.Server`. 233 """ 234 # Specifying port 0 means that the OS will choose a free port for the 235 # server. 236 return Server({"localhost": ["localhost:0"]}, 237 protocol="grpc", 238 config=config, 239 start=start) 240 241 242@tf_export("train.ClusterSpec") 243class ClusterSpec: 244 """Represents a cluster as a set of "tasks", organized into "jobs". 245 246 A `tf.train.ClusterSpec` represents the set of processes that 247 participate in a distributed TensorFlow computation. Every 248 `tf.distribute.Server` is constructed in a particular cluster. 249 250 To create a cluster with two jobs and five tasks, you specify the 251 mapping from job names to lists of network addresses (typically 252 hostname-port pairs). 253 254 ```python 255 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 256 "worker1.example.com:2222", 257 "worker2.example.com:2222"], 258 "ps": ["ps0.example.com:2222", 259 "ps1.example.com:2222"]}) 260 ``` 261 262 Each job may also be specified as a sparse mapping from task indices 263 to network addresses. This enables a server to be configured without 264 needing to know the identity of (for example) all other worker 265 tasks: 266 267 ```python 268 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 269 "ps": ["ps0.example.com:2222", 270 "ps1.example.com:2222"]}) 271 ``` 272 """ 273 274 def __init__(self, cluster): 275 """Creates a `ClusterSpec`. 276 277 Args: 278 cluster: A dictionary mapping one or more job names to (i) a list of 279 network addresses, or (ii) a dictionary mapping integer task indices to 280 network addresses; or a `tf.train.ClusterDef` protocol buffer. 281 282 Raises: 283 TypeError: If `cluster` is not a dictionary mapping strings to lists 284 of strings, and not a `tf.train.ClusterDef` protobuf. 285 """ 286 if isinstance(cluster, dict): 287 self._cluster_spec = {} 288 for job_name, tasks in cluster.items(): 289 if isinstance(tasks, (list, tuple)): 290 job_tasks = {i: task for i, task in enumerate(tasks)} 291 elif isinstance(tasks, dict): 292 job_tasks = {int(i): task for i, task in tasks.items()} 293 else: 294 raise TypeError("The tasks for job %r must be a list or a dictionary " 295 "from integers to strings." % job_name) 296 self._cluster_spec[job_name] = job_tasks 297 self._make_cluster_def() 298 elif isinstance(cluster, cluster_pb2.ClusterDef): 299 self._cluster_def = cluster 300 self._cluster_spec = {} 301 for job_def in self._cluster_def.job: 302 self._cluster_spec[job_def.name] = { 303 i: t for i, t in job_def.tasks.items() 304 } 305 elif isinstance(cluster, ClusterSpec): 306 self._cluster_def = cluster_pb2.ClusterDef() 307 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 308 self._cluster_spec = {} 309 for job_def in self._cluster_def.job: 310 self._cluster_spec[job_def.name] = { 311 i: t for i, t in job_def.tasks.items() 312 } 313 else: 314 raise TypeError("`cluster` must be a dictionary mapping one or more " 315 "job names to lists of network addresses, or a " 316 "`ClusterDef` protocol buffer") 317 318 def __bool__(self): 319 return bool(self._cluster_spec) 320 321 # Python 2.x 322 __nonzero__ = __bool__ 323 324 def __eq__(self, other): 325 return self._cluster_spec == other 326 327 def __ne__(self, other): 328 return self._cluster_spec != other 329 330 def __repr__(self): 331 key_values = self.as_dict() 332 string_items = [ 333 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values) 334 ] 335 return "ClusterSpec({" + ", ".join(string_items) + "})" 336 337 def as_dict(self): 338 """Returns a dictionary from job names to their tasks. 339 340 For each job, if the task index space is dense, the corresponding 341 value will be a list of network addresses; otherwise it will be a 342 dictionary mapping (sparse) task indices to the corresponding 343 addresses. 344 345 Returns: 346 A dictionary mapping job names to lists or dictionaries 347 describing the tasks in those jobs. 348 """ 349 ret = {} 350 for job in self.jobs: 351 task_indices = self.task_indices(job) 352 if len(task_indices) == 0: 353 ret[job] = {} 354 continue 355 if max(task_indices) + 1 == len(task_indices): 356 # Return a list because the task indices are dense. This 357 # matches the behavior of `as_dict()` before support for 358 # sparse jobs was added. 359 ret[job] = self.job_tasks(job) 360 else: 361 ret[job] = {i: self.task_address(job, i) for i in task_indices} 362 return ret 363 364 def as_cluster_def(self): 365 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 366 return self._cluster_def 367 368 @property 369 def jobs(self): 370 """Returns a list of job names in this cluster. 371 372 Returns: 373 A list of strings, corresponding to the names of jobs in this cluster. 374 """ 375 return list(self._cluster_spec.keys()) 376 377 def num_tasks(self, job_name): 378 """Returns the number of tasks defined in the given job. 379 380 Args: 381 job_name: The string name of a job in this cluster. 382 383 Returns: 384 The number of tasks defined in the given job. 385 386 Raises: 387 ValueError: If `job_name` does not name a job in this cluster. 388 """ 389 try: 390 job = self._cluster_spec[job_name] 391 except KeyError: 392 raise ValueError("No such job in cluster: %r" % job_name) 393 return len(job) 394 395 def task_indices(self, job_name): 396 """Returns a list of valid task indices in the given job. 397 398 Args: 399 job_name: The string name of a job in this cluster. 400 401 Returns: 402 A list of valid task indices in the given job. 403 404 Raises: 405 ValueError: If `job_name` does not name a job in this cluster, 406 or no task with index `task_index` is defined in that job. 407 """ 408 try: 409 job = self._cluster_spec[job_name] 410 except KeyError: 411 raise ValueError("No such job in cluster: %r" % job_name) 412 return list(sorted(job.keys())) 413 414 def task_address(self, job_name, task_index): 415 """Returns the address of the given task in the given job. 416 417 Args: 418 job_name: The string name of a job in this cluster. 419 task_index: A non-negative integer. 420 421 Returns: 422 The address of the given task in the given job. 423 424 Raises: 425 ValueError: If `job_name` does not name a job in this cluster, 426 or no task with index `task_index` is defined in that job. 427 """ 428 try: 429 job = self._cluster_spec[job_name] 430 except KeyError: 431 raise ValueError("No such job in cluster: %r" % job_name) 432 try: 433 return job[task_index] 434 except KeyError: 435 raise ValueError("No task with index %r in job %r" % 436 (task_index, job_name)) 437 438 def job_tasks(self, job_name): 439 """Returns a mapping from task ID to address in the given job. 440 441 NOTE: For backwards compatibility, this method returns a list. If 442 the given job was defined with a sparse set of task indices, the 443 length of this list may not reflect the number of tasks defined in 444 this job. Use the `tf.train.ClusterSpec.num_tasks` method 445 to find the number of tasks defined in a particular job. 446 447 Args: 448 job_name: The string name of a job in this cluster. 449 450 Returns: 451 A list of task addresses, where the index in the list 452 corresponds to the task index of each task. The list may contain 453 `None` if the job was defined with a sparse set of task indices. 454 455 Raises: 456 ValueError: If `job_name` does not name a job in this cluster. 457 """ 458 try: 459 job = self._cluster_spec[job_name] 460 except KeyError: 461 raise ValueError("No such job in cluster: %r" % job_name) 462 ret = [None for _ in range(max(job.keys()) + 1)] 463 for i, task in job.items(): 464 ret[i] = task 465 return ret 466 467 def _make_cluster_def(self): 468 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 469 470 Raises: 471 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 472 of strings. 473 """ 474 self._cluster_def = cluster_pb2.ClusterDef() 475 476 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 477 for job_name, tasks in sorted(self._cluster_spec.items()): 478 try: 479 job_name = compat.as_bytes(job_name) 480 except TypeError: 481 raise TypeError("Job name %r must be bytes or unicode" % job_name) 482 483 job_def = self._cluster_def.job.add() 484 job_def.name = job_name 485 486 for i, task_address in sorted(tasks.items()): 487 try: 488 task_address = compat.as_bytes(task_address) 489 except TypeError: 490 raise TypeError("Task address %r must be bytes or unicode" % 491 task_address) 492 job_def.tasks[i] = task_address 493 494 495@tf_export("config.experimental.ClusterDeviceFilters") 496class ClusterDeviceFilters: 497 """Represent a collection of device filters for the remote workers in cluster. 498 499 NOTE: this is an experimental API and subject to changes. 500 501 Set device filters for selective jobs and tasks. For each remote worker, the 502 device filters are a list of strings. When any filters are present, the remote 503 worker will ignore all devices which do not match any of its filters. Each 504 filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", 505 etc. Note that a device is always visible to the worker it is located on. 506 507 For example, to set the device filters for a parameter server cluster: 508 509 ```python 510 cdf = tf.config.experimental.ClusterDeviceFilters() 511 for i in range(num_workers): 512 cdf.set_device_filters('worker', i, ['/job:ps']) 513 for i in range(num_ps): 514 cdf.set_device_filters('ps', i, ['/job:worker']) 515 516 tf.config.experimental_connect_to_cluster(cluster_def, 517 cluster_device_filters=cdf) 518 ``` 519 520 The device filters can be partically specified. For remote tasks that do not 521 have device filters specified, all devices will be visible to them. 522 """ 523 524 def __init__(self): 525 # `_device_filters` is a dict mapping job names to job device filters. 526 # Job device filters further maps task IDs to task device filters. 527 # Task device filters are a list of strings, each one is a device filter. 528 self._device_filters = {} 529 530 # Serialized protobuf for cluster device filters. 531 self._cluster_device_filters = None 532 533 def set_device_filters(self, job_name, task_index, device_filters): 534 """Set the device filters for given job name and task id.""" 535 assert all(isinstance(df, str) for df in device_filters) 536 self._device_filters.setdefault(job_name, {}) 537 self._device_filters[job_name][task_index] = [df for df in device_filters] 538 # Due to updates in data, invalidate the serialized proto cache. 539 self._cluster_device_filters = None 540 541 def _as_cluster_device_filters(self): 542 """Returns a serialized protobuf of cluster device filters.""" 543 if self._cluster_device_filters: 544 return self._cluster_device_filters 545 546 self._make_cluster_device_filters() 547 return self._cluster_device_filters 548 549 def _make_cluster_device_filters(self): 550 """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. 551 552 Raises: 553 TypeError: If `_device_filters` is not a dictionary mapping strings to 554 a map of task indices and device filters. 555 """ 556 self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() 557 558 # Sort by job_name to produce deterministic protobufs. 559 for job_name, tasks in sorted(self._device_filters.items()): 560 try: 561 job_name = compat.as_bytes(job_name) 562 except TypeError: 563 raise TypeError("Job name %r must be bytes or unicode" % job_name) 564 565 jdf = self._cluster_device_filters.jobs.add() 566 jdf.name = job_name 567 568 for i, task_device_filters in sorted(tasks.items()): 569 for tdf in task_device_filters: 570 try: 571 tdf = compat.as_bytes(tdf) 572 except TypeError: 573 raise TypeError("Device filter %r must be bytes or unicode" % tdf) 574 jdf.tasks[i].device_filters.append(tdf) 575