• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating TensorFlow servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import device_filters_pb2
23from tensorflow.core.protobuf import tensorflow_server_pb2
24from tensorflow.python.client import pywrap_tf_session as c_api
25from tensorflow.python.framework import errors
26from tensorflow.python.util import compat
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
32                     config):
33  """Creates a `tf.train.ServerDef` protocol buffer.
34
35  Args:
36    server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
37      protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
38      to be defined and/or the cluster of which it is a member.
39    job_name: (Optional.) Specifies the name of the job of which the server is a
40      member. Defaults to the value in `server_or_cluster_def`, if specified.
41    task_index: (Optional.) Specifies the task index of the server in its job.
42      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
43      defaults to 0 if the server's job has only one task.
44    protocol: (Optional.) Specifies the protocol to be used by the server.
45      Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
46      `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
47    config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
48      configuration options for all sessions that run on this server.
49
50  Returns:
51    A `tf.train.ServerDef`.
52
53  Raises:
54    TypeError: If the arguments do not have the appropriate type.
55    ValueError: If an argument is not specified and cannot be inferred.
56  """
57  server_def = tensorflow_server_pb2.ServerDef()
58  if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
59    server_def.MergeFrom(server_or_cluster_def)
60    if job_name is not None:
61      server_def.job_name = job_name
62    if task_index is not None:
63      server_def.task_index = task_index
64    if protocol is not None:
65      server_def.protocol = protocol
66    if config is not None:
67      server_def.default_session_config.MergeFrom(config)
68  else:
69    try:
70      cluster_spec = ClusterSpec(server_or_cluster_def)
71    except TypeError:
72      raise TypeError("Could not convert `server_or_cluster_def` to a "
73                      "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
74    if job_name is None:
75      if len(cluster_spec.jobs) == 1:
76        job_name = cluster_spec.jobs[0]
77      else:
78        raise ValueError("Must specify an explicit `job_name`.")
79    if task_index is None:
80      task_indices = cluster_spec.task_indices(job_name)
81      if len(task_indices) == 1:
82        task_index = task_indices[0]
83      else:
84        raise ValueError("Must specify an explicit `task_index`.")
85    if protocol is None:
86      protocol = "grpc"
87
88    server_def = tensorflow_server_pb2.ServerDef(
89        cluster=cluster_spec.as_cluster_def(),
90        job_name=job_name,
91        task_index=task_index,
92        protocol=protocol)
93    if config is not None:
94      server_def.default_session_config.MergeFrom(config)
95  return server_def
96
97
98@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"])
99@deprecation.deprecated_endpoints("train.Server")
100class Server(object):
101  """An in-process TensorFlow server, for use in distributed training.
102
103  A `tf.distribute.Server` instance encapsulates a set of devices and a
104  `tf.compat.v1.Session` target that
105  can participate in distributed training. A server belongs to a
106  cluster (specified by a `tf.train.ClusterSpec`), and
107  corresponds to a particular task in a named job. The server can
108  communicate with any other server in the same cluster.
109  """
110
111  def __init__(self,
112               server_or_cluster_def,
113               job_name=None,
114               task_index=None,
115               protocol=None,
116               config=None,
117               start=True):
118    """Creates a new server with the given definition.
119
120    The `job_name`, `task_index`, and `protocol` arguments are optional, and
121    override any information provided in `server_or_cluster_def`.
122
123    Args:
124      server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
125        protocol buffer, or a `tf.train.ClusterSpec` object, describing the
126        server to be created and/or the cluster of which it is a member.
127      job_name: (Optional.) Specifies the name of the job of which the server is
128        a member. Defaults to the value in `server_or_cluster_def`, if
129        specified.
130      task_index: (Optional.) Specifies the task index of the server in its job.
131        Defaults to the value in `server_or_cluster_def`, if specified.
132        Otherwise defaults to 0 if the server's job has only one task.
133      protocol: (Optional.) Specifies the protocol to be used by the server.
134        Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
135        in `server_or_cluster_def`, if specified. Otherwise defaults to
136        `"grpc"`.
137      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
138        configuration options for all sessions that run on this server.
139      start: (Optional.) Boolean, indicating whether to start the server after
140        creating it. Defaults to `True`.
141
142    Raises:
143      tf.errors.OpError: Or one of its subclasses if an error occurs while
144        creating the TensorFlow server.
145    """
146    self._server_def = _make_server_def(server_or_cluster_def, job_name,
147                                        task_index, protocol, config)
148    self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
149    if start:
150      self.start()
151
152  def __del__(self):
153    try:
154      c_api.TF_ServerStop(self._server)
155      # Clean shutdown of servers is not yet implemented, so
156      # we leak instead of calling c_api.TF_DeleteServer here.
157      # See:
158      # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73
159    except errors.UnimplementedError:
160      pass
161    except AttributeError:
162      # At shutdown, `c_api` may have been garbage collected.
163      pass
164    self._server = None
165
166  def start(self):
167    """Starts this server.
168
169    Raises:
170      tf.errors.OpError: Or one of its subclasses if an error occurs while
171        starting the TensorFlow server.
172    """
173    c_api.TF_ServerStart(self._server)
174
175  def join(self):
176    """Blocks until the server has shut down.
177
178    This method currently blocks forever.
179
180    Raises:
181      tf.errors.OpError: Or one of its subclasses if an error occurs while
182        joining the TensorFlow server.
183    """
184    c_api.TF_ServerJoin(self._server)
185
186  @property
187  def server_def(self):
188    """Returns the `tf.train.ServerDef` for this server.
189
190    Returns:
191      A `tf.train.ServerDef` protocol buffer that describes the configuration
192      of this server.
193    """
194    return self._server_def
195
196  @property
197  def target(self):
198    """Returns the target for a `tf.compat.v1.Session` to connect to this server.
199
200    To create a
201    `tf.compat.v1.Session` that
202    connects to this server, use the following snippet:
203
204    ```python
205    server = tf.distribute.Server(...)
206    with tf.compat.v1.Session(server.target):
207      # ...
208    ```
209
210    Returns:
211      A string containing a session target for this server.
212    """
213    return c_api.TF_ServerTarget(self._server)
214
215  @staticmethod
216  def create_local_server(config=None, start=True):
217    """Creates a new single-process cluster running on the local host.
218
219    This method is a convenience wrapper for creating a
220    `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
221    single-process cluster containing a single task in a job called
222    `"local"`.
223
224    Args:
225      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
226        configuration options for all sessions that run on this server.
227      start: (Optional.) Boolean, indicating whether to start the server after
228        creating it. Defaults to `True`.
229
230    Returns:
231      A local `tf.distribute.Server`.
232    """
233    # Specifying port 0 means that the OS will choose a free port for the
234    # server.
235    return Server({"localhost": ["localhost:0"]},
236                  protocol="grpc",
237                  config=config,
238                  start=start)
239
240
241@tf_export("train.ClusterSpec")
242class ClusterSpec(object):
243  """Represents a cluster as a set of "tasks", organized into "jobs".
244
245  A `tf.train.ClusterSpec` represents the set of processes that
246  participate in a distributed TensorFlow computation. Every
247  `tf.distribute.Server` is constructed in a particular cluster.
248
249  To create a cluster with two jobs and five tasks, you specify the
250  mapping from job names to lists of network addresses (typically
251  hostname-port pairs).
252
253  ```python
254  cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
255                                             "worker1.example.com:2222",
256                                             "worker2.example.com:2222"],
257                                  "ps": ["ps0.example.com:2222",
258                                         "ps1.example.com:2222"]})
259  ```
260
261  Each job may also be specified as a sparse mapping from task indices
262  to network addresses. This enables a server to be configured without
263  needing to know the identity of (for example) all other worker
264  tasks:
265
266  ```python
267  cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
268                                  "ps": ["ps0.example.com:2222",
269                                         "ps1.example.com:2222"]})
270  ```
271  """
272
273  def __init__(self, cluster):
274    """Creates a `ClusterSpec`.
275
276    Args:
277      cluster: A dictionary mapping one or more job names to (i) a list of
278        network addresses, or (ii) a dictionary mapping integer task indices to
279        network addresses; or a `tf.train.ClusterDef` protocol buffer.
280
281    Raises:
282      TypeError: If `cluster` is not a dictionary mapping strings to lists
283        of strings, and not a `tf.train.ClusterDef` protobuf.
284    """
285    if isinstance(cluster, dict):
286      self._cluster_spec = {}
287      for job_name, tasks in cluster.items():
288        if isinstance(tasks, (list, tuple)):
289          job_tasks = {i: task for i, task in enumerate(tasks)}
290        elif isinstance(tasks, dict):
291          job_tasks = {i: task for i, task in tasks.items()}
292        else:
293          raise TypeError("The tasks for job %r must be a list or a dictionary "
294                          "from integers to strings." % job_name)
295        self._cluster_spec[job_name] = job_tasks
296      self._make_cluster_def()
297    elif isinstance(cluster, cluster_pb2.ClusterDef):
298      self._cluster_def = cluster
299      self._cluster_spec = {}
300      for job_def in self._cluster_def.job:
301        self._cluster_spec[job_def.name] = {
302            i: t for i, t in job_def.tasks.items()
303        }
304    elif isinstance(cluster, ClusterSpec):
305      self._cluster_def = cluster_pb2.ClusterDef()
306      self._cluster_def.MergeFrom(cluster.as_cluster_def())
307      self._cluster_spec = {}
308      for job_def in self._cluster_def.job:
309        self._cluster_spec[job_def.name] = {
310            i: t for i, t in job_def.tasks.items()
311        }
312    else:
313      raise TypeError("`cluster` must be a dictionary mapping one or more "
314                      "job names to lists of network addresses, or a "
315                      "`ClusterDef` protocol buffer")
316
317  def __nonzero__(self):
318    return bool(self._cluster_spec)
319
320  # Python 3.x
321  __bool__ = __nonzero__
322
323  def __eq__(self, other):
324    return self._cluster_spec == other
325
326  def __ne__(self, other):
327    return self._cluster_spec != other
328
329  def __repr__(self):
330    key_values = self.as_dict()
331    string_items = [
332        repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
333    ]
334    return "ClusterSpec({" + ", ".join(string_items) + "})"
335
336  def as_dict(self):
337    """Returns a dictionary from job names to their tasks.
338
339    For each job, if the task index space is dense, the corresponding
340    value will be a list of network addresses; otherwise it will be a
341    dictionary mapping (sparse) task indices to the corresponding
342    addresses.
343
344    Returns:
345      A dictionary mapping job names to lists or dictionaries
346      describing the tasks in those jobs.
347    """
348    ret = {}
349    for job in self.jobs:
350      task_indices = self.task_indices(job)
351      if len(task_indices) == 0:
352        ret[job] = {}
353        continue
354      if max(task_indices) + 1 == len(task_indices):
355        # Return a list because the task indices are dense. This
356        # matches the behavior of `as_dict()` before support for
357        # sparse jobs was added.
358        ret[job] = self.job_tasks(job)
359      else:
360        ret[job] = {i: self.task_address(job, i) for i in task_indices}
361    return ret
362
363  def as_cluster_def(self):
364    """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
365    return self._cluster_def
366
367  @property
368  def jobs(self):
369    """Returns a list of job names in this cluster.
370
371    Returns:
372      A list of strings, corresponding to the names of jobs in this cluster.
373    """
374    return list(self._cluster_spec.keys())
375
376  def num_tasks(self, job_name):
377    """Returns the number of tasks defined in the given job.
378
379    Args:
380      job_name: The string name of a job in this cluster.
381
382    Returns:
383      The number of tasks defined in the given job.
384
385    Raises:
386      ValueError: If `job_name` does not name a job in this cluster.
387    """
388    try:
389      job = self._cluster_spec[job_name]
390    except KeyError:
391      raise ValueError("No such job in cluster: %r" % job_name)
392    return len(job)
393
394  def task_indices(self, job_name):
395    """Returns a list of valid task indices in the given job.
396
397    Args:
398      job_name: The string name of a job in this cluster.
399
400    Returns:
401      A list of valid task indices in the given job.
402
403    Raises:
404      ValueError: If `job_name` does not name a job in this cluster,
405      or no task with index `task_index` is defined in that job.
406    """
407    try:
408      job = self._cluster_spec[job_name]
409    except KeyError:
410      raise ValueError("No such job in cluster: %r" % job_name)
411    return list(sorted(job.keys()))
412
413  def task_address(self, job_name, task_index):
414    """Returns the address of the given task in the given job.
415
416    Args:
417      job_name: The string name of a job in this cluster.
418      task_index: A non-negative integer.
419
420    Returns:
421      The address of the given task in the given job.
422
423    Raises:
424      ValueError: If `job_name` does not name a job in this cluster,
425      or no task with index `task_index` is defined in that job.
426    """
427    try:
428      job = self._cluster_spec[job_name]
429    except KeyError:
430      raise ValueError("No such job in cluster: %r" % job_name)
431    try:
432      return job[task_index]
433    except KeyError:
434      raise ValueError("No task with index %r in job %r" %
435                       (task_index, job_name))
436
437  def job_tasks(self, job_name):
438    """Returns a mapping from task ID to address in the given job.
439
440    NOTE: For backwards compatibility, this method returns a list. If
441    the given job was defined with a sparse set of task indices, the
442    length of this list may not reflect the number of tasks defined in
443    this job. Use the `tf.train.ClusterSpec.num_tasks` method
444    to find the number of tasks defined in a particular job.
445
446    Args:
447      job_name: The string name of a job in this cluster.
448
449    Returns:
450      A list of task addresses, where the index in the list
451      corresponds to the task index of each task. The list may contain
452      `None` if the job was defined with a sparse set of task indices.
453
454    Raises:
455      ValueError: If `job_name` does not name a job in this cluster.
456    """
457    try:
458      job = self._cluster_spec[job_name]
459    except KeyError:
460      raise ValueError("No such job in cluster: %r" % job_name)
461    ret = [None for _ in range(max(job.keys()) + 1)]
462    for i, task in job.items():
463      ret[i] = task
464    return ret
465
466  def _make_cluster_def(self):
467    """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
468
469    Raises:
470      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
471        of strings.
472    """
473    self._cluster_def = cluster_pb2.ClusterDef()
474
475    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
476    for job_name, tasks in sorted(self._cluster_spec.items()):
477      try:
478        job_name = compat.as_bytes(job_name)
479      except TypeError:
480        raise TypeError("Job name %r must be bytes or unicode" % job_name)
481
482      job_def = self._cluster_def.job.add()
483      job_def.name = job_name
484
485      for i, task_address in sorted(tasks.items()):
486        try:
487          task_address = compat.as_bytes(task_address)
488        except TypeError:
489          raise TypeError("Task address %r must be bytes or unicode" %
490                          task_address)
491        job_def.tasks[i] = task_address
492
493
494@tf_export("config.experimental.ClusterDeviceFilters")
495class ClusterDeviceFilters(object):
496  """Represent a collection of device filters for the remote workers in cluster.
497
498  NOTE: this is an experimental API and subject to changes.
499
500  Set device filters for selective jobs and tasks. For each remote worker, the
501  device filters are a list of strings. When any filters are present, the remote
502  worker will ignore all devices which do not match any of its filters. Each
503  filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3",
504  etc. Note that a device is always visible to the worker it is located on.
505
506  For example, to set the device filters for a parameter server cluster:
507
508  ```python
509  cdf = tf.config.experimental.ClusterDeviceFilters()
510  for i in range(num_workers):
511    cdf.set_device_filters('worker', i, ['/job:ps'])
512  for i in range(num_ps):
513    cdf.set_device_filters('ps', i, ['/job:worker'])
514
515  tf.config.experimental_connect_to_cluster(cluster_def,
516                                            cluster_device_filters=cdf)
517  ```
518
519  The device filters can be partically specified. For remote tasks that do not
520  have device filters specified, all devices will be visible to them.
521  """
522
523  def __init__(self):
524    # `_device_filters` is a dict mapping job names to job device filters.
525    # Job device filters further maps task IDs to task device filters.
526    # Task device filters are a list of strings, each one is a device filter.
527    self._device_filters = {}
528
529    # Serialized protobuf for cluster device filters.
530    self._cluster_device_filters = None
531
532  def set_device_filters(self, job_name, task_index, device_filters):
533    """Set the device filters for given job name and task id."""
534    assert all(isinstance(df, str) for df in device_filters)
535    self._device_filters.setdefault(job_name, {})
536    self._device_filters[job_name][task_index] = [df for df in device_filters]
537    # Due to updates in data, invalidate the serialized proto cache.
538    self._cluster_device_filters = None
539
540  def _as_cluster_device_filters(self):
541    """Returns a serialized protobuf of cluster device filters."""
542    if self._cluster_device_filters:
543      return self._cluster_device_filters
544
545    self._make_cluster_device_filters()
546    return self._cluster_device_filters
547
548  def _make_cluster_device_filters(self):
549    """Creates `ClusterDeviceFilters` proto based on the `_device_filters`.
550
551    Raises:
552      TypeError: If `_device_filters` is not a dictionary mapping strings to
553      a map of task indices and device filters.
554    """
555    self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters()
556
557    # Sort by job_name to produce deterministic protobufs.
558    for job_name, tasks in sorted(self._device_filters.items()):
559      try:
560        job_name = compat.as_bytes(job_name)
561      except TypeError:
562        raise TypeError("Job name %r must be bytes or unicode" % job_name)
563
564      jdf = self._cluster_device_filters.jobs.add()
565      jdf.name = job_name
566
567      for i, task_device_filters in sorted(tasks.items()):
568        for tdf in task_device_filters:
569          try:
570            tdf = compat.as_bytes(tdf)
571          except TypeError:
572            raise TypeError("Device filter %r must be bytes or unicode" % tdf)
573          jdf.tasks[i].device_filters.append(tdf)
574