1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for multi-worker distribution strategies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.python.distribute import distribute_coordinator_context as dc_context 23from tensorflow.python.training import server_lib 24 25 26def normalize_cluster_spec(cluster_spec): 27 """Makes `cluster_spec` into a `ClusterSpec` object. 28 29 Args: 30 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 31 cluster configurations. 32 33 Returns: 34 a `ClusterSpec` object. 35 36 Raises: 37 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a 38 `ClusterDef`. 39 """ 40 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): 41 return server_lib.ClusterSpec(cluster_spec) 42 elif not isinstance(cluster_spec, server_lib.ClusterSpec): 43 raise ValueError( 44 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 45 "`tf.train.ClusterDef` object") 46 return cluster_spec 47 48 49def task_count(cluster_spec, task_type): 50 try: 51 return cluster_spec.num_tasks(task_type) 52 except ValueError: 53 return 0 54 55 56def _validate_cluster_spec(cluster_spec, 57 task_type, 58 task_id): 59 """Validates `cluster_spec`. 60 61 It checks: 62 1) task type is one of "chief", "worker", "ps", "evaluator", or not provided 63 (None). 64 2) whether there is such a task type as `task_type` in the `cluster_spec`. The 65 only exception is `evaluator`. In other words, it is still a valid 66 configuration when `task_type` is `evaluator` but it doesn't appear in 67 `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator. 68 3) whether there is at most one "chief" job. 69 4) whether there is at most one "evaluator" job. 70 5) whether the `task_id` is smaller than the number of tasks for that 71 particular `task_type`. 72 73 Args: 74 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 75 task_type: string indicating the type of the task. 76 task_id: the id of the `task_type` in this cluster. 77 78 Raises: 79 ValueError: if `cluster_spec` fails any check. 80 """ 81 allowed_task_types = ("chief", "worker", "evaluator", "ps", None) 82 83 cluster_spec = normalize_cluster_spec(cluster_spec) 84 85 if any([job not in allowed_task_types for job in cluster_spec.jobs]): 86 raise ValueError("Disallowed task type found in cluster spec. Allowed " 87 "types are {} and the cluster spec is {}.".format( 88 allowed_task_types, cluster_spec)) 89 90 if task_type not in allowed_task_types: 91 raise ValueError( 92 "Unrecognized task_type: {}, valid task types are: {}".format( 93 task_type, allowed_task_types)) 94 95 if (task_type and task_type not in cluster_spec.jobs and 96 task_type != "evaluator"): 97 raise ValueError("`task_type` %r not found in cluster_spec." % task_type) 98 99 if task_count(cluster_spec, "chief") > 1: 100 raise ValueError("There must be at most one 'chief' job.") 101 102 if task_count(cluster_spec, "evaluator") > 1: 103 raise ValueError("There must be at most one 'evaluator' job.") 104 105 # The `evaluator` job is allowed to be missing in `cluster_spec`. 106 if task_type in cluster_spec.jobs and task_id >= task_count( 107 cluster_spec, task_type): 108 raise ValueError( 109 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) 110 111 112def is_chief(cluster_spec=None, task_type=None, task_id=None): 113 """Returns whether the given task is chief in the cluster. 114 115 Since there is at most one evaluator and the evaluator itself should be 116 independent of the training cluster, the evaluator job is also a chief job on 117 its own. 118 119 If this is currently running under a `_WorkerContext` of distribute 120 coordinator, the arguments can be omitted as the result is already available. 121 122 Args: 123 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 124 cluster configurations. 125 task_type: the task type in the cluster. 126 task_id: the task id in the cluster. 127 128 Returns: 129 a boolean indicating whether the given task is chief. 130 131 Raises: 132 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds 133 the maximum id of the `task_type`. 134 """ 135 if has_worker_context(): 136 # If a worker context exists, use the value provided by it. 137 return dc_context.get_current_worker_context().is_chief 138 139 _validate_cluster_spec(cluster_spec, task_type, task_id) 140 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 141 142 if task_type == "chief" or task_type == "evaluator": 143 return True 144 145 # If chief not in the cluster_spec, use the first worker as chief. This is 146 # common in CollectiveAllReduceStrategy. 147 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0): 148 return True 149 return False 150 151 152def collective_leader(cluster_spec, task_type, task_id): 153 """Return the job name for the leader of for collective ops. 154 155 Args: 156 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 157 cluster configurations. 158 task_type: the task type in the cluster. 159 task_id: the task id in the cluster. 160 161 Returns: 162 a string indicating the leader job name or empty string if no need to set 163 leader job. 164 """ 165 cluster_spec = normalize_cluster_spec(cluster_spec) 166 167 # No need to set collective leader for local. 168 if not cluster_spec.as_dict(): 169 return "" 170 171 _validate_cluster_spec(cluster_spec, task_type, task_id) 172 173 # Only one evaluator, so no need to set collective leader. 174 if task_type == "evaluator": 175 return "" 176 177 # Use chief if chief is in the cluster. 178 if "chief" in cluster_spec.jobs: 179 return "/job:chief/replica:0/task:0" 180 181 # Use worker 0 if no chief job. 182 assert "worker" in cluster_spec.jobs 183 return "/job:worker/replica:0/task:0" 184 185 186def worker_count(cluster_spec, task_type): 187 """Returns the number of workers in the cluster.""" 188 _validate_cluster_spec(cluster_spec, task_type, task_id=0) 189 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 190 191 # Other jobs such as "ps" shouldn't call this function. 192 if task_type not in ["chief", "worker", "evaluator"]: 193 raise ValueError("Unexpected `task_type` %r" % task_type) 194 195 if task_type == "evaluator": 196 # The "evaluator" is in its own cluster or its own partition of a cluster. 197 # So we don't have to count "chief" or "worker" if the current task is an 198 # "evaluator". 199 return len(cluster_spec["evaluator"]) 200 else: 201 # In the non-evaluator case, we return the total number of "chief" and 202 # "worker" tasks as the "chief" is also a worker. 203 return (len(cluster_spec.get("chief", [])) + len( 204 cluster_spec.get("worker", []))) 205 206 207def id_in_cluster(cluster_spec, task_type, task_id): 208 """Returns a unique id for the task in the `task_type`'s cluster. 209 210 It returns an id ranging from [0, `worker_count(task_type, task_id)`). 211 212 Note: this function assumes that "evaluate" job is in its own cluster or its 213 own partition of a cluster. 214 215 Args: 216 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 217 task_type: string indicating the type of the task. 218 task_id: the id of the `task_type` in this cluster. 219 220 Returns: 221 an int indicating the unique id. 222 223 Throws: 224 ValueError: if `task_type` is not "chief", "worker" or "evaluator". 225 """ 226 _validate_cluster_spec(cluster_spec, task_type, task_id) 227 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 228 229 # The "chief" job has always id 0 and there is at most one and "worker" jobs 230 # come after it. 231 if task_type == "chief": 232 return 0 233 234 if task_type == "worker": 235 return task_id + len(cluster_spec.get("chief", [])) 236 237 # The "evaluator" is in its own cluster or its own partition of a cluster. 238 if task_type == "evaluator": 239 return task_id 240 241 # We currently don't assign ids to other tasks. 242 raise ValueError("There is no id for task_type %r" % task_type) 243 244 245def should_save_checkpoint(): 246 """Returns whether the current worker should save checkpoints. 247 248 In multi-worker training, if saving checkpoint is requested by user, or needed 249 for fault-tolerance, the cluster should save checkpoint but not necessarily 250 every worker in the cluster should. 251 252 TODO(rchao): Consider generalizing this util to be `should_save_file` as there 253 can be other files to save such as summary. 254 255 Returns: 256 Whether this particular worker in the cluster should save checkpoints. 257 """ 258 return dc_context.get_current_worker_context().should_checkpoint 259 260 261def should_load_checkpoint(): 262 """Returns whether the current worker should load checkpoints. 263 264 In multi-worker training, if loading checkpoint is requested by user, or 265 needed for fault-tolerance, the cluster should load checkpoint but not 266 necessarily every worker in the cluster should. 267 268 Returns: 269 Whether this particular worker in the cluster should load checkpoints. 270 """ 271 return dc_context.get_current_worker_context().experimental_should_init 272 273 274def wait_for_other_workers(): 275 """Waits for other workers to reach the same call to this method.""" 276 return dc_context.get_current_worker_context().wait_for_other_workers() 277 278 279def has_worker_context(): 280 """Returns whether a worker context has been entered.""" 281 return dc_context.get_current_worker_context() is not None 282