• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for multi-worker distribution strategies."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.python.training import server_lib
23
24
25def normalize_cluster_spec(cluster_spec):
26  """Makes `cluster_spec` into a `ClusterSpec` object.
27
28  Args:
29    cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
30      cluster configurations.
31
32  Returns:
33    a `ClusterSpec` object.
34
35  Raises:
36    ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
37      `ClusterDef`.
38  """
39  if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
40    return server_lib.ClusterSpec(cluster_spec)
41  elif not isinstance(cluster_spec, server_lib.ClusterSpec):
42    raise ValueError(
43        "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
44        "`tf.train.ClusterDef` object")
45  return cluster_spec
46
47
48# TODO(yuefengz): add more validations.
49def _validate_cluster_spec(cluster_spec, task_type, task_id):
50  """Validates `cluster_spec`.
51
52  It checks
53  1) whether there is such a task type as `task_type` in the
54  `cluster_spec`.
55  2) whether there is at most one "chief" job.
56  3) whether the `task_id` is smaller than the number of `task_type`.
57
58  Args:
59    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
60    task_type: string indicating the type of the task.
61    task_id: task_id: the id of the `task_type` in this cluster.
62  Throws:
63    ValueError: if `cluster_spec` fails any check.
64  """
65  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
66  if task_type and task_type not in cluster_spec:
67    raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
68  if len(cluster_spec.get("chief", [])) > 1:
69    raise ValueError("There must be at most one 'chief' job.")
70  if task_id >= len(cluster_spec[task_type]):
71    raise ValueError(
72        "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
73
74
75def is_chief(cluster_spec, task_type, task_id):
76  """Returns whether the given task is chief in the cluster.
77
78  Args:
79    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
80      cluster configurations.
81    task_type: the task type in the cluster.
82    task_id: the task id in the cluster.
83
84  Returns:
85    a boolean indicating whether the given task is chief.
86
87  Raises:
88    ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
89      the maximum id of the `task_type`.
90  """
91  _validate_cluster_spec(cluster_spec, task_type, task_id)
92  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
93
94  if task_type == "chief":
95    return True
96
97  # If chief not in the cluster_spec, use the first worker as chief. This is
98  # common in CollectiveAllReduceStrategy.
99  if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
100    return True
101  return False
102
103
104def worker_count(cluster_spec, task_type):
105  """Returns the number of workers in the cluster."""
106  _validate_cluster_spec(cluster_spec, task_type, task_id=0)
107  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
108
109  # Other jobs such as "ps" shouldn't call this function.
110  if task_type not in ["chief", "worker", "evaluator"]:
111    raise ValueError("Unexpected `task_type` %r" % task_type)
112
113  if task_type == "evaluator":
114    # The "evaluator" is in its own cluster or its own partition of a cluster.
115    # So we don't have to count "chief" or "worker" if the current task is an
116    # "evaluator".
117    return len(cluster_spec["evaluator"])
118  else:
119    # In the non-evaluator case, we return the total number of "chief" and
120    # "worker" tasks as the "chief" is also a worker.
121    return (len(cluster_spec.get("chief", [])) + len(
122        cluster_spec.get("worker", [])))
123
124
125def id_in_cluster(cluster_spec, task_type, task_id):
126  """Returns a unique id for the task in the `task_type`'s cluster.
127
128  It returns an id ranging from [0, `worker_count(task_type, task_id)`).
129
130  Note: this function assumes that "evaluate" job is in its own cluster or its
131  own partition of a cluster.
132
133  Args:
134    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
135    task_type: string indicating the type of the task.
136    task_id: the id of the `task_type` in this cluster.
137
138  Returns:
139    an int indicating the unique id.
140
141  Throws:
142    ValueError: if `task_type` is not "chief", "worker" or "evaluator".
143  """
144  _validate_cluster_spec(cluster_spec, task_type, task_id)
145  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
146
147  # The "chief" job has always id 0 and there is at most one and "worker" jobs
148  # come after it.
149  if task_type == "chief":
150    return 0
151
152  if task_type == "worker":
153    return task_id + len(cluster_spec.get("chief", []))
154
155  # The "evaluator" is in its own cluster or its own partition of a cluster.
156  if task_type == "evaluator":
157    return task_id
158
159  # We currently don't assign ids to other tasks.
160  raise ValueError("There is no id for task_type %r" % task_type)
161