• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Inter-process communication using MPI."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow as tf
22
23from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops
24from tensorflow.contrib.util import loader
25from tensorflow.python.framework import ops
26from tensorflow.python.platform import resource_loader
27
28_mpi_ops_so = loader.load_op_library(
29    resource_loader.get_path_to_datafile('_mpi_ops.so'))
30
31
32def size(name=None):
33  """An op which returns the number of MPI processes.
34
35  This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
36  size of the global communicator.
37
38  Returns:
39    An integer scalar containing the number of MPI processes.
40  """
41  return gen_mpi_ops.mpi_size(name=name)
42
43
44ops.NotDifferentiable('MPISize')
45
46
47def rank(name=None):
48  """An op which returns the MPI rank of the calling process.
49
50  This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
51  rank of the current process in the global communicator.
52
53  Returns:
54    An integer scalar with the MPI rank of the calling process.
55  """
56  return gen_mpi_ops.mpi_rank(name=name)
57
58
59ops.NotDifferentiable('MPIRank')
60
61
62def init(name=None):
63  """An op which initializes MPI on the device on which it is run.
64
65  All future MPI ops must be run on the same device that the `init` op was run
66  on.
67  """
68  return gen_mpi_ops.mpi_init(name=name)
69
70
71ops.NotDifferentiable('MPIInit')
72
73
74def local_rank(name=None):
75  """An op which returns the local MPI rank of the calling process, within the
76  node that it is running on. For example, if there are seven processes running
77  on a node, their local ranks will be zero through six, inclusive.
78
79  This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
80  which only includes processes on the same node.
81
82  Returns:
83    An integer scalar with the local MPI rank of the calling process.
84  """
85  return gen_mpi_ops.mpi_local_rank(name=name)
86
87
88ops.NotDifferentiable('MPILocalRank')
89
90
91def _allreduce(tensor, name=None):
92  """An op which sums an input tensor over all the MPI processes.
93
94  The reduction operation is keyed by the name of the op. The tensor type and
95  shape must be the same on all MPI processes for a given name. The reduction
96  will not start until all processes are ready to send and receive the tensor.
97
98  Returns:
99    A tensor of the same shape and type as `tensor`, summed across all
100    processes.
101  """
102  return gen_mpi_ops.mpi_allreduce(tensor, name=name)
103
104
105ops.NotDifferentiable('MPIAllreduce')
106
107
108def allgather(tensor, name=None):
109  """An op which concatenates the input tensor with the same input tensor on
110  all other MPI processes.
111
112  The concatenation is done on the first dimension, so the input tensors on the
113  different processes must have the same rank and shape, except for the first
114  dimension, which is allowed to be different.
115
116  Returns:
117    A tensor of the same type as `tensor`, concatenated on dimension zero
118    across all processes. The shape is identical to the input shape, except for
119    the first dimension, which may be greater and is the sum of all first
120    dimensions of the tensors in different MPI processes.
121  """
122  # Specify that first allgather is to collect the tensor gather sizes,
123  # indicated by passing in a scalar (0-D tensor) of value 0
124  sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const')
125  my_size = tf.slice(
126      tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice')
127  if name is None:
128    name = 'allgather'
129  sizing_name = '{}_sizing'.format(name)
130  sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name)
131  return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name)
132
133
134ops.NotDifferentiable('MPIAllgather')
135