• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Util for running models in a distribution setting.
16
17Mostly from
18https://github.com/tensorflow/models/blob/master/official/
19utils/misc/distribution_utils.py.
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import os
26
27import json
28
29import tensorflow as tf
30
31
32def _collective_communication(all_reduce_alg):
33  """Return a CollectiveCommunication based on all_reduce_alg.
34
35  Args:
36    all_reduce_alg: a string specifying which collective communication to pick,
37      or None.
38
39  Returns:
40    tf.distribute.experimental.CollectiveCommunication object
41
42  Raises:
43    ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
44  """
45  collective_communication_options = {
46      None: tf.distribute.experimental.CollectiveCommunication.AUTO,
47      "ring": tf.distribute.experimental.CollectiveCommunication.RING,
48      "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
49  }
50  if all_reduce_alg not in collective_communication_options:
51    raise ValueError(
52        "When used with `multi_worker_mirrored`, valid values for "
53        "all_reduce_alg are [`ring`, `nccl`].  Supplied value: {}".format(
54            all_reduce_alg))
55  return collective_communication_options[all_reduce_alg]
56
57
58def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
59  """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
60
61  Args:
62    all_reduce_alg: a string specifying which cross device op to pick, or None.
63    num_packs: an integer specifying number of packs for the cross device op.
64
65  Returns:
66    tf.distribute.CrossDeviceOps object or None.
67
68  Raises:
69    ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
70  """
71  if all_reduce_alg is None:
72    return None
73  mirrored_all_reduce_options = {
74      "nccl": tf.distribute.NcclAllReduce,
75      "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
76  }
77  if all_reduce_alg not in mirrored_all_reduce_options:
78    raise ValueError(
79        "When used with `mirrored`, valid values for all_reduce_alg are "
80        "[`nccl`, `hierarchical_copy`].  Supplied value: {}".format(
81            all_reduce_alg))
82  cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
83  return cross_device_ops_class(num_packs=num_packs)
84
85
86def get_distribution_strategy(distribution_strategy="mirrored",
87                              num_gpus=0,
88                              all_reduce_alg=None,
89                              num_packs=1):
90  """Return a DistributionStrategy for running the model.
91
92  Args:
93    distribution_strategy: a string specifying which distribution strategy to
94      use. Accepted values are "off", "one_device", "mirrored", and
95      "multi_worker_mirrored" -- case insensitive. "off" means not to use
96      Distribution Strategy.
97    num_gpus: Number of GPUs to run this model.
98
99  Returns:
100    tf.distribute.DistibutionStrategy object.
101  Raises:
102    ValueError: if `distribution_strategy` is "off" or "one_device" and
103      `num_gpus` is larger than 1; or `num_gpus` is negative.
104  """
105  if num_gpus < 0:
106    raise ValueError("`num_gpus` can not be negative.")
107
108  distribution_strategy = distribution_strategy.lower()
109
110  if distribution_strategy == "off":
111    if num_gpus > 1:
112      raise ValueError("When {} GPUs are specified, distribution_strategy "
113                       "flag cannot be set to `off`.".format(num_gpus))
114    return None
115
116  if distribution_strategy == "multi_worker_mirrored":
117    return tf.distribute.experimental.MultiWorkerMirroredStrategy(
118        communication=_collective_communication(all_reduce_alg))
119
120  if distribution_strategy == "one_device":
121    if num_gpus == 0:
122      return tf.distribute.OneDeviceStrategy("device:CPU:0")
123    if num_gpus > 1:
124      raise ValueError("`OneDeviceStrategy` can not be used for more than "
125                       "one device.")
126    return tf.distribute.OneDeviceStrategy("device:GPU:0")
127
128  if distribution_strategy == "mirrored":
129    if num_gpus == 0:
130      devices = ["device:CPU:0"]
131    else:
132      devices = ["device:GPU:%d" % i for i in range(num_gpus)]
133    return tf.distribute.MirroredStrategy(
134        devices=devices,
135        cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
136
137  raise ValueError("Unrecognized Distribution Strategy: %r" %
138                   distribution_strategy)
139
140
141def configure_cluster(worker_hosts=None, task_index=-1):
142  """Set multi-worker cluster spec in TF_CONFIG environment variable.
143
144  Args:
145    worker_hosts: comma-separated list of worker ip:port pairs.
146
147  Returns:
148    Number of workers in the cluster.
149  """
150  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
151  if tf_config:
152    num_workers = (
153        len(tf_config["cluster"].get("chief", [])) +
154        len(tf_config["cluster"].get("worker", [])))
155  elif worker_hosts:
156    workers = worker_hosts.split(",")
157    num_workers = len(workers)
158    if num_workers > 1 and task_index < 0:
159      raise ValueError("Must specify task_index when number of workers > 1")
160    task_index = 0 if num_workers == 1 else task_index
161    os.environ["TF_CONFIG"] = json.dumps({
162        "cluster": {
163            "worker": workers
164        },
165        "task": {
166            "type": "worker",
167            "index": task_index
168        }
169    })
170  else:
171    num_workers = 1
172  return num_workers
173
174
175def get_strategy_scope(strategy):
176  if strategy:
177    strategy_scope = strategy.scope()
178  else:
179    strategy_scope = DummyContextManager()
180
181  return strategy_scope
182
183
184class DummyContextManager(object):
185
186  def __enter__(self):
187    pass
188
189  def __exit__(self, *args):
190    pass
191