• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for tensorflow.contrib.mpi_collectives.mpi_ops."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os.path
23import itertools
24
25import tensorflow as tf
26
27import tensorflow.contrib.mpi_collectives as mpi
28
29
30def mpi_env_rank_and_size():
31  """Get MPI rank and size from environment variables and return them as a
32  tuple of integers.
33
34  Most MPI implementations have an `mpirun` or `mpiexec` command that will
35  run an MPI executable and set up all communication necessary between the
36  different processors. As part of that set up, they will set environment
37  variables that contain the rank and size of the MPI_COMM_WORLD
38  communicator. We can read those environment variables from Python in order
39  to ensure that `mpi.rank()` and `mpi.size()` return the expected values.
40
41  Since MPI is just a standard, not an implementation, implementations
42  typically choose their own environment variable names. This function tries
43  to support several different implementation, but really it only needs to
44  support whatever implementation we want to use for the TensorFlow test
45  suite.
46
47  If this is not running under MPI, then defaults of rank zero and size one
48  are returned. (This is appropriate because when you call MPI_Init in an
49  application not started with mpirun, it will create a new independent
50  communicator with only one process in it.)
51  """
52  rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split()
53  size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split()
54
55  for rank_var, size_var in zip(rank_env, size_env):
56    rank = os.environ.get(rank_var)
57    size = os.environ.get(size_var)
58    if rank is not None and size is not None:
59      return int(rank), int(size)
60
61  # Default to rank zero and size one if there are no environment variables
62  return 0, 1
63
64
65class MPITests(tf.test.TestCase):
66  """
67  Tests for MPI ops in tensorflow.contrib.mpi_collectives.
68  """
69
70  def test_mpi_rank(self):
71    """Test that the rank returned by mpi.rank() is correct."""
72    true_rank, _ = mpi_env_rank_and_size()
73    with self.test_session() as session:
74      rank = session.run(mpi.rank())
75      self.assertEqual(true_rank, rank)
76
77  def test_mpi_size(self):
78    """Test that the size returned by mpi.size() is correct."""
79    _, true_size = mpi_env_rank_and_size()
80    with self.test_session() as session:
81      size = session.run(mpi.size())
82      self.assertEqual(true_size, size)
83
84  def test_mpi_allreduce_cpu(self):
85    """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors."""
86    with self.test_session() as session:
87      size = session.run(mpi.size())
88
89      dtypes = [tf.int32, tf.float32]
90      dims = [1, 2, 3]
91      for dtype, dim in itertools.product(dtypes, dims):
92        tf.set_random_seed(1234)
93        tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
94        summed = mpi.allreduce(tensor, average=False)
95        multiplied = tensor * size
96        max_difference = tf.reduce_max(tf.abs(summed - multiplied))
97
98        # Threshold for floating point equality depends on number of
99        # ranks, since we're comparing against precise multiplication.
100        if size <= 3:
101          threshold = 0
102        elif size < 10:
103          threshold = 1e-4
104        elif size < 15:
105          threshold = 5e-4
106        else:
107          break
108
109        diff = session.run(max_difference)
110        self.assertTrue(diff <= threshold,
111                        "mpi.allreduce produces incorrect results")
112
113  def test_mpi_allreduce_gpu(self):
114    """Test that the allreduce works on GPUs.
115
116    This test will crash badly if used with an MPI implementation that does
117    not support GPU memory transfers directly, as it will call MPI_Send on
118    a GPU data pointer."""
119    # Only do this test if there are GPUs available.
120    if not tf.test.is_gpu_available(cuda_only=True):
121      return
122
123    no_gpus = tf.GPUOptions(visible_device_list="")
124    cpu_config = tf.ConfigProto(gpu_options=no_gpus)
125    with self.test_session(config=cpu_config) as session:
126      local_rank = session.run(mpi.local_rank())
127
128    one_gpu = tf.GPUOptions(visible_device_list=str(local_rank))
129    gpu_config = tf.ConfigProto(gpu_options=one_gpu)
130    with self.test_session(config=gpu_config) as session:
131      size = session.run(mpi.size())
132
133      dtype = tf.float32
134      dim = 3
135      with tf.device("/gpu:0"):
136        tf.set_random_seed(1234)
137        tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
138        summed = mpi.allreduce(tensor, average=False)
139        multiplied = tensor * size
140        max_difference = tf.reduce_max(tf.abs(summed - multiplied))
141
142      # Threshold for floating point equality depends on number of
143      # ranks, since we're comparing against precise multiplication.
144      if size <= 3:
145        threshold = 0
146      elif size < 10:
147        threshold = 1e-4
148      elif size < 15:
149        threshold = 5e-4
150      else:
151        return
152
153      diff = session.run(max_difference)
154      self.assertTrue(diff <= threshold,
155                      "mpi.allreduce on GPU produces incorrect results")
156
157  def test_mpi_allreduce_error(self):
158    """Test that the allreduce raises an error if different ranks try to
159    send tensors of different rank or dimension."""
160    with self.test_session() as session:
161      rank = session.run(mpi.rank())
162      size = session.run(mpi.size())
163
164      # This test does not apply if there is only one worker.
165      if size == 1:
166        return
167
168      # Same rank, different dimension
169      tf.set_random_seed(1234)
170      dims = [17 + rank] * 3
171      tensor = tf.random_uniform(dims, -1.0, 1.0)
172      with self.assertRaises(tf.errors.FailedPreconditionError):
173        session.run(mpi.allreduce(tensor))
174
175      # Same number of elements, different rank
176      tf.set_random_seed(1234)
177      if rank == 0:
178        dims = [17, 23 * 57]
179      else:
180        dims = [17, 23, 57]
181      tensor = tf.random_uniform(dims, -1.0, 1.0)
182      with self.assertRaises(tf.errors.FailedPreconditionError):
183        session.run(mpi.allreduce(tensor))
184
185  def test_mpi_allreduce_type_error(self):
186    """Test that the allreduce raises an error if different ranks try to
187    send tensors of different type."""
188    with self.test_session() as session:
189      rank = session.run(mpi.rank())
190      size = session.run(mpi.size())
191
192      # This test does not apply if there is only one worker.
193      if size == 1:
194        return
195
196      # Same rank, different dimension
197      dims = [17] * 3
198      tensor = tf.ones(dims, dtype=tf.int32 if rank % 2 == 0 else tf.float32)
199      with self.assertRaises(tf.errors.FailedPreconditionError):
200        session.run(mpi.allreduce(tensor))
201
202  def test_mpi_allgather(self):
203    """Test that the allgather correctly gathers 1D, 2D, 3D tensors."""
204    with self.test_session() as session:
205      size = session.run(mpi.size())
206      rank = session.run(mpi.rank())
207
208      dtypes = tf.int32, tf.float32
209      dims = 1, 2, 3
210      for dtype, dim in itertools.product(dtypes, dims):
211        tensor = tf.ones([17] * dim, dtype=dtype) * rank
212        gathered = mpi.allgather(tensor)
213
214        gathered_tensor = session.run(gathered)
215        self.assertEqual(list(gathered_tensor.shape),
216                         [17 * size] + [17] * (dim - 1))
217
218        for i in range(size):
219          rank_tensor = tf.slice(gathered_tensor, [i * 17] + [0] * (dim - 1),
220                                 [17] + [-1] * (dim - 1))
221          self.assertEqual(list(rank_tensor.shape), [17] * dim)
222          self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
223                          "mpi.allgather produces incorrect gathered tensor")
224
225  def test_mpi_allgather_variable_size(self):
226    """Test that the allgather correctly gathers 1D, 2D, 3D tensors,
227    even if those tensors have different sizes along the first dim."""
228    with self.test_session() as session:
229      size = session.run(mpi.size())
230      rank = session.run(mpi.rank())
231
232      dtypes = tf.int32, tf.float32
233      dims = 1, 2, 3
234      for dtype, dim in itertools.product(dtypes, dims):
235        # Support tests up to MPI Size of 35
236        if size > 35:
237          break
238
239        tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5
240        tensor_sizes = tensor_sizes[:size]
241
242        tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1),
243                         dtype=dtype) * rank
244        gathered = mpi.allgather(tensor)
245
246        gathered_tensor = session.run(gathered)
247        expected_size = sum(tensor_sizes)
248        self.assertEqual(list(gathered_tensor.shape),
249                         [expected_size] + [17] * (dim - 1))
250
251        for i in range(size):
252          rank_size = [tensor_sizes[i]] + [17] * (dim - 1)
253          rank_tensor = tf.slice(gathered,
254                                 [sum(tensor_sizes[:i])] + [0] * (dim - 1),
255                                 rank_size)
256          self.assertEqual(list(rank_tensor.shape), rank_size)
257          self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
258                          "mpi.allgather produces incorrect gathered tensor")
259
260  def test_mpi_allgather_error(self):
261    """Test that the allgather returns an error if any dimension besides
262    the first is different among the tensors being gathered."""
263    with self.test_session() as session:
264      rank = session.run(mpi.rank())
265      size = session.run(mpi.size())
266
267      # This test does not apply if there is only one worker.
268      if size == 1:
269        return
270
271      tensor_size = [17] * 3
272      tensor_size[1] = 10 * (rank + 1)
273      tensor = tf.ones(tensor_size, dtype=tf.float32) * rank
274      with self.assertRaises(tf.errors.FailedPreconditionError):
275        session.run(mpi.allgather(tensor))
276
277  def test_mpi_allgather_type_error(self):
278    """Test that the allgather returns an error if the types being gathered
279    differ among the processes"""
280    with self.test_session() as session:
281      rank = session.run(mpi.rank())
282      size = session.run(mpi.size())
283
284      # This test does not apply if there is only one worker.
285      if size == 1:
286        return
287
288      tensor_size = [17] * 3
289      dtype = tf.int32 if rank % 2 == 0 else tf.float32
290      tensor = tf.ones(tensor_size, dtype=dtype) * rank
291      with self.assertRaises(tf.errors.FailedPreconditionError):
292        session.run(mpi.allgather(tensor))
293
294
295if __name__ == '__main__':
296  tf.test.main()
297