• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating TensorFlow servers."""
16
17from tensorflow.core.protobuf import cluster_pb2
18from tensorflow.core.protobuf import device_filters_pb2
19from tensorflow.core.protobuf import tensorflow_server_pb2
20from tensorflow.python.client import pywrap_tf_session as c_api
21from tensorflow.python.framework import errors
22from tensorflow.python.util import compat
23from tensorflow.python.util import deprecation
24from tensorflow.python.util.tf_export import tf_export
25
26
27def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
28                     config):
29  """Creates a `tf.train.ServerDef` protocol buffer.
30
31  Args:
32    server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
33      protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
34      to be defined and/or the cluster of which it is a member.
35    job_name: (Optional.) Specifies the name of the job of which the server is a
36      member. Defaults to the value in `server_or_cluster_def`, if specified.
37    task_index: (Optional.) Specifies the task index of the server in its job.
38      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
39      defaults to 0 if the server's job has only one task.
40    protocol: (Optional.) Specifies the protocol to be used by the server.
41      Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
42      `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
43    config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
44      configuration options for all sessions that run on this server.
45
46  Returns:
47    A `tf.train.ServerDef`.
48
49  Raises:
50    TypeError: If the arguments do not have the appropriate type.
51    ValueError: If an argument is not specified and cannot be inferred.
52  """
53  server_def = tensorflow_server_pb2.ServerDef()
54  if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
55    server_def.MergeFrom(server_or_cluster_def)
56    if job_name is not None:
57      server_def.job_name = job_name
58    if task_index is not None:
59      server_def.task_index = task_index
60    if protocol is not None:
61      server_def.protocol = protocol
62    if config is not None:
63      server_def.default_session_config.MergeFrom(config)
64  else:
65    try:
66      cluster_spec = ClusterSpec(server_or_cluster_def)
67    except TypeError:
68      raise TypeError("Could not convert `server_or_cluster_def` to a "
69                      "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
70    if job_name is None:
71      if len(cluster_spec.jobs) == 1:
72        job_name = cluster_spec.jobs[0]
73      else:
74        raise ValueError("Must specify an explicit `job_name`.")
75    if task_index is None:
76      task_indices = cluster_spec.task_indices(job_name)
77      if len(task_indices) == 1:
78        task_index = task_indices[0]
79      else:
80        raise ValueError("Must specify an explicit `task_index`.")
81    if protocol is None:
82      protocol = "grpc"
83
84    server_def = tensorflow_server_pb2.ServerDef(
85        cluster=cluster_spec.as_cluster_def(),
86        job_name=job_name,
87        task_index=task_index,
88        protocol=protocol)
89    if config is not None:
90      server_def.default_session_config.MergeFrom(config)
91  return server_def
92
93
94@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"])
95@deprecation.deprecated_endpoints("train.Server")
96class Server:
97  """An in-process TensorFlow server, for use in distributed training.
98
99  A `tf.distribute.Server` instance encapsulates a set of devices and a
100  `tf.compat.v1.Session` target that
101  can participate in distributed training. A server belongs to a
102  cluster (specified by a `tf.train.ClusterSpec`), and
103  corresponds to a particular task in a named job. The server can
104  communicate with any other server in the same cluster.
105  """
106
107  def __init__(self,
108               server_or_cluster_def,
109               job_name=None,
110               task_index=None,
111               protocol=None,
112               config=None,
113               start=True):
114    """Creates a new server with the given definition.
115
116    The `job_name`, `task_index`, and `protocol` arguments are optional, and
117    override any information provided in `server_or_cluster_def`.
118
119    Args:
120      server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
121        protocol buffer, or a `tf.train.ClusterSpec` object, describing the
122        server to be created and/or the cluster of which it is a member.
123      job_name: (Optional.) Specifies the name of the job of which the server is
124        a member. Defaults to the value in `server_or_cluster_def`, if
125        specified.
126      task_index: (Optional.) Specifies the task index of the server in its job.
127        Defaults to the value in `server_or_cluster_def`, if specified.
128        Otherwise defaults to 0 if the server's job has only one task.
129      protocol: (Optional.) Specifies the protocol to be used by the server.
130        Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
131        in `server_or_cluster_def`, if specified. Otherwise defaults to
132        `"grpc"`.
133      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
134        configuration options for all sessions that run on this server.
135      start: (Optional.) Boolean, indicating whether to start the server after
136        creating it. Defaults to `True`.
137
138    Raises:
139      tf.errors.OpError: Or one of its subclasses if an error occurs while
140        creating the TensorFlow server.
141    """
142    self._server_def = _make_server_def(server_or_cluster_def, job_name,
143                                        task_index, protocol, config)
144    self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
145    if start:
146      self.start()
147
148  def __del__(self):
149    # At shutdown, `errors` may have been garbage collected.
150    if errors is not None:
151      exception = errors.UnimplementedError
152    else:
153      exception = Exception
154    try:
155      c_api.TF_ServerStop(self._server)
156      # Clean shutdown of servers is not yet implemented, so
157      # we leak instead of calling c_api.TF_DeleteServer here.
158      # See:
159      # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73
160    except AttributeError:
161      # At shutdown, `c_api` may have been garbage collected.
162      pass
163    except exception:
164      pass
165    self._server = None
166
167  def start(self):
168    """Starts this server.
169
170    Raises:
171      tf.errors.OpError: Or one of its subclasses if an error occurs while
172        starting the TensorFlow server.
173    """
174    c_api.TF_ServerStart(self._server)
175
176  def join(self):
177    """Blocks until the server has shut down.
178
179    This method currently blocks forever.
180
181    Raises:
182      tf.errors.OpError: Or one of its subclasses if an error occurs while
183        joining the TensorFlow server.
184    """
185    c_api.TF_ServerJoin(self._server)
186
187  @property
188  def server_def(self):
189    """Returns the `tf.train.ServerDef` for this server.
190
191    Returns:
192      A `tf.train.ServerDef` protocol buffer that describes the configuration
193      of this server.
194    """
195    return self._server_def
196
197  @property
198  def target(self):
199    """Returns the target for a `tf.compat.v1.Session` to connect to this server.
200
201    To create a
202    `tf.compat.v1.Session` that
203    connects to this server, use the following snippet:
204
205    ```python
206    server = tf.distribute.Server(...)
207    with tf.compat.v1.Session(server.target):
208      # ...
209    ```
210
211    Returns:
212      A string containing a session target for this server.
213    """
214    return c_api.TF_ServerTarget(self._server)
215
216  @staticmethod
217  def create_local_server(config=None, start=True):
218    """Creates a new single-process cluster running on the local host.
219
220    This method is a convenience wrapper for creating a
221    `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
222    single-process cluster containing a single task in a job called
223    `"local"`.
224
225    Args:
226      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
227        configuration options for all sessions that run on this server.
228      start: (Optional.) Boolean, indicating whether to start the server after
229        creating it. Defaults to `True`.
230
231    Returns:
232      A local `tf.distribute.Server`.
233    """
234    # Specifying port 0 means that the OS will choose a free port for the
235    # server.
236    return Server({"localhost": ["localhost:0"]},
237                  protocol="grpc",
238                  config=config,
239                  start=start)
240
241
242@tf_export("train.ClusterSpec")
243class ClusterSpec:
244  """Represents a cluster as a set of "tasks", organized into "jobs".
245
246  A `tf.train.ClusterSpec` represents the set of processes that
247  participate in a distributed TensorFlow computation. Every
248  `tf.distribute.Server` is constructed in a particular cluster.
249
250  To create a cluster with two jobs and five tasks, you specify the
251  mapping from job names to lists of network addresses (typically
252  hostname-port pairs).
253
254  ```python
255  cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
256                                             "worker1.example.com:2222",
257                                             "worker2.example.com:2222"],
258                                  "ps": ["ps0.example.com:2222",
259                                         "ps1.example.com:2222"]})
260  ```
261
262  Each job may also be specified as a sparse mapping from task indices
263  to network addresses. This enables a server to be configured without
264  needing to know the identity of (for example) all other worker
265  tasks:
266
267  ```python
268  cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
269                                  "ps": ["ps0.example.com:2222",
270                                         "ps1.example.com:2222"]})
271  ```
272  """
273
274  def __init__(self, cluster):
275    """Creates a `ClusterSpec`.
276
277    Args:
278      cluster: A dictionary mapping one or more job names to (i) a list of
279        network addresses, or (ii) a dictionary mapping integer task indices to
280        network addresses; or a `tf.train.ClusterDef` protocol buffer.
281
282    Raises:
283      TypeError: If `cluster` is not a dictionary mapping strings to lists
284        of strings, and not a `tf.train.ClusterDef` protobuf.
285    """
286    if isinstance(cluster, dict):
287      self._cluster_spec = {}
288      for job_name, tasks in cluster.items():
289        if isinstance(tasks, (list, tuple)):
290          job_tasks = {i: task for i, task in enumerate(tasks)}
291        elif isinstance(tasks, dict):
292          job_tasks = {int(i): task for i, task in tasks.items()}
293        else:
294          raise TypeError("The tasks for job %r must be a list or a dictionary "
295                          "from integers to strings." % job_name)
296        self._cluster_spec[job_name] = job_tasks
297      self._make_cluster_def()
298    elif isinstance(cluster, cluster_pb2.ClusterDef):
299      self._cluster_def = cluster
300      self._cluster_spec = {}
301      for job_def in self._cluster_def.job:
302        self._cluster_spec[job_def.name] = {
303            i: t for i, t in job_def.tasks.items()
304        }
305    elif isinstance(cluster, ClusterSpec):
306      self._cluster_def = cluster_pb2.ClusterDef()
307      self._cluster_def.MergeFrom(cluster.as_cluster_def())
308      self._cluster_spec = {}
309      for job_def in self._cluster_def.job:
310        self._cluster_spec[job_def.name] = {
311            i: t for i, t in job_def.tasks.items()
312        }
313    else:
314      raise TypeError("`cluster` must be a dictionary mapping one or more "
315                      "job names to lists of network addresses, or a "
316                      "`ClusterDef` protocol buffer")
317
318  def __bool__(self):
319    return bool(self._cluster_spec)
320
321  # Python 2.x
322  __nonzero__ = __bool__
323
324  def __eq__(self, other):
325    return self._cluster_spec == other
326
327  def __ne__(self, other):
328    return self._cluster_spec != other
329
330  def __repr__(self):
331    key_values = self.as_dict()
332    string_items = [
333        repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
334    ]
335    return "ClusterSpec({" + ", ".join(string_items) + "})"
336
337  def as_dict(self):
338    """Returns a dictionary from job names to their tasks.
339
340    For each job, if the task index space is dense, the corresponding
341    value will be a list of network addresses; otherwise it will be a
342    dictionary mapping (sparse) task indices to the corresponding
343    addresses.
344
345    Returns:
346      A dictionary mapping job names to lists or dictionaries
347      describing the tasks in those jobs.
348    """
349    ret = {}
350    for job in self.jobs:
351      task_indices = self.task_indices(job)
352      if len(task_indices) == 0:
353        ret[job] = {}
354        continue
355      if max(task_indices) + 1 == len(task_indices):
356        # Return a list because the task indices are dense. This
357        # matches the behavior of `as_dict()` before support for
358        # sparse jobs was added.
359        ret[job] = self.job_tasks(job)
360      else:
361        ret[job] = {i: self.task_address(job, i) for i in task_indices}
362    return ret
363
364  def as_cluster_def(self):
365    """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
366    return self._cluster_def
367
368  @property
369  def jobs(self):
370    """Returns a list of job names in this cluster.
371
372    Returns:
373      A list of strings, corresponding to the names of jobs in this cluster.
374    """
375    return list(self._cluster_spec.keys())
376
377  def num_tasks(self, job_name):
378    """Returns the number of tasks defined in the given job.
379
380    Args:
381      job_name: The string name of a job in this cluster.
382
383    Returns:
384      The number of tasks defined in the given job.
385
386    Raises:
387      ValueError: If `job_name` does not name a job in this cluster.
388    """
389    try:
390      job = self._cluster_spec[job_name]
391    except KeyError:
392      raise ValueError("No such job in cluster: %r" % job_name)
393    return len(job)
394
395  def task_indices(self, job_name):
396    """Returns a list of valid task indices in the given job.
397
398    Args:
399      job_name: The string name of a job in this cluster.
400
401    Returns:
402      A list of valid task indices in the given job.
403
404    Raises:
405      ValueError: If `job_name` does not name a job in this cluster,
406      or no task with index `task_index` is defined in that job.
407    """
408    try:
409      job = self._cluster_spec[job_name]
410    except KeyError:
411      raise ValueError("No such job in cluster: %r" % job_name)
412    return list(sorted(job.keys()))
413
414  def task_address(self, job_name, task_index):
415    """Returns the address of the given task in the given job.
416
417    Args:
418      job_name: The string name of a job in this cluster.
419      task_index: A non-negative integer.
420
421    Returns:
422      The address of the given task in the given job.
423
424    Raises:
425      ValueError: If `job_name` does not name a job in this cluster,
426      or no task with index `task_index` is defined in that job.
427    """
428    try:
429      job = self._cluster_spec[job_name]
430    except KeyError:
431      raise ValueError("No such job in cluster: %r" % job_name)
432    try:
433      return job[task_index]
434    except KeyError:
435      raise ValueError("No task with index %r in job %r" %
436                       (task_index, job_name))
437
438  def job_tasks(self, job_name):
439    """Returns a mapping from task ID to address in the given job.
440
441    NOTE: For backwards compatibility, this method returns a list. If
442    the given job was defined with a sparse set of task indices, the
443    length of this list may not reflect the number of tasks defined in
444    this job. Use the `tf.train.ClusterSpec.num_tasks` method
445    to find the number of tasks defined in a particular job.
446
447    Args:
448      job_name: The string name of a job in this cluster.
449
450    Returns:
451      A list of task addresses, where the index in the list
452      corresponds to the task index of each task. The list may contain
453      `None` if the job was defined with a sparse set of task indices.
454
455    Raises:
456      ValueError: If `job_name` does not name a job in this cluster.
457    """
458    try:
459      job = self._cluster_spec[job_name]
460    except KeyError:
461      raise ValueError("No such job in cluster: %r" % job_name)
462    ret = [None for _ in range(max(job.keys()) + 1)]
463    for i, task in job.items():
464      ret[i] = task
465    return ret
466
467  def _make_cluster_def(self):
468    """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
469
470    Raises:
471      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
472        of strings.
473    """
474    self._cluster_def = cluster_pb2.ClusterDef()
475
476    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
477    for job_name, tasks in sorted(self._cluster_spec.items()):
478      try:
479        job_name = compat.as_bytes(job_name)
480      except TypeError:
481        raise TypeError("Job name %r must be bytes or unicode" % job_name)
482
483      job_def = self._cluster_def.job.add()
484      job_def.name = job_name
485
486      for i, task_address in sorted(tasks.items()):
487        try:
488          task_address = compat.as_bytes(task_address)
489        except TypeError:
490          raise TypeError("Task address %r must be bytes or unicode" %
491                          task_address)
492        job_def.tasks[i] = task_address
493
494
495@tf_export("config.experimental.ClusterDeviceFilters")
496class ClusterDeviceFilters:
497  """Represent a collection of device filters for the remote workers in cluster.
498
499  NOTE: this is an experimental API and subject to changes.
500
501  Set device filters for selective jobs and tasks. For each remote worker, the
502  device filters are a list of strings. When any filters are present, the remote
503  worker will ignore all devices which do not match any of its filters. Each
504  filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3",
505  etc. Note that a device is always visible to the worker it is located on.
506
507  For example, to set the device filters for a parameter server cluster:
508
509  ```python
510  cdf = tf.config.experimental.ClusterDeviceFilters()
511  for i in range(num_workers):
512    cdf.set_device_filters('worker', i, ['/job:ps'])
513  for i in range(num_ps):
514    cdf.set_device_filters('ps', i, ['/job:worker'])
515
516  tf.config.experimental_connect_to_cluster(cluster_def,
517                                            cluster_device_filters=cdf)
518  ```
519
520  The device filters can be partically specified. For remote tasks that do not
521  have device filters specified, all devices will be visible to them.
522  """
523
524  def __init__(self):
525    # `_device_filters` is a dict mapping job names to job device filters.
526    # Job device filters further maps task IDs to task device filters.
527    # Task device filters are a list of strings, each one is a device filter.
528    self._device_filters = {}
529
530    # Serialized protobuf for cluster device filters.
531    self._cluster_device_filters = None
532
533  def set_device_filters(self, job_name, task_index, device_filters):
534    """Set the device filters for given job name and task id."""
535    assert all(isinstance(df, str) for df in device_filters)
536    self._device_filters.setdefault(job_name, {})
537    self._device_filters[job_name][task_index] = [df for df in device_filters]
538    # Due to updates in data, invalidate the serialized proto cache.
539    self._cluster_device_filters = None
540
541  def _as_cluster_device_filters(self):
542    """Returns a serialized protobuf of cluster device filters."""
543    if self._cluster_device_filters:
544      return self._cluster_device_filters
545
546    self._make_cluster_device_filters()
547    return self._cluster_device_filters
548
549  def _make_cluster_device_filters(self):
550    """Creates `ClusterDeviceFilters` proto based on the `_device_filters`.
551
552    Raises:
553      TypeError: If `_device_filters` is not a dictionary mapping strings to
554      a map of task indices and device filters.
555    """
556    self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters()
557
558    # Sort by job_name to produce deterministic protobufs.
559    for job_name, tasks in sorted(self._device_filters.items()):
560      try:
561        job_name = compat.as_bytes(job_name)
562      except TypeError:
563        raise TypeError("Job name %r must be bytes or unicode" % job_name)
564
565      jdf = self._cluster_device_filters.jobs.add()
566      jdf.name = job_name
567
568      for i, task_device_filters in sorted(tasks.items()):
569        for tdf in task_device_filters:
570          try:
571            tdf = compat.as_bytes(tdf)
572          except TypeError:
573            raise TypeError("Device filter %r must be bytes or unicode" % tdf)
574          jdf.tasks[i].device_filters.append(tdf)
575