1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Util for running models in a distribution setting. 16 17Mostly from 18https://github.com/tensorflow/models/blob/master/official/ 19utils/misc/distribution_utils.py. 20""" 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import os 26 27import json 28 29import tensorflow as tf 30 31 32def _collective_communication(all_reduce_alg): 33 """Return a CollectiveCommunication based on all_reduce_alg. 34 35 Args: 36 all_reduce_alg: a string specifying which collective communication to pick, 37 or None. 38 39 Returns: 40 tf.distribute.experimental.CollectiveCommunication object 41 42 Raises: 43 ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"] 44 """ 45 collective_communication_options = { 46 None: tf.distribute.experimental.CollectiveCommunication.AUTO, 47 "ring": tf.distribute.experimental.CollectiveCommunication.RING, 48 "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL 49 } 50 if all_reduce_alg not in collective_communication_options: 51 raise ValueError( 52 "When used with `multi_worker_mirrored`, valid values for " 53 "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format( 54 all_reduce_alg)) 55 return collective_communication_options[all_reduce_alg] 56 57 58def _mirrored_cross_device_ops(all_reduce_alg, num_packs): 59 """Return a CrossDeviceOps based on all_reduce_alg and num_packs. 60 61 Args: 62 all_reduce_alg: a string specifying which cross device op to pick, or None. 63 num_packs: an integer specifying number of packs for the cross device op. 64 65 Returns: 66 tf.distribute.CrossDeviceOps object or None. 67 68 Raises: 69 ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"]. 70 """ 71 if all_reduce_alg is None: 72 return None 73 mirrored_all_reduce_options = { 74 "nccl": tf.distribute.NcclAllReduce, 75 "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce 76 } 77 if all_reduce_alg not in mirrored_all_reduce_options: 78 raise ValueError( 79 "When used with `mirrored`, valid values for all_reduce_alg are " 80 "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format( 81 all_reduce_alg)) 82 cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg] 83 return cross_device_ops_class(num_packs=num_packs) 84 85 86def get_distribution_strategy(distribution_strategy="mirrored", 87 num_gpus=0, 88 all_reduce_alg=None, 89 num_packs=1): 90 """Return a DistributionStrategy for running the model. 91 92 Args: 93 distribution_strategy: a string specifying which distribution strategy to 94 use. Accepted values are "off", "one_device", "mirrored", and 95 "multi_worker_mirrored" -- case insensitive. "off" means not to use 96 Distribution Strategy. 97 num_gpus: Number of GPUs to run this model. 98 99 Returns: 100 tf.distribute.DistibutionStrategy object. 101 Raises: 102 ValueError: if `distribution_strategy` is "off" or "one_device" and 103 `num_gpus` is larger than 1; or `num_gpus` is negative. 104 """ 105 if num_gpus < 0: 106 raise ValueError("`num_gpus` can not be negative.") 107 108 distribution_strategy = distribution_strategy.lower() 109 110 if distribution_strategy == "off": 111 if num_gpus > 1: 112 raise ValueError("When {} GPUs are specified, distribution_strategy " 113 "flag cannot be set to `off`.".format(num_gpus)) 114 return None 115 116 if distribution_strategy == "multi_worker_mirrored": 117 return tf.distribute.experimental.MultiWorkerMirroredStrategy( 118 communication=_collective_communication(all_reduce_alg)) 119 120 if distribution_strategy == "one_device": 121 if num_gpus == 0: 122 return tf.distribute.OneDeviceStrategy("device:CPU:0") 123 if num_gpus > 1: 124 raise ValueError("`OneDeviceStrategy` can not be used for more than " 125 "one device.") 126 return tf.distribute.OneDeviceStrategy("device:GPU:0") 127 128 if distribution_strategy == "mirrored": 129 if num_gpus == 0: 130 devices = ["device:CPU:0"] 131 else: 132 devices = ["device:GPU:%d" % i for i in range(num_gpus)] 133 return tf.distribute.MirroredStrategy( 134 devices=devices, 135 cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs)) 136 137 raise ValueError("Unrecognized Distribution Strategy: %r" % 138 distribution_strategy) 139 140 141def configure_cluster(worker_hosts=None, task_index=-1): 142 """Set multi-worker cluster spec in TF_CONFIG environment variable. 143 144 Args: 145 worker_hosts: comma-separated list of worker ip:port pairs. 146 147 Returns: 148 Number of workers in the cluster. 149 """ 150 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 151 if tf_config: 152 num_workers = ( 153 len(tf_config["cluster"].get("chief", [])) + 154 len(tf_config["cluster"].get("worker", []))) 155 elif worker_hosts: 156 workers = worker_hosts.split(",") 157 num_workers = len(workers) 158 if num_workers > 1 and task_index < 0: 159 raise ValueError("Must specify task_index when number of workers > 1") 160 task_index = 0 if num_workers == 1 else task_index 161 os.environ["TF_CONFIG"] = json.dumps({ 162 "cluster": { 163 "worker": workers 164 }, 165 "task": { 166 "type": "worker", 167 "index": task_index 168 } 169 }) 170 else: 171 num_workers = 1 172 return num_workers 173 174 175def get_strategy_scope(strategy): 176 if strategy: 177 strategy_scope = strategy.scope() 178 else: 179 strategy_scope = DummyContextManager() 180 181 return strategy_scope 182 183 184class DummyContextManager(object): 185 186 def __enter__(self): 187 pass 188 189 def __exit__(self, *args): 190 pass 191