1# coding=utf-8 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Utilities for collectives.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23import enum 24 25from tensorflow.python.util import deprecation 26from tensorflow.python.util.tf_export import tf_export 27 28 29# TODO(b/170340570): print deprecation warning for CollectiveCommunication. 30@tf_export("distribute.experimental.CommunicationImplementation", 31 "distribute.experimental.CollectiveCommunication") 32class CommunicationImplementation(enum.Enum): 33 """Cross device communication implementation. 34 35 Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is 36 deprecated and will be removed in a future version. Use 37 `tf.distribute.experimental.CommunicationImplementation` instead. 38 39 * `AUTO`: Automatically chosen by Tensorflow. 40 * `RING`: TensorFlow's ring algorithms for all-reduce and 41 all-gather. 42 * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on 43 GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING. 44 """ 45 AUTO = "AUTO" 46 RING = "RING" 47 NCCL = "NCCL" 48 # TODO(ayushd): add ncclAllGather implementation. 49 50 51CollectiveCommunication = CommunicationImplementation 52 53 54@tf_export("distribute.experimental.CommunicationOptions") 55class _OptionsExported(object): 56 """Options for cross device communications like All-reduce. 57 58 This can be passed to methods like 59 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 60 operation performance. Note that these are only hints, which may or may not 61 change the actual behavior. Some options only apply to certain strategy and 62 are ignored by others. 63 64 One common optimization is to break gradients all-reduce into multiple packs 65 so that weight updates can overlap with gradient all-reduce. 66 67 Examples: 68 69 ```python 70 options = tf.distribute.experimental.CommunicationOptions( 71 bytes_per_pack=50 * 1024 * 1024, 72 timeout_seconds=120, 73 implementation=tf.distribute.experimental.CommunicationImplementation.NCCL 74 ) 75 grads = tf.distribute.get_replica_context().all_reduce( 76 'sum', grads, options=options) 77 optimizer.apply_gradients(zip(grads, vars), 78 experimental_aggregate_gradients=False) 79 ``` 80 81 """ 82 83 def __new__(cls, *args, **kwargs): 84 # We expose a dummy class so that we can separate internal and public APIs. 85 # Note that __init__ won't be called on the returned object if it's a 86 # different class [1]. 87 # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__ 88 return Options(*args, **kwargs) 89 90 def __init__(self, 91 bytes_per_pack=0, 92 timeout_seconds=None, 93 implementation=CommunicationImplementation.AUTO): 94 """Creates a CollectiveHints. 95 96 Args: 97 bytes_per_pack: a non-negative integer. Breaks collective operations into 98 packs of certain size. If it's zero, the value is determined 99 automatically. This only applies to all-reduce with 100 `MultiWorkerMirroredStrategy` currently. 101 timeout_seconds: a float or None, timeout in seconds. If not None, the 102 collective raises `tf.errors.DeadlineExceededError` if it takes longer 103 than this timeout. Zero disables timeout. This can be useful when 104 debugging hanging issues. This should only be used for debugging since 105 it creates a new thread for each collective, i.e. an overhead of 106 `timeout_seconds * num_collectives_per_second` more threads. This only 107 works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 108 implementation: a 109 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 110 on the preferred communication implementation. Possible values include 111 `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU, 112 but doesn't work for CPU. This only works for 113 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 114 115 Raises: 116 ValueError: When arguments have invalid value. 117 """ 118 pass 119 120 121class Options(object): 122 """Implementation of OptionsInterface.""" 123 124 def __init__(self, 125 bytes_per_pack=0, 126 timeout_seconds=None, 127 implementation=CommunicationImplementation.AUTO): 128 if bytes_per_pack < 0: 129 raise ValueError("bytes_per_pack must be non-negative") 130 if isinstance(implementation, str): 131 implementation = CommunicationImplementation(implementation.upper()) 132 if not isinstance(implementation, CommunicationImplementation): 133 raise ValueError("implementation should be a " 134 "tf.distribute.experimental.CommunicationImplementation") 135 self.bytes_per_pack = bytes_per_pack 136 self.timeout_seconds = timeout_seconds 137 self.implementation = implementation 138 139 __init__.__doc__ = _OptionsExported.__init__.__doc__ 140 141 def merge(self, options): 142 """Merges with another options and returns a new one. 143 144 Values specified in the `options` takes precedence if they're not the 145 default. 146 147 Args: 148 options: a `tf.distribute.experimental.CollectiveCommunication`. 149 150 Returns: 151 A new `tf.distribute.experimental.CollectiveCommunication`. 152 """ 153 merged = copy.deepcopy(self) 154 if options is None: 155 return merged 156 if options.bytes_per_pack != 0: 157 merged.bytes_per_pack = options.bytes_per_pack 158 if options.timeout_seconds is not None: 159 merged.timeout_seconds = options.timeout_seconds 160 if options.implementation != CommunicationImplementation.AUTO: 161 merged.implementation = options.implementation 162 return merged 163 164 165@tf_export("distribute.experimental.CollectiveHints") 166class Hints(object): 167 """Hints for collective operations like AllReduce. 168 169 This can be passed to methods like 170 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 171 operation performance. Note that these are only hints, which may or may not 172 change the actual behavior. Some options only apply to certain strategy and 173 are ignored by others. 174 175 One common optimization is to break gradients all-reduce into multiple packs 176 so that weight updates can overlap with gradient all-reduce. 177 178 Examples: 179 180 - bytes_per_pack 181 182 ```python 183 hints = tf.distribute.experimental.CollectiveHints( 184 bytes_per_pack=50 * 1024 * 1024) 185 grads = tf.distribute.get_replica_context().all_reduce( 186 'sum', grads, experimental_hints=hints) 187 optimizer.apply_gradients(zip(grads, vars), 188 experimental_aggregate_gradients=False) 189 ``` 190 191 - timeout_seconds 192 193 ```python 194 strategy = tf.distribute.MirroredStrategy() 195 hints = tf.distribute.experimental.CollectiveHints( 196 timeout_seconds=120) 197 try: 198 strategy.reduce("sum", v, axis=None, experimental_hints=hints) 199 except tf.errors.DeadlineExceededError: 200 do_something() 201 ``` 202 203 """ 204 205 @deprecation.deprecated( 206 None, "use distribute.experimental.CommunicationOptions instead") 207 def __new__(cls, bytes_per_pack=0, timeout_seconds=None): 208 return Options( 209 bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds) 210 211 def __init__(self, bytes_per_pack=0, timeout_seconds=None): 212 """Creates a CollectiveHints. 213 214 Args: 215 bytes_per_pack: a non-negative integer. Breaks collective operations into 216 packs of certain size. If it's zero, the value is determined 217 automatically. This only applies to all-reduce with 218 `MultiWorkerMirroredStrategy` currently. 219 timeout_seconds: a float or None, timeout in seconds. If not None, the 220 collective raises `tf.errors.DeadlineExceededError` if it takes longer 221 than this timeout. This can be useful when debugging hanging issues. 222 This should only be used for debugging since it creates a new thread for 223 each collective, i.e. an overhead of `timeout_seconds * 224 num_collectives_per_second` more threads. This only works for 225 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 226 227 Raises: 228 ValueError: When arguments have invalid value. 229 """ 230 pass 231