1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for multi-worker distribution strategies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.python.distribute import distribute_coordinator_context as dc_context 23from tensorflow.python.training import server_lib 24 25 26def normalize_cluster_spec(cluster_spec): 27 """Makes `cluster_spec` into a `ClusterSpec` object. 28 29 Args: 30 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 31 cluster configurations. 32 33 Returns: 34 a `ClusterSpec` object. 35 36 Raises: 37 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a 38 `ClusterDef`. 39 """ 40 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): 41 return server_lib.ClusterSpec(cluster_spec) 42 elif not isinstance(cluster_spec, server_lib.ClusterSpec): 43 raise ValueError( 44 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 45 "`tf.train.ClusterDef` object") 46 return cluster_spec 47 48 49# TODO(yuefengz): add more validations. 50def _validate_cluster_spec(cluster_spec, task_type, task_id): 51 """Validates `cluster_spec`. 52 53 It checks: 54 0) None of `cluster_spec`, `task_type`, and `task_id` is `None`. 55 1) task type is one of "chief", "worker" or "evaluator". 56 2) whether there is such a task type as `task_type` in the `cluster_spec`. The 57 only exception is `evaluator`. In other words, it is still a valid 58 configuration when `task_type` is `evaluator` but it doesn't appear in 59 `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator. 60 3) whether there is at most one "chief" job. 61 4) whether there is at most one "evaluator" job. 62 5) whether the `task_id` is smaller than the number of tasks for that 63 particular `task_type`. 64 65 Args: 66 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 67 task_type: string indicating the type of the task. 68 task_id: task_id: the id of the `task_type` in this cluster. 69 Throws: 70 ValueError: if `cluster_spec` fails any check. 71 """ 72 if cluster_spec is None or task_type is None or task_id is None: 73 raise ValueError( 74 "None of `cluster_spec`, `task_type`, and `task_id` should be `None`.") 75 76 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 77 if task_type not in ("chief", "worker", "evaluator", "ps"): 78 raise ValueError( 79 "Unrecognized task_type: %r, valid task types are: \"chief\", " 80 "\"worker\", \"evaluator\" and \"ps\"." % task_type) 81 82 if task_type and task_type not in cluster_spec and task_type != "evaluator": 83 raise ValueError("`task_type` %r not found in cluster_spec." % task_type) 84 85 if len(cluster_spec.get("chief", [])) > 1: 86 raise ValueError("There must be at most one 'chief' job.") 87 88 if len(cluster_spec.get("evaluator", [])) > 1: 89 raise ValueError("There must be at most one 'evaluator' job.") 90 91 # The `evaluator` job is allowed to be missing in `cluster_spec`. 92 if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]): 93 raise ValueError( 94 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) 95 96 97def is_chief(cluster_spec=None, task_type=None, task_id=None): 98 """Returns whether the given task is chief in the cluster. 99 100 Since there is at most one evaluator and the evaluator itself should be 101 independent of the training cluster, the evaluator job is also a chief job on 102 its own. 103 104 If this is currently running under a `_WorkerContext` of distribute 105 coordinator, the arguments can be omitted as the result is already available. 106 107 Args: 108 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 109 cluster configurations. 110 task_type: the task type in the cluster. 111 task_id: the task id in the cluster. 112 113 Returns: 114 a boolean indicating whether the given task is chief. 115 116 Raises: 117 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds 118 the maximum id of the `task_type`. 119 """ 120 if has_worker_context(): 121 # If a worker context exists, use the value provided by it. 122 return dc_context.get_current_worker_context().is_chief 123 124 _validate_cluster_spec(cluster_spec, task_type, task_id) 125 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 126 127 if task_type == "chief" or task_type == "evaluator": 128 return True 129 130 # If chief not in the cluster_spec, use the first worker as chief. This is 131 # common in CollectiveAllReduceStrategy. 132 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0): 133 return True 134 return False 135 136 137def collective_leader(cluster_spec, task_type, task_id): 138 """Return the job name for the leader of for collective ops. 139 140 Args: 141 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 142 cluster configurations. 143 task_type: the task type in the cluster. 144 task_id: the task id in the cluster. 145 146 Returns: 147 a string indicating the leader job name or empty string if no need to set 148 leader job. 149 """ 150 cluster_spec = normalize_cluster_spec(cluster_spec) 151 152 # No need to set collective leader for local. 153 if not cluster_spec.as_dict(): 154 return "" 155 156 _validate_cluster_spec(cluster_spec, task_type, task_id) 157 158 # Only one evaluator, so no need to set collective leader. 159 if task_type == "evaluator": 160 return "" 161 162 # Use chief if chief is in the cluster. 163 if "chief" in cluster_spec.jobs: 164 return "/job:chief/replica:0/task:0" 165 166 # Use worker 0 if no chief job. 167 assert "worker" in cluster_spec.jobs 168 return "/job:worker/replica:0/task:0" 169 170 171def worker_count(cluster_spec, task_type): 172 """Returns the number of workers in the cluster.""" 173 _validate_cluster_spec(cluster_spec, task_type, task_id=0) 174 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 175 176 # Other jobs such as "ps" shouldn't call this function. 177 if task_type not in ["chief", "worker", "evaluator"]: 178 raise ValueError("Unexpected `task_type` %r" % task_type) 179 180 if task_type == "evaluator": 181 # The "evaluator" is in its own cluster or its own partition of a cluster. 182 # So we don't have to count "chief" or "worker" if the current task is an 183 # "evaluator". 184 return len(cluster_spec["evaluator"]) 185 else: 186 # In the non-evaluator case, we return the total number of "chief" and 187 # "worker" tasks as the "chief" is also a worker. 188 return (len(cluster_spec.get("chief", [])) + len( 189 cluster_spec.get("worker", []))) 190 191 192def id_in_cluster(cluster_spec, task_type, task_id): 193 """Returns a unique id for the task in the `task_type`'s cluster. 194 195 It returns an id ranging from [0, `worker_count(task_type, task_id)`). 196 197 Note: this function assumes that "evaluate" job is in its own cluster or its 198 own partition of a cluster. 199 200 Args: 201 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 202 task_type: string indicating the type of the task. 203 task_id: the id of the `task_type` in this cluster. 204 205 Returns: 206 an int indicating the unique id. 207 208 Throws: 209 ValueError: if `task_type` is not "chief", "worker" or "evaluator". 210 """ 211 _validate_cluster_spec(cluster_spec, task_type, task_id) 212 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 213 214 # The "chief" job has always id 0 and there is at most one and "worker" jobs 215 # come after it. 216 if task_type == "chief": 217 return 0 218 219 if task_type == "worker": 220 return task_id + len(cluster_spec.get("chief", [])) 221 222 # The "evaluator" is in its own cluster or its own partition of a cluster. 223 if task_type == "evaluator": 224 return task_id 225 226 # We currently don't assign ids to other tasks. 227 raise ValueError("There is no id for task_type %r" % task_type) 228 229 230def should_save_checkpoint(): 231 """Returns whether the current worker should save checkpoints. 232 233 In multi-worker training, if saving checkpoint is requested by user, or needed 234 for fault-tolerance, the cluster should save checkpoint but not necessarily 235 every worker in the cluster should. 236 237 TODO(rchao): Consider generalizing this util to be `should_save_file` as there 238 can be other files to save such as summary. 239 240 Returns: 241 Whether this particular worker in the cluster should save checkpoints. 242 """ 243 return dc_context.get_current_worker_context().should_checkpoint 244 245 246def should_load_checkpoint(): 247 """Returns whether the current worker should load checkpoints. 248 249 In multi-worker training, if loading checkpoint is requested by user, or 250 needed for fault-tolerance, the cluster should load checkpoint but not 251 necessarily every worker in the cluster should. 252 253 Returns: 254 Whether this particular worker in the cluster should load checkpoints. 255 """ 256 return dc_context.get_current_worker_context().experimental_should_init 257 258 259def wait_for_other_workers(): 260 """Waits for other workers to reach the same call to this method.""" 261 return dc_context.get_current_worker_context().wait_for_other_workers() 262 263 264def has_worker_context(): 265 """Returns whether a worker context has been entered.""" 266 return dc_context.get_current_worker_context() is not None 267