1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for multi-worker distribution strategies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.python.training import server_lib 23 24 25def normalize_cluster_spec(cluster_spec): 26 """Makes `cluster_spec` into a `ClusterSpec` object. 27 28 Args: 29 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 30 cluster configurations. 31 32 Returns: 33 a `ClusterSpec` object. 34 35 Raises: 36 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a 37 `ClusterDef`. 38 """ 39 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): 40 return server_lib.ClusterSpec(cluster_spec) 41 elif not isinstance(cluster_spec, server_lib.ClusterSpec): 42 raise ValueError( 43 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 44 "`tf.train.ClusterDef` object") 45 return cluster_spec 46 47 48# TODO(yuefengz): add more validations. 49def _validate_cluster_spec(cluster_spec, task_type, task_id): 50 """Validates `cluster_spec`. 51 52 It checks 53 1) whether there is such a task type as `task_type` in the 54 `cluster_spec`. 55 2) whether there is at most one "chief" job. 56 3) whether the `task_id` is smaller than the number of `task_type`. 57 58 Args: 59 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 60 task_type: string indicating the type of the task. 61 task_id: task_id: the id of the `task_type` in this cluster. 62 Throws: 63 ValueError: if `cluster_spec` fails any check. 64 """ 65 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 66 if task_type and task_type not in cluster_spec: 67 raise ValueError("`task_type` %r not found in cluster_spec." % task_type) 68 if len(cluster_spec.get("chief", [])) > 1: 69 raise ValueError("There must be at most one 'chief' job.") 70 if task_id >= len(cluster_spec[task_type]): 71 raise ValueError( 72 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) 73 74 75def is_chief(cluster_spec, task_type, task_id): 76 """Returns whether the given task is chief in the cluster. 77 78 Args: 79 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 80 cluster configurations. 81 task_type: the task type in the cluster. 82 task_id: the task id in the cluster. 83 84 Returns: 85 a boolean indicating whether the given task is chief. 86 87 Raises: 88 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds 89 the maximum id of the `task_type`. 90 """ 91 _validate_cluster_spec(cluster_spec, task_type, task_id) 92 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 93 94 if task_type == "chief": 95 return True 96 97 # If chief not in the cluster_spec, use the first worker as chief. This is 98 # common in CollectiveAllReduceStrategy. 99 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0): 100 return True 101 return False 102 103 104def worker_count(cluster_spec, task_type): 105 """Returns the number of workers in the cluster.""" 106 _validate_cluster_spec(cluster_spec, task_type, task_id=0) 107 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 108 109 # Other jobs such as "ps" shouldn't call this function. 110 if task_type not in ["chief", "worker", "evaluator"]: 111 raise ValueError("Unexpected `task_type` %r" % task_type) 112 113 if task_type == "evaluator": 114 # The "evaluator" is in its own cluster or its own partition of a cluster. 115 # So we don't have to count "chief" or "worker" if the current task is an 116 # "evaluator". 117 return len(cluster_spec["evaluator"]) 118 else: 119 # In the non-evaluator case, we return the total number of "chief" and 120 # "worker" tasks as the "chief" is also a worker. 121 return (len(cluster_spec.get("chief", [])) + len( 122 cluster_spec.get("worker", []))) 123 124 125def id_in_cluster(cluster_spec, task_type, task_id): 126 """Returns a unique id for the task in the `task_type`'s cluster. 127 128 It returns an id ranging from [0, `worker_count(task_type, task_id)`). 129 130 Note: this function assumes that "evaluate" job is in its own cluster or its 131 own partition of a cluster. 132 133 Args: 134 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 135 task_type: string indicating the type of the task. 136 task_id: the id of the `task_type` in this cluster. 137 138 Returns: 139 an int indicating the unique id. 140 141 Throws: 142 ValueError: if `task_type` is not "chief", "worker" or "evaluator". 143 """ 144 _validate_cluster_spec(cluster_spec, task_type, task_id) 145 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 146 147 # The "chief" job has always id 0 and there is at most one and "worker" jobs 148 # come after it. 149 if task_type == "chief": 150 return 0 151 152 if task_type == "worker": 153 return task_id + len(cluster_spec.get("chief", [])) 154 155 # The "evaluator" is in its own cluster or its own partition of a cluster. 156 if task_type == "evaluator": 157 return task_id 158 159 # We currently don't assign ids to other tasks. 160 raise ValueError("There is no id for task_type %r" % task_type) 161