• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Utilities for collectives."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23import enum
24
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28
29# TODO(b/170340570): print deprecation warning for CollectiveCommunication.
30@tf_export("distribute.experimental.CommunicationImplementation",
31           "distribute.experimental.CollectiveCommunication")
32class CommunicationImplementation(enum.Enum):
33  """Cross device communication implementation.
34
35  Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is
36  deprecated and will be removed in a future version. Use
37  `tf.distribute.experimental.CommunicationImplementation` instead.
38
39  * `AUTO`: Automatically chosen by Tensorflow.
40  * `RING`: TensorFlow's ring algorithms for all-reduce and
41    all-gather.
42  * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on
43    GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING.
44  """
45  AUTO = "AUTO"
46  RING = "RING"
47  NCCL = "NCCL"
48  # TODO(ayushd): add ncclAllGather implementation.
49
50
51CollectiveCommunication = CommunicationImplementation
52
53
54@tf_export("distribute.experimental.CommunicationOptions")
55class _OptionsExported(object):
56  """Options for cross device communications like All-reduce.
57
58  This can be passed to methods like
59  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
60  operation performance. Note that these are only hints, which may or may not
61  change the actual behavior. Some options only apply to certain strategy and
62  are ignored by others.
63
64  One common optimization is to break gradients all-reduce into multiple packs
65  so that weight updates can overlap with gradient all-reduce.
66
67  Examples:
68
69  ```python
70  options = tf.distribute.experimental.CommunicationOptions(
71      bytes_per_pack=50 * 1024 * 1024,
72      timeout_seconds=120,
73      implementation=tf.distribute.experimental.CommunicationImplementation.NCCL
74  )
75  grads = tf.distribute.get_replica_context().all_reduce(
76      'sum', grads, options=options)
77  optimizer.apply_gradients(zip(grads, vars),
78      experimental_aggregate_gradients=False)
79  ```
80
81  """
82
83  def __new__(cls, *args, **kwargs):
84    # We expose a dummy class so that we can separate internal and public APIs.
85    # Note that __init__ won't be called on the returned object if it's a
86    # different class [1].
87    # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__
88    return Options(*args, **kwargs)
89
90  def __init__(self,
91               bytes_per_pack=0,
92               timeout_seconds=None,
93               implementation=CommunicationImplementation.AUTO):
94    """Creates a CollectiveHints.
95
96    Args:
97      bytes_per_pack: a non-negative integer. Breaks collective operations into
98        packs of certain size. If it's zero, the value is determined
99        automatically. This only applies to all-reduce with
100        `MultiWorkerMirroredStrategy` currently.
101      timeout_seconds: a float or None, timeout in seconds. If not None, the
102        collective raises `tf.errors.DeadlineExceededError` if it takes longer
103        than this timeout. Zero disables timeout. This can be useful when
104        debugging hanging issues.  This should only be used for debugging since
105        it creates a new thread for each collective, i.e. an overhead of
106        `timeout_seconds * num_collectives_per_second` more threads. This only
107        works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
108      implementation: a
109        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
110        on the preferred communication implementation. Possible values include
111        `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU,
112        but doesn't work for CPU. This only works for
113        `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
114
115    Raises:
116      ValueError: When arguments have invalid value.
117    """
118    pass
119
120
121class Options(object):
122  """Implementation of OptionsInterface."""
123
124  def __init__(self,
125               bytes_per_pack=0,
126               timeout_seconds=None,
127               implementation=CommunicationImplementation.AUTO):
128    if bytes_per_pack < 0:
129      raise ValueError("bytes_per_pack must be non-negative")
130    if isinstance(implementation, str):
131      implementation = CommunicationImplementation(implementation.upper())
132    if not isinstance(implementation, CommunicationImplementation):
133      raise ValueError("implementation should be a "
134                       "tf.distribute.experimental.CommunicationImplementation")
135    self.bytes_per_pack = bytes_per_pack
136    self.timeout_seconds = timeout_seconds
137    self.implementation = implementation
138
139  __init__.__doc__ = _OptionsExported.__init__.__doc__
140
141  def merge(self, options):
142    """Merges with another options and returns a new one.
143
144    Values specified in the `options` takes precedence if they're not the
145    default.
146
147    Args:
148      options: a `tf.distribute.experimental.CollectiveCommunication`.
149
150    Returns:
151      A new `tf.distribute.experimental.CollectiveCommunication`.
152    """
153    merged = copy.deepcopy(self)
154    if options is None:
155      return merged
156    if options.bytes_per_pack != 0:
157      merged.bytes_per_pack = options.bytes_per_pack
158    if options.timeout_seconds is not None:
159      merged.timeout_seconds = options.timeout_seconds
160    if options.implementation != CommunicationImplementation.AUTO:
161      merged.implementation = options.implementation
162    return merged
163
164
165@tf_export("distribute.experimental.CollectiveHints")
166class Hints(object):
167  """Hints for collective operations like AllReduce.
168
169  This can be passed to methods like
170  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
171  operation performance. Note that these are only hints, which may or may not
172  change the actual behavior. Some options only apply to certain strategy and
173  are ignored by others.
174
175  One common optimization is to break gradients all-reduce into multiple packs
176  so that weight updates can overlap with gradient all-reduce.
177
178  Examples:
179
180  - bytes_per_pack
181
182  ```python
183  hints = tf.distribute.experimental.CollectiveHints(
184      bytes_per_pack=50 * 1024 * 1024)
185  grads = tf.distribute.get_replica_context().all_reduce(
186      'sum', grads, experimental_hints=hints)
187  optimizer.apply_gradients(zip(grads, vars),
188      experimental_aggregate_gradients=False)
189  ```
190
191  - timeout_seconds
192
193  ```python
194  strategy = tf.distribute.MirroredStrategy()
195  hints = tf.distribute.experimental.CollectiveHints(
196      timeout_seconds=120)
197  try:
198    strategy.reduce("sum", v, axis=None, experimental_hints=hints)
199  except tf.errors.DeadlineExceededError:
200    do_something()
201  ```
202
203  """
204
205  @deprecation.deprecated(
206      None, "use distribute.experimental.CommunicationOptions instead")
207  def __new__(cls, bytes_per_pack=0, timeout_seconds=None):
208    return Options(
209        bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds)
210
211  def __init__(self, bytes_per_pack=0, timeout_seconds=None):
212    """Creates a CollectiveHints.
213
214    Args:
215      bytes_per_pack: a non-negative integer. Breaks collective operations into
216        packs of certain size. If it's zero, the value is determined
217        automatically. This only applies to all-reduce with
218        `MultiWorkerMirroredStrategy` currently.
219      timeout_seconds: a float or None, timeout in seconds. If not None, the
220        collective raises `tf.errors.DeadlineExceededError` if it takes longer
221        than this timeout. This can be useful when debugging hanging issues.
222        This should only be used for debugging since it creates a new thread for
223        each collective, i.e. an overhead of `timeout_seconds *
224        num_collectives_per_second` more threads.  This only works for
225        `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
226
227    Raises:
228      ValueError: When arguments have invalid value.
229    """
230    pass
231