• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow collective Ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.ops import gen_collective_ops
21
22
23def all_reduce(t,
24               group_size,
25               group_key,
26               instance_key,
27               merge_op='Add',
28               final_op='Id',
29               subdiv_offsets=(0,),
30               communication_hint='auto',
31               timeout=0):
32  """Reduces tensors collectively, across devices.
33
34  Args:
35    t: the tensor to be reduced.
36    group_size: the total number of tensors to be collectively reduced.
37      Each must reside on a different device.  Should be a positive integer.
38    group_key: an integer identifying the group of devices.
39    instance_key: an integer identifying the participating group of Ops.
40    merge_op: string naming the binary Op to be applied to compute each
41      partial reduction.
42    final_op: string naming the unary Op to be applied to each fully
43      reduced value.  Can be 'Id' for no operation.
44    subdiv_offsets: a list of integer offsets into the tensor at which each
45      independent subdivision should begin.  Use [0] if no subdivision should
46      be done.
47    communication_hint: preferred collective communication.  The implementation
48      may fall back to another mechanism.  Options include `auto`, `ring`, and
49      `nccl`.
50    timeout: a float. If set to a non zero, set a completion timeout to detect
51      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
52      timeout value in seconds. This feature is experimental.
53
54  Returns:
55    An Op implementing the distributed reduction.
56
57  Raises:
58    ValueError: if any of the input parameter constraints are not met.
59  """
60  if group_size < 1:
61    raise ValueError('Parameter group_size to all_reduce must be at least 1.')
62  return gen_collective_ops.collective_reduce(
63      t,
64      group_size=group_size,
65      group_key=group_key,
66      instance_key=instance_key,
67      merge_op=merge_op,
68      final_op=final_op,
69      subdiv_offsets=subdiv_offsets,
70      communication_hint=communication_hint.lower(),
71      timeout_seconds=timeout)
72
73
74def all_reduce_v2(t,
75                  group_size,
76                  group_key,
77                  instance_key,
78                  merge_op='Add',
79                  final_op='Id',
80                  communication_hint='auto',
81                  timeout=0,
82                  ordering_token=None,
83                  max_subdivs_per_device=-1):
84  """Reduces tensors collectively, across devices.
85
86  Args:
87    t: the tensor to be reduced.
88    group_size: an int32 tensor. The total number of tensors to be collectively
89      reduced.  Each must reside on a different device.  Should be a positive
90      integer.
91    group_key: an int32 tensor identifying the group of devices.
92    instance_key: an int32 tensor identifying the participating group of Ops.
93    merge_op: string naming the binary Op to be applied to compute each partial
94      reduction.
95    final_op: string naming the unary Op to be applied to each fully reduced
96      value.  Can be 'Id' for no operation.
97    communication_hint: preferred collective communication.  The implementation
98      may fall back to another mechanism.  Options include `auto`, `ring`, and
99      `nccl`.
100    timeout: a float. If set to a non zero, set a completion timeout to detect
101      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
102      timeout value in seconds. This feature is experimental.
103    ordering_token: an optional resource tensor to pass to the op as inputs.
104      They aren't used by the kernel but allow AutoControlDependency to order
105      the collectives with control dependencies.
106    max_subdivs_per_device: int specifying the maximum number of subdivisions a
107      tensor on a device can be divided into. The runtime uses this contraint to
108      parallelize processing of each per-device tensor. Setting to -1 disables
109      subdivision and reverts to previous behavior of not sub-dividing tensor.
110      Setting to 0 uses sytem defaults.
111
112  Returns:
113    An Op implementing the distributed reduction.
114  """
115  if ordering_token is not None:
116    ordering_token = [ordering_token]
117  return gen_collective_ops.collective_reduce_v2(
118      t,
119      group_size=group_size,
120      group_key=group_key,
121      instance_key=instance_key,
122      merge_op=merge_op,
123      final_op=final_op,
124      communication_hint=communication_hint.lower(),
125      timeout_seconds=timeout,
126      ordering_token=ordering_token or [],
127      max_subdivs_per_device=max_subdivs_per_device)
128
129
130def all_gather(t,
131               group_size,
132               group_key,
133               instance_key,
134               communication_hint='auto',
135               timeout=0):
136  """Accumulates tensors collectively, across devices, along first dimension.
137
138  Args:
139    t: the tensor to participate in the accumulation.
140    group_size: the total number of tensors to be collectively accumulated.
141      Each must reside on a different device. Should be a positive integer.
142    group_key: an integer identifying the group of devices.
143    instance_key: an integer identifying the participating group of Ops.
144    communication_hint: preferred collective communication. The implementation
145      may fall back to another mechanism. Options include `auto`, `ring`, and
146      `nccl`.
147    timeout: a float. If set to a non zero, set a completion timeout to detect
148      staleness. If the timer goes off, a DeadlineExceededError is raised. The
149      timeout value in seconds. This feature is experimental.
150
151  Returns:
152    An Op implementing the distributed operation.
153
154  Raises:
155    ValueError: if any of the input parameter constraints are not met.
156  """
157  if group_size < 1:
158    raise ValueError('Parameter group_size to all_gather must be at least 1.')
159  return gen_collective_ops.collective_gather(
160      t,
161      shape=[0],
162      group_size=group_size,
163      group_key=group_key,
164      instance_key=instance_key,
165      communication_hint=communication_hint.lower(),
166      timeout_seconds=timeout)
167
168
169def all_gather_v2(t,
170                  group_size,
171                  group_key,
172                  instance_key,
173                  communication_hint='auto',
174                  timeout=0,
175                  ordering_token=None):
176  """Accumulates tensors collectively, across devices, along first dimension.
177
178  Args:
179    t: the tensor to participate in the accumulation.
180    group_size: an int32 tensor, the total number of tensors to be collectively
181      accumulated. Each must reside on a different device. Should be a positive
182      integer.
183    group_key: an int32 tensor identifying the group of devices.
184    instance_key: an int32 tensor identifying the participating group of Ops.
185    communication_hint: preferred collective communication. The implementation
186      may fall back to another mechanism. Options include `auto`, `ring`, and
187      `nccl`.
188    timeout: a float. If set to a non zero, set a completion timeout to detect
189      staleness. If the timer goes off, a DeadlineExceededError is raised. The
190      timeout value in seconds. This feature is experimental.
191    ordering_token: an optional resource tensor to pass to the op as inputs.
192      They aren't used by the kernel but allow AutoControlDependency to order
193      the collectives with control dependencies.
194
195  Returns:
196    An Op implementing the distributed operation.
197  """
198  if ordering_token is not None:
199    ordering_token = [ordering_token]
200  return gen_collective_ops.collective_gather_v2(
201      t,
202      group_size=group_size,
203      group_key=group_key,
204      instance_key=instance_key,
205      communication_hint=communication_hint.lower(),
206      timeout_seconds=timeout,
207      ordering_token=ordering_token or [])
208
209
210def broadcast_send(t,
211                   shape,
212                   dtype,
213                   group_size,
214                   group_key,
215                   instance_key,
216                   communication_hint='auto',
217                   timeout=0):
218  """Broadcasts one tensor to a group of others, across devices.
219
220  Args:
221    t: the tensor to be sent.
222    shape: the shape of the tensor being sent, which must agree with t.
223    dtype: the type of the tensor being sent, which must agree with t.
224    group_size: one plus the number of receiving tensors, i.e. the total
225      number of devices participating.  Each tensor must reside on a
226      different device.
227    group_key: an integer identifying the group of devices.
228    instance_key: an integer identifying the participating group of Ops.
229    communication_hint: preferred collective communication.  The implementation
230      may fall back to another mechanism.  Options include `auto`, `ring`, and
231      `nccl`.
232    timeout: If set to a non zero, set a completion timeout to detect staleness.
233      If the timer goes off, a DeadlineExceededError is raised.
234      The timeout value in seconds. This feature is experimental.
235
236  Returns:
237    An Op implementing the distributed broadcast send.
238
239  Raises:
240    ValueError: if any of the input parameter constraints are not met.
241
242  Note that the shape and dtype arguments appear redundant since they
243  should be obtainable from t.  The are two reasons for including
244  them.  First, the shape and type of tensors passed via broadcast must
245  be known ahead of time in their most specific form so that the receive
246  side can allocate memory for the operation and shape/type inference can
247  carry forward from there.  Including the same declarations on the
248  send side clarifies a commitment already made.  Secondly, having nearly
249  identical use syntax for send and receive sides may simplify tool-driven
250  generation of broadcast.
251  """
252  if group_size <= 1:
253    raise ValueError(
254        'Parameter group_size to broadcast_send must be at least 2.')
255  if t.shape != shape:
256    raise ValueError(
257        'Shape of broadcast_send tensor not equal to declared shape')
258  if t.dtype != dtype:
259    raise ValueError(
260        'Type of broadcast_send tensor not equal to declared type')
261  return gen_collective_ops.collective_bcast_send(
262      t,
263      shape=shape,
264      group_size=group_size,
265      group_key=group_key,
266      instance_key=instance_key,
267      communication_hint=communication_hint.lower(),
268      timeout_seconds=timeout)
269
270
271def broadcast_send_v2(t,
272                      group_size,
273                      group_key,
274                      instance_key,
275                      communication_hint='auto',
276                      timeout=0):
277  """Broadcasts one tensor to a group of others, across devices.
278
279  Args:
280    t: the tensor to be sent.
281    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
282        the total number of devices participating.  Each tensor must reside on a
283        different device.
284    group_key: an int32 tensor identifying the group of devices.
285    instance_key: an int32 tensor identifying the participating group of Ops.
286    communication_hint: preferred collective communication.  The implementation
287      may fall back to another mechanism.  Options include `auto`, `ring`, and
288      `nccl`.
289    timeout: If set to a non zero, set a completion timeout to detect staleness.
290      If the timer goes off, a DeadlineExceededError is raised.
291      The timeout value in seconds. This feature is experimental.
292
293  Returns:
294    An Op implementing the distributed broadcast send.
295  """
296  return gen_collective_ops.collective_bcast_send_v2(
297      t,
298      group_size=group_size,
299      group_key=group_key,
300      instance_key=instance_key,
301      communication_hint=communication_hint.lower(),
302      timeout_seconds=timeout)
303
304
305def broadcast_recv(shape,
306                   dtype,
307                   group_size,
308                   group_key,
309                   instance_key,
310                   communication_hint='auto',
311                   timeout=0):
312  """Receives a broadcasts tensor, across devices.
313
314  Args:
315    shape: Shape of the tensor to be received.
316    dtype: Type of the tensor to be received.
317    group_size: one plus the number of receiving tensors, i.e. the total
318      number of devices participating.  Each tensor must reside on a
319      different device.
320    group_key: an integer identifying the group of devices.
321    instance_key: an integer identifying the participating group of Ops.
322    communication_hint: preferred collective communication.  The implementation
323      may fall back to another mechanism.  Options include `auto`, `ring`, and
324      `nccl`.
325    timeout: If set to a non zero, set a completion timeout to detect staleness.
326      If the timer goes off, a DeadlineExceededError is raised.
327      The timeout value in seconds. This feature is experimental.
328
329  Returns:
330    An Op implementing the broadcast receive.
331
332  Raises:
333    ValueError: if any of the input parameter constraints are not met.
334  """
335  if group_size <= 1:
336    raise ValueError(
337        'Parameter group_size to broadcast_send must be at least 2.')
338  return gen_collective_ops.collective_bcast_recv(
339      shape=shape,
340      T=dtype,
341      group_size=group_size,
342      group_key=group_key,
343      instance_key=instance_key,
344      communication_hint=communication_hint.lower(),
345      timeout_seconds=timeout)
346
347
348def broadcast_recv_v2(shape,
349                      dtype,
350                      group_size,
351                      group_key,
352                      instance_key,
353                      communication_hint='auto',
354                      timeout=0):
355  """Receives a broadcasts tensor, across devices.
356
357  Args:
358    shape: an int tensor.  Shape of the tensor to be received.
359    dtype: Type of the tensor to be received.
360    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
361        the total number of devices participating.  Each tensor must reside on a
362        different device.
363    group_key: an int32 tensor identifying the group of devices.
364    instance_key: an int32 tensor identifying the participating group of Ops.
365    communication_hint: preferred collective communication.  The implementation
366      may fall back to another mechanism.  Options include `auto`, `ring`, and
367      `nccl`.
368    timeout: If set to a non zero, set a completion timeout to detect staleness.
369      If the timer goes off, a DeadlineExceededError is raised.
370      The timeout value in seconds. This feature is experimental.
371
372  Returns:
373    An Op implementing the broadcast receive.
374  """
375  return gen_collective_ops.collective_bcast_recv_v2(
376      T=dtype,
377      group_size=group_size,
378      group_key=group_key,
379      instance_key=instance_key,
380      shape=shape,
381      communication_hint=communication_hint.lower(),
382      timeout_seconds=timeout)
383