• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow collective Ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.ops import gen_collective_ops
21
22
23def all_reduce(t,
24               group_size,
25               group_key,
26               instance_key,
27               merge_op='Add',
28               final_op='Id',
29               subdiv_offsets=(0,),
30               communication_hint='auto',
31               timeout=0):
32  """Reduces tensors collectively, across devices.
33
34  Args:
35    t: the tensor to be reduced.
36    group_size: the total number of tensors to be collectively reduced.
37      Each must reside on a different device.  Should be a positive integer.
38    group_key: an integer identifying the group of devices.
39    instance_key: an integer identifying the participating group of Ops.
40    merge_op: string naming the binary Op to be applied to compute each
41      partial reduction.
42    final_op: string naming the unary Op to be applied to each fully
43      reduced value.  Can be 'Id' for no operation.
44    subdiv_offsets: a list of integer offsets into the tensor at which each
45      independent subdivision should begin.  Use [0] if no subdivision should
46      be done.
47    communication_hint: preferred collective communication.  The implementation
48      may fall back to another mechanism.  Options include `auto`, `ring`, and
49      `nccl`.
50    timeout: a float. If set to a non zero, set a completion timeout to detect
51      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
52      timeout value in seconds. This feature is experimental.
53
54  Returns:
55    An Op implementing the distributed reduction.
56
57  Raises:
58    ValueError: if any of the input parameter constraints are not met.
59  """
60  if group_size < 1:
61    raise ValueError('Parameter group_size to all_reduce must be at least 1.')
62  return gen_collective_ops.collective_reduce(
63      t,
64      group_size=group_size,
65      group_key=group_key,
66      instance_key=instance_key,
67      merge_op=merge_op,
68      final_op=final_op,
69      subdiv_offsets=subdiv_offsets,
70      communication_hint=communication_hint.lower(),
71      timeout_seconds=timeout)
72
73
74def all_reduce_v2(t,
75                  group_size,
76                  group_key,
77                  instance_key,
78                  merge_op='Add',
79                  final_op='Id',
80                  communication_hint='auto',
81                  timeout=0,
82                  ordering_token=None):
83  """Reduces tensors collectively, across devices.
84
85  Args:
86    t: the tensor to be reduced.
87    group_size: an int32 tensor. The total number of tensors to be collectively
88      reduced.  Each must reside on a different device.  Should be a positive
89      integer.
90    group_key: an int32 tensor identifying the group of devices.
91    instance_key: an int32 tensor identifying the participating group of Ops.
92    merge_op: string naming the binary Op to be applied to compute each partial
93      reduction.
94    final_op: string naming the unary Op to be applied to each fully reduced
95      value.  Can be 'Id' for no operation.
96    communication_hint: preferred collective communication.  The implementation
97      may fall back to another mechanism.  Options include `auto`, `ring`, and
98      `nccl`.
99    timeout: a float. If set to a non zero, set a completion timeout to detect
100      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
101      timeout value in seconds. This feature is experimental.
102    ordering_token: an optional resource tensor to pass to the op as inputs.
103      They aren't used by the kernel but allow AutoControlDependency to order
104      the collectives with control dependencies.
105
106  Returns:
107    An Op implementing the distributed reduction.
108  """
109  if ordering_token is not None:
110    ordering_token = [ordering_token]
111  return gen_collective_ops.collective_reduce_v2(
112      t,
113      group_size=group_size,
114      group_key=group_key,
115      instance_key=instance_key,
116      merge_op=merge_op,
117      final_op=final_op,
118      communication_hint=communication_hint.lower(),
119      timeout_seconds=timeout,
120      ordering_token=ordering_token or [])
121
122
123def all_gather(t,
124               group_size,
125               group_key,
126               instance_key,
127               communication_hint='auto',
128               timeout=0):
129  """Accumulates tensors collectively, across devices, along first dimension.
130
131  Args:
132    t: the tensor to participate in the accumulation.
133    group_size: the total number of tensors to be collectively accumulated.
134      Each must reside on a different device. Should be a positive integer.
135    group_key: an integer identifying the group of devices.
136    instance_key: an integer identifying the participating group of Ops.
137    communication_hint: preferred collective communication. The implementation
138      may fall back to another mechanism. Options include `auto`, `ring`, and
139      `nccl`.
140    timeout: a float. If set to a non zero, set a completion timeout to detect
141      staleness. If the timer goes off, a DeadlineExceededError is raised. The
142      timeout value in seconds. This feature is experimental.
143
144  Returns:
145    An Op implementing the distributed operation.
146
147  Raises:
148    ValueError: if any of the input parameter constraints are not met.
149  """
150  if group_size < 1:
151    raise ValueError('Parameter group_size to all_gather must be at least 1.')
152  return gen_collective_ops.collective_gather(
153      t,
154      shape=[0],
155      group_size=group_size,
156      group_key=group_key,
157      instance_key=instance_key,
158      communication_hint=communication_hint.lower(),
159      timeout_seconds=timeout)
160
161
162def all_gather_v2(t,
163                  group_size,
164                  group_key,
165                  instance_key,
166                  communication_hint='auto',
167                  timeout=0,
168                  ordering_token=None):
169  """Accumulates tensors collectively, across devices, along first dimension.
170
171  Args:
172    t: the tensor to participate in the accumulation.
173    group_size: an int32 tensor, the total number of tensors to be collectively
174      accumulated. Each must reside on a different device. Should be a positive
175      integer.
176    group_key: an int32 tensor identifying the group of devices.
177    instance_key: an int32 tensor identifying the participating group of Ops.
178    communication_hint: preferred collective communication. The implementation
179      may fall back to another mechanism. Options include `auto`, `ring`, and
180      `nccl`.
181    timeout: a float. If set to a non zero, set a completion timeout to detect
182      staleness. If the timer goes off, a DeadlineExceededError is raised. The
183      timeout value in seconds. This feature is experimental.
184    ordering_token: an optional resource tensor to pass to the op as inputs.
185      They aren't used by the kernel but allow AutoControlDependency to order
186      the collectives with control dependencies.
187
188  Returns:
189    An Op implementing the distributed operation.
190  """
191  if ordering_token is not None:
192    ordering_token = [ordering_token]
193  return gen_collective_ops.collective_gather_v2(
194      t,
195      group_size=group_size,
196      group_key=group_key,
197      instance_key=instance_key,
198      communication_hint=communication_hint.lower(),
199      timeout_seconds=timeout,
200      ordering_token=ordering_token or [])
201
202
203def broadcast_send(t,
204                   shape,
205                   dtype,
206                   group_size,
207                   group_key,
208                   instance_key,
209                   communication_hint='auto',
210                   timeout=0):
211  """Broadcasts one tensor to a group of others, across devices.
212
213  Args:
214    t: the tensor to be sent.
215    shape: the shape of the tensor being sent, which must agree with t.
216    dtype: the type of the tensor being sent, which must agree with t.
217    group_size: one plus the number of receiving tensors, i.e. the total
218      number of devices participating.  Each tensor must reside on a
219      different device.
220    group_key: an integer identifying the group of devices.
221    instance_key: an integer identifying the participating group of Ops.
222    communication_hint: preferred collective communication.  The implementation
223      may fall back to another mechanism.  Options include `auto`, `ring`, and
224      `nccl`.
225    timeout: If set to a non zero, set a completion timeout to detect staleness.
226      If the timer goes off, a DeadlineExceededError is raised.
227      The timeout value in seconds. This feature is experimental.
228
229  Returns:
230    An Op implementing the distributed broadcast send.
231
232  Raises:
233    ValueError: if any of the input parameter constraints are not met.
234
235  Note that the shape and dtype arguments appear redundant since they
236  should be obtainable from t.  The are two reasons for including
237  them.  First, the shape and type of tensors passed via broadcast must
238  be known ahead of time in their most specific form so that the receive
239  side can allocate memory for the operation and shape/type inference can
240  carry forward from there.  Including the same declarations on the
241  send side clarifies a commitment already made.  Secondly, having nearly
242  identical use syntax for send and receive sides may simplify tool-driven
243  generation of broadcast.
244  """
245  if group_size <= 1:
246    raise ValueError(
247        'Parameter group_size to broadcast_send must be at least 2.')
248  if t.shape != shape:
249    raise ValueError(
250        'Shape of broadcast_send tensor not equal to declared shape')
251  if t.dtype != dtype:
252    raise ValueError(
253        'Type of broadcast_send tensor not equal to declared type')
254  return gen_collective_ops.collective_bcast_send(
255      t,
256      shape=shape,
257      group_size=group_size,
258      group_key=group_key,
259      instance_key=instance_key,
260      communication_hint=communication_hint.lower(),
261      timeout_seconds=timeout)
262
263
264def broadcast_send_v2(t,
265                      group_size,
266                      group_key,
267                      instance_key,
268                      communication_hint='auto',
269                      timeout=0):
270  """Broadcasts one tensor to a group of others, across devices.
271
272  Args:
273    t: the tensor to be sent.
274    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
275        the total number of devices participating.  Each tensor must reside on a
276        different device.
277    group_key: an int32 tensor identifying the group of devices.
278    instance_key: an int32 tensor identifying the participating group of Ops.
279    communication_hint: preferred collective communication.  The implementation
280      may fall back to another mechanism.  Options include `auto`, `ring`, and
281      `nccl`.
282    timeout: If set to a non zero, set a completion timeout to detect staleness.
283      If the timer goes off, a DeadlineExceededError is raised.
284      The timeout value in seconds. This feature is experimental.
285
286  Returns:
287    An Op implementing the distributed broadcast send.
288  """
289  return gen_collective_ops.collective_bcast_send_v2(
290      t,
291      group_size=group_size,
292      group_key=group_key,
293      instance_key=instance_key,
294      communication_hint=communication_hint.lower(),
295      timeout_seconds=timeout)
296
297
298def broadcast_recv(shape,
299                   dtype,
300                   group_size,
301                   group_key,
302                   instance_key,
303                   communication_hint='auto',
304                   timeout=0):
305  """Receives a broadcasts tensor, across devices.
306
307  Args:
308    shape: Shape of the tensor to be received.
309    dtype: Type of the tensor to be received.
310    group_size: one plus the number of receiving tensors, i.e. the total
311      number of devices participating.  Each tensor must reside on a
312      different device.
313    group_key: an integer identifying the group of devices.
314    instance_key: an integer identifying the participating group of Ops.
315    communication_hint: preferred collective communication.  The implementation
316      may fall back to another mechanism.  Options include `auto`, `ring`, and
317      `nccl`.
318    timeout: If set to a non zero, set a completion timeout to detect staleness.
319      If the timer goes off, a DeadlineExceededError is raised.
320      The timeout value in seconds. This feature is experimental.
321
322  Returns:
323    An Op implementing the broadcast receive.
324
325  Raises:
326    ValueError: if any of the input parameter constraints are not met.
327  """
328  if group_size <= 1:
329    raise ValueError(
330        'Parameter group_size to broadcast_send must be at least 2.')
331  return gen_collective_ops.collective_bcast_recv(
332      shape=shape,
333      T=dtype,
334      group_size=group_size,
335      group_key=group_key,
336      instance_key=instance_key,
337      communication_hint=communication_hint.lower(),
338      timeout_seconds=timeout)
339
340
341def broadcast_recv_v2(shape,
342                      dtype,
343                      group_size,
344                      group_key,
345                      instance_key,
346                      communication_hint='auto',
347                      timeout=0):
348  """Receives a broadcasts tensor, across devices.
349
350  Args:
351    shape: an int tensor.  Shape of the tensor to be received.
352    dtype: Type of the tensor to be received.
353    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
354        the total number of devices participating.  Each tensor must reside on a
355        different device.
356    group_key: an int32 tensor identifying the group of devices.
357    instance_key: an int32 tensor identifying the participating group of Ops.
358    communication_hint: preferred collective communication.  The implementation
359      may fall back to another mechanism.  Options include `auto`, `ring`, and
360      `nccl`.
361    timeout: If set to a non zero, set a completion timeout to detect staleness.
362      If the timer goes off, a DeadlineExceededError is raised.
363      The timeout value in seconds. This feature is experimental.
364
365  Returns:
366    An Op implementing the broadcast receive.
367  """
368  return gen_collective_ops.collective_bcast_recv_v2(
369      T=dtype,
370      group_size=group_size,
371      group_key=group_key,
372      instance_key=instance_key,
373      shape=shape,
374      communication_hint=communication_hint.lower(),
375      timeout_seconds=timeout)
376
377