1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating TensorFlow servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.core.protobuf import tensorflow_server_pb2 23from tensorflow.python import pywrap_tensorflow as c_api 24from tensorflow.python.framework import errors 25from tensorflow.python.util import compat 26from tensorflow.python.util import deprecation 27from tensorflow.python.util.tf_export import tf_export 28 29 30def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 31 config): 32 """Creates a `tf.train.ServerDef` protocol buffer. 33 34 Args: 35 server_or_cluster_def: A `tf.train.ServerDef` or 36 `tf.train.ClusterDef` protocol buffer, or a 37 `tf.train.ClusterSpec` object, describing the server to be 38 defined and/or the cluster of which it is a member. 39 job_name: (Optional.) Specifies the name of the job of which the server 40 is a member. Defaults to the value in `server_or_cluster_def`, if 41 specified. 42 task_index: (Optional.) Specifies the task index of the server in its job. 43 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 44 defaults to 0 if the server's job has only one task. 45 protocol: (Optional.) Specifies the protocol to be used by the server. 46 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value 47 in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 48 config: (Options.) A `tf.ConfigProto` that specifies default configuration 49 options for all sessions that run on this server. 50 51 Returns: 52 A `tf.train.ServerDef`. 53 54 Raises: 55 TypeError: If the arguments do not have the appropriate type. 56 ValueError: If an argument is not specified and cannot be inferred. 57 """ 58 server_def = tensorflow_server_pb2.ServerDef() 59 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 60 server_def.MergeFrom(server_or_cluster_def) 61 if job_name is not None: 62 server_def.job_name = job_name 63 if task_index is not None: 64 server_def.task_index = task_index 65 if protocol is not None: 66 server_def.protocol = protocol 67 if config is not None: 68 server_def.default_session_config.MergeFrom(config) 69 else: 70 try: 71 cluster_spec = ClusterSpec(server_or_cluster_def) 72 except TypeError: 73 raise TypeError("Could not convert `server_or_cluster_def` to a " 74 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 75 if job_name is None: 76 if len(cluster_spec.jobs) == 1: 77 job_name = cluster_spec.jobs[0] 78 else: 79 raise ValueError("Must specify an explicit `job_name`.") 80 if task_index is None: 81 task_indices = cluster_spec.task_indices(job_name) 82 if len(task_indices) == 1: 83 task_index = task_indices[0] 84 else: 85 raise ValueError("Must specify an explicit `task_index`.") 86 if protocol is None: 87 protocol = "grpc" 88 89 server_def = tensorflow_server_pb2.ServerDef( 90 cluster=cluster_spec.as_cluster_def(), 91 job_name=job_name, task_index=task_index, protocol=protocol) 92 if config is not None: 93 server_def.default_session_config.MergeFrom(config) 94 return server_def 95 96 97@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) 98@deprecation.deprecated_endpoints("train.Server") 99class Server(object): 100 """An in-process TensorFlow server, for use in distributed training. 101 102 A `tf.train.Server` instance encapsulates a set of devices and a 103 `tf.Session` target that 104 can participate in distributed training. A server belongs to a 105 cluster (specified by a `tf.train.ClusterSpec`), and 106 corresponds to a particular task in a named job. The server can 107 communicate with any other server in the same cluster. 108 """ 109 110 def __init__(self, 111 server_or_cluster_def, 112 job_name=None, 113 task_index=None, 114 protocol=None, 115 config=None, 116 start=True): 117 """Creates a new server with the given definition. 118 119 The `job_name`, `task_index`, and `protocol` arguments are optional, and 120 override any information provided in `server_or_cluster_def`. 121 122 Args: 123 server_or_cluster_def: A `tf.train.ServerDef` or 124 `tf.train.ClusterDef` protocol buffer, or a 125 `tf.train.ClusterSpec` object, describing the server to be 126 created and/or the cluster of which it is a member. 127 job_name: (Optional.) Specifies the name of the job of which the server 128 is a member. Defaults to the value in `server_or_cluster_def`, if 129 specified. 130 task_index: (Optional.) Specifies the task index of the server in its 131 job. Defaults to the value in `server_or_cluster_def`, if specified. 132 Otherwise defaults to 0 if the server's job has only one task. 133 protocol: (Optional.) Specifies the protocol to be used by the server. 134 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the 135 value in `server_or_cluster_def`, if specified. Otherwise defaults to 136 `"grpc"`. 137 config: (Options.) A `tf.ConfigProto` that specifies default 138 configuration options for all sessions that run on this server. 139 start: (Optional.) Boolean, indicating whether to start the server 140 after creating it. Defaults to `True`. 141 142 Raises: 143 tf.errors.OpError: Or one of its subclasses if an error occurs while 144 creating the TensorFlow server. 145 """ 146 self._server_def = _make_server_def(server_or_cluster_def, 147 job_name, task_index, protocol, config) 148 self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) 149 if start: 150 self.start() 151 152 def __del__(self): 153 try: 154 c_api.TF_ServerStop(self._server) 155 # Clean shutdown of servers is not yet implemented, so 156 # we leak instead of calling c_api.TF_DeleteServer here. 157 # See: 158 # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 159 except errors.UnimplementedError: 160 pass 161 except AttributeError: 162 # At shutdown, `c_api` may have been garbage collected. 163 pass 164 self._server = None 165 166 def start(self): 167 """Starts this server. 168 169 Raises: 170 tf.errors.OpError: Or one of its subclasses if an error occurs while 171 starting the TensorFlow server. 172 """ 173 c_api.TF_ServerStart(self._server) 174 175 def join(self): 176 """Blocks until the server has shut down. 177 178 This method currently blocks forever. 179 180 Raises: 181 tf.errors.OpError: Or one of its subclasses if an error occurs while 182 joining the TensorFlow server. 183 """ 184 c_api.TF_ServerJoin(self._server) 185 186 @property 187 def server_def(self): 188 """Returns the `tf.train.ServerDef` for this server. 189 190 Returns: 191 A `tf.train.ServerDef` protocol buffer that describes the configuration 192 of this server. 193 """ 194 return self._server_def 195 196 @property 197 def target(self): 198 """Returns the target for a `tf.Session` to connect to this server. 199 200 To create a 201 `tf.Session` that 202 connects to this server, use the following snippet: 203 204 ```python 205 server = tf.train.Server(...) 206 with tf.Session(server.target): 207 # ... 208 ``` 209 210 Returns: 211 A string containing a session target for this server. 212 """ 213 return c_api.TF_ServerTarget(self._server) 214 215 @staticmethod 216 def create_local_server(config=None, start=True): 217 """Creates a new single-process cluster running on the local host. 218 219 This method is a convenience wrapper for creating a 220 `tf.train.Server` with a `tf.train.ServerDef` that specifies a 221 single-process cluster containing a single task in a job called 222 `"local"`. 223 224 Args: 225 config: (Options.) A `tf.ConfigProto` that specifies default 226 configuration options for all sessions that run on this server. 227 start: (Optional.) Boolean, indicating whether to start the server after 228 creating it. Defaults to `True`. 229 230 Returns: 231 A local `tf.train.Server`. 232 """ 233 # Specifying port 0 means that the OS will choose a free port for the 234 # server. 235 return Server({"local": ["localhost:0"]}, protocol="grpc", config=config, 236 start=start) 237 238 239@tf_export("train.ClusterSpec") 240class ClusterSpec(object): 241 """Represents a cluster as a set of "tasks", organized into "jobs". 242 243 A `tf.train.ClusterSpec` represents the set of processes that 244 participate in a distributed TensorFlow computation. Every 245 `tf.train.Server` is constructed in a particular cluster. 246 247 To create a cluster with two jobs and five tasks, you specify the 248 mapping from job names to lists of network addresses (typically 249 hostname-port pairs). 250 251 ```python 252 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 253 "worker1.example.com:2222", 254 "worker2.example.com:2222"], 255 "ps": ["ps0.example.com:2222", 256 "ps1.example.com:2222"]}) 257 ``` 258 259 Each job may also be specified as a sparse mapping from task indices 260 to network addresses. This enables a server to be configured without 261 needing to know the identity of (for example) all other worker 262 tasks: 263 264 ```python 265 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 266 "ps": ["ps0.example.com:2222", 267 "ps1.example.com:2222"]}) 268 ``` 269 """ 270 271 def __init__(self, cluster): 272 """Creates a `ClusterSpec`. 273 274 Args: 275 cluster: A dictionary mapping one or more job names to (i) a 276 list of network addresses, or (ii) a dictionary mapping integer 277 task indices to network addresses; or a `tf.train.ClusterDef` 278 protocol buffer. 279 280 Raises: 281 TypeError: If `cluster` is not a dictionary mapping strings to lists 282 of strings, and not a `tf.train.ClusterDef` protobuf. 283 """ 284 if isinstance(cluster, dict): 285 self._cluster_spec = {} 286 for job_name, tasks in cluster.items(): 287 if isinstance(tasks, (list, tuple)): 288 job_tasks = {i: task for i, task in enumerate(tasks)} 289 elif isinstance(tasks, dict): 290 job_tasks = {i: task for i, task in tasks.items()} 291 else: 292 raise TypeError("The tasks for job %r must be a list or a dictionary " 293 "from integers to strings." % job_name) 294 self._cluster_spec[job_name] = job_tasks 295 self._make_cluster_def() 296 elif isinstance(cluster, cluster_pb2.ClusterDef): 297 self._cluster_def = cluster 298 self._cluster_spec = {} 299 for job_def in self._cluster_def.job: 300 self._cluster_spec[job_def.name] = { 301 i: t for i, t in job_def.tasks.items()} 302 elif isinstance(cluster, ClusterSpec): 303 self._cluster_def = cluster_pb2.ClusterDef() 304 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 305 self._cluster_spec = {} 306 for job_def in self._cluster_def.job: 307 self._cluster_spec[job_def.name] = { 308 i: t for i, t in job_def.tasks.items()} 309 else: 310 raise TypeError("`cluster` must be a dictionary mapping one or more " 311 "job names to lists of network addresses, or a " 312 "`ClusterDef` protocol buffer") 313 314 def __nonzero__(self): 315 return bool(self._cluster_spec) 316 317 # Python 3.x 318 __bool__ = __nonzero__ 319 320 def __eq__(self, other): 321 return self._cluster_spec == other 322 323 def __ne__(self, other): 324 return self._cluster_spec != other 325 326 def __str__(self): 327 key_values = self.as_dict() 328 string_items = [ 329 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)] 330 return "ClusterSpec({" + ", ".join(string_items) + "})" 331 332 def as_dict(self): 333 """Returns a dictionary from job names to their tasks. 334 335 For each job, if the task index space is dense, the corresponding 336 value will be a list of network addresses; otherwise it will be a 337 dictionary mapping (sparse) task indices to the corresponding 338 addresses. 339 340 Returns: 341 A dictionary mapping job names to lists or dictionaries 342 describing the tasks in those jobs. 343 """ 344 ret = {} 345 for job in self.jobs: 346 task_indices = self.task_indices(job) 347 if len(task_indices) == 0: 348 ret[job] = {} 349 continue 350 if max(task_indices) + 1 == len(task_indices): 351 # Return a list because the task indices are dense. This 352 # matches the behavior of `as_dict()` before support for 353 # sparse jobs was added. 354 ret[job] = self.job_tasks(job) 355 else: 356 ret[job] = {i: self.task_address(job, i) for i in task_indices} 357 return ret 358 359 def as_cluster_def(self): 360 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 361 return self._cluster_def 362 363 @property 364 def jobs(self): 365 """Returns a list of job names in this cluster. 366 367 Returns: 368 A list of strings, corresponding to the names of jobs in this cluster. 369 """ 370 return list(self._cluster_spec.keys()) 371 372 def num_tasks(self, job_name): 373 """Returns the number of tasks defined in the given job. 374 375 Args: 376 job_name: The string name of a job in this cluster. 377 378 Returns: 379 The number of tasks defined in the given job. 380 381 Raises: 382 ValueError: If `job_name` does not name a job in this cluster. 383 """ 384 try: 385 job = self._cluster_spec[job_name] 386 except KeyError: 387 raise ValueError("No such job in cluster: %r" % job_name) 388 return len(job) 389 390 def task_indices(self, job_name): 391 """Returns a list of valid task indices in the given job. 392 393 Args: 394 job_name: The string name of a job in this cluster. 395 396 Returns: 397 A list of valid task indices in the given job. 398 399 Raises: 400 ValueError: If `job_name` does not name a job in this cluster, 401 or no task with index `task_index` is defined in that job. 402 """ 403 try: 404 job = self._cluster_spec[job_name] 405 except KeyError: 406 raise ValueError("No such job in cluster: %r" % job_name) 407 return list(sorted(job.keys())) 408 409 def task_address(self, job_name, task_index): 410 """Returns the address of the given task in the given job. 411 412 Args: 413 job_name: The string name of a job in this cluster. 414 task_index: A non-negative integer. 415 416 Returns: 417 The address of the given task in the given job. 418 419 Raises: 420 ValueError: If `job_name` does not name a job in this cluster, 421 or no task with index `task_index` is defined in that job. 422 """ 423 try: 424 job = self._cluster_spec[job_name] 425 except KeyError: 426 raise ValueError("No such job in cluster: %r" % job_name) 427 try: 428 return job[task_index] 429 except KeyError: 430 raise ValueError("No task with index %r in job %r" 431 % (task_index, job_name)) 432 433 def job_tasks(self, job_name): 434 """Returns a mapping from task ID to address in the given job. 435 436 NOTE: For backwards compatibility, this method returns a list. If 437 the given job was defined with a sparse set of task indices, the 438 length of this list may not reflect the number of tasks defined in 439 this job. Use the `tf.train.ClusterSpec.num_tasks` method 440 to find the number of tasks defined in a particular job. 441 442 Args: 443 job_name: The string name of a job in this cluster. 444 445 Returns: 446 A list of task addresses, where the index in the list 447 corresponds to the task index of each task. The list may contain 448 `None` if the job was defined with a sparse set of task indices. 449 450 Raises: 451 ValueError: If `job_name` does not name a job in this cluster. 452 """ 453 try: 454 job = self._cluster_spec[job_name] 455 except KeyError: 456 raise ValueError("No such job in cluster: %r" % job_name) 457 ret = [None for _ in range(max(job.keys()) + 1)] 458 for i, task in job.items(): 459 ret[i] = task 460 return ret 461 462 def _make_cluster_def(self): 463 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 464 465 Raises: 466 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 467 of strings. 468 """ 469 self._cluster_def = cluster_pb2.ClusterDef() 470 471 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 472 for job_name, tasks in sorted(self._cluster_spec.items()): 473 try: 474 job_name = compat.as_bytes(job_name) 475 except TypeError: 476 raise TypeError("Job name %r must be bytes or unicode" % job_name) 477 478 job_def = self._cluster_def.job.add() 479 job_def.name = job_name 480 481 for i, task_address in sorted(tasks.items()): 482 try: 483 task_address = compat.as_bytes(task_address) 484 except TypeError: 485 raise TypeError( 486 "Task address %r must be bytes or unicode" % task_address) 487 job_def.tasks[i] = task_address 488