• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for multi-worker distribution strategies."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.python.distribute import distribute_coordinator_context as dc_context
23from tensorflow.python.training import server_lib
24
25
26def normalize_cluster_spec(cluster_spec):
27  """Makes `cluster_spec` into a `ClusterSpec` object.
28
29  Args:
30    cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
31      cluster configurations.
32
33  Returns:
34    a `ClusterSpec` object.
35
36  Raises:
37    ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
38      `ClusterDef`.
39  """
40  if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
41    return server_lib.ClusterSpec(cluster_spec)
42  elif not isinstance(cluster_spec, server_lib.ClusterSpec):
43    raise ValueError(
44        "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
45        "`tf.train.ClusterDef` object")
46  return cluster_spec
47
48
49# TODO(yuefengz): add more validations.
50def _validate_cluster_spec(cluster_spec, task_type, task_id):
51  """Validates `cluster_spec`.
52
53  It checks:
54  0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
55  1) task type is one of "chief", "worker" or "evaluator".
56  2) whether there is such a task type as `task_type` in the `cluster_spec`. The
57     only exception is `evaluator`. In other words, it is still a valid
58     configuration when `task_type` is `evaluator` but it doesn't appear in
59     `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
60  3) whether there is at most one "chief" job.
61  4) whether there is at most one "evaluator" job.
62  5) whether the `task_id` is smaller than the number of tasks for that
63     particular `task_type`.
64
65  Args:
66    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
67    task_type: string indicating the type of the task.
68    task_id: task_id: the id of the `task_type` in this cluster.
69  Throws:
70    ValueError: if `cluster_spec` fails any check.
71  """
72  if cluster_spec is None or task_type is None or task_id is None:
73    raise ValueError(
74        "None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
75
76  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
77  if task_type not in ("chief", "worker", "evaluator", "ps"):
78    raise ValueError(
79        "Unrecognized task_type: %r, valid task types are: \"chief\", "
80        "\"worker\", \"evaluator\" and \"ps\"." % task_type)
81
82  if task_type and task_type not in cluster_spec and task_type != "evaluator":
83    raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
84
85  if len(cluster_spec.get("chief", [])) > 1:
86    raise ValueError("There must be at most one 'chief' job.")
87
88  if len(cluster_spec.get("evaluator", [])) > 1:
89    raise ValueError("There must be at most one 'evaluator' job.")
90
91  # The `evaluator` job is allowed to be missing in `cluster_spec`.
92  if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
93    raise ValueError(
94        "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
95
96
97def is_chief(cluster_spec=None, task_type=None, task_id=None):
98  """Returns whether the given task is chief in the cluster.
99
100  Since there is at most one evaluator and the evaluator itself should be
101  independent of the training cluster, the evaluator job is also a chief job on
102  its own.
103
104  If this is currently running under a `_WorkerContext` of distribute
105  coordinator, the arguments can be omitted as the result is already available.
106
107  Args:
108    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
109      cluster configurations.
110    task_type: the task type in the cluster.
111    task_id: the task id in the cluster.
112
113  Returns:
114    a boolean indicating whether the given task is chief.
115
116  Raises:
117    ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
118      the maximum id of the `task_type`.
119  """
120  if has_worker_context():
121    # If a worker context exists, use the value provided by it.
122    return dc_context.get_current_worker_context().is_chief
123
124  _validate_cluster_spec(cluster_spec, task_type, task_id)
125  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
126
127  if task_type == "chief" or task_type == "evaluator":
128    return True
129
130  # If chief not in the cluster_spec, use the first worker as chief. This is
131  # common in CollectiveAllReduceStrategy.
132  if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
133    return True
134  return False
135
136
137def collective_leader(cluster_spec, task_type, task_id):
138  """Return the job name for the leader of for collective ops.
139
140  Args:
141    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
142      cluster configurations.
143    task_type: the task type in the cluster.
144    task_id: the task id in the cluster.
145
146  Returns:
147    a string indicating the leader job name or empty string if no need to set
148    leader job.
149  """
150  cluster_spec = normalize_cluster_spec(cluster_spec)
151
152  # No need to set collective leader for local.
153  if not cluster_spec.as_dict():
154    return ""
155
156  _validate_cluster_spec(cluster_spec, task_type, task_id)
157
158  # Only one evaluator, so no need to set collective leader.
159  if task_type == "evaluator":
160    return ""
161
162  # Use chief if chief is in the cluster.
163  if "chief" in cluster_spec.jobs:
164    return "/job:chief/replica:0/task:0"
165
166  # Use worker 0 if no chief job.
167  assert "worker" in cluster_spec.jobs
168  return "/job:worker/replica:0/task:0"
169
170
171def worker_count(cluster_spec, task_type):
172  """Returns the number of workers in the cluster."""
173  _validate_cluster_spec(cluster_spec, task_type, task_id=0)
174  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
175
176  # Other jobs such as "ps" shouldn't call this function.
177  if task_type not in ["chief", "worker", "evaluator"]:
178    raise ValueError("Unexpected `task_type` %r" % task_type)
179
180  if task_type == "evaluator":
181    # The "evaluator" is in its own cluster or its own partition of a cluster.
182    # So we don't have to count "chief" or "worker" if the current task is an
183    # "evaluator".
184    return len(cluster_spec["evaluator"])
185  else:
186    # In the non-evaluator case, we return the total number of "chief" and
187    # "worker" tasks as the "chief" is also a worker.
188    return (len(cluster_spec.get("chief", [])) + len(
189        cluster_spec.get("worker", [])))
190
191
192def id_in_cluster(cluster_spec, task_type, task_id):
193  """Returns a unique id for the task in the `task_type`'s cluster.
194
195  It returns an id ranging from [0, `worker_count(task_type, task_id)`).
196
197  Note: this function assumes that "evaluate" job is in its own cluster or its
198  own partition of a cluster.
199
200  Args:
201    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
202    task_type: string indicating the type of the task.
203    task_id: the id of the `task_type` in this cluster.
204
205  Returns:
206    an int indicating the unique id.
207
208  Throws:
209    ValueError: if `task_type` is not "chief", "worker" or "evaluator".
210  """
211  _validate_cluster_spec(cluster_spec, task_type, task_id)
212  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
213
214  # The "chief" job has always id 0 and there is at most one and "worker" jobs
215  # come after it.
216  if task_type == "chief":
217    return 0
218
219  if task_type == "worker":
220    return task_id + len(cluster_spec.get("chief", []))
221
222  # The "evaluator" is in its own cluster or its own partition of a cluster.
223  if task_type == "evaluator":
224    return task_id
225
226  # We currently don't assign ids to other tasks.
227  raise ValueError("There is no id for task_type %r" % task_type)
228
229
230def should_save_checkpoint():
231  """Returns whether the current worker should save checkpoints.
232
233  In multi-worker training, if saving checkpoint is requested by user, or needed
234  for fault-tolerance, the cluster should save checkpoint but not necessarily
235  every worker in the cluster should.
236
237  TODO(rchao): Consider generalizing this util to be `should_save_file` as there
238  can be other files to save such as summary.
239
240  Returns:
241      Whether this particular worker in the cluster should save checkpoints.
242  """
243  return dc_context.get_current_worker_context().should_checkpoint
244
245
246def should_load_checkpoint():
247  """Returns whether the current worker should load checkpoints.
248
249  In multi-worker training, if loading checkpoint is requested by user, or
250  needed for fault-tolerance, the cluster should load checkpoint but not
251  necessarily every worker in the cluster should.
252
253  Returns:
254      Whether this particular worker in the cluster should load checkpoints.
255  """
256  return dc_context.get_current_worker_context().experimental_should_init
257
258
259def wait_for_other_workers():
260  """Waits for other workers to reach the same call to this method."""
261  return dc_context.get_current_worker_context().wait_for_other_workers()
262
263
264def has_worker_context():
265  """Returns whether a worker context has been entered."""
266  return dc_context.get_current_worker_context() is not None
267