• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import numpy as np
22import tensorflow as tf
23import tensorflow.contrib.mpi_collectives as mpi
24from tensorflow.python.platform import test
25
26
27average_allreduce = False
28max_wrong_count = -1
29
30
31class AllreduceTest(test.TestCase):
32  def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red,
33                  our_correct):
34    # Find reduced/allreduced indices that are wrong and print all the
35    # values from output, slices, reduced, allreduced, so we can debug
36    # which is incorrect:
37    wrong_count = 0
38    red_dims = out_loc_red.shape
39    assert(len(red_dims) == 2)
40    for i in range(red_dims[0]):
41      for j in range(red_dims[1]):
42        suffix = ""
43        if out_loc_red[i][j] != my_correct[i][j] or \
44           out_all_red[i][j] != our_correct[i][j]:
45          suffix = "WRONG"
46          wrong_count += 1
47        print("{}\t{}\t{}\t{}\t{}\t{}"
48              .format(my_rank, i, j, out_loc_red[i][j],
49                      out_all_red[i][j], suffix), flush=True)
50        if max_wrong_count > 0 and wrong_count >= max_wrong_count:
51          return
52
53  def test_mpi_allreduce(self):
54    # Get MPI rank
55    my_rank = int(os.environ['PMI_RANK'])
56    num_ranks = int(os.environ['PMI_SIZE'])
57
58    stages = 13
59    batch_size = 1331
60    hidden_size = batch_size
61    out_size = batch_size
62
63    # Input placeholder (batch_size x hidden) - init to 1s
64    inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size),
65                            name="Input")
66
67    # Large matrices (hidden x out_dim) - init random
68    weights = []
69    for i in range(stages):
70      initer = tf.constant_initializer(pow(2.0, i + 1.0))
71      weights.append(tf.get_variable("weights_{}".format(i),
72                                     shape=(hidden_size, out_size),
73                                     dtype=tf.float32,
74                                     initializer=initer))
75
76    # Calculate output through dependent allreduces
77    stage_input = inputs
78    for i in range(stages):
79      inter_output = tf.add(stage_input, weights[i],
80                            name="add_red_{}".format(i))
81      stage_input = mpi.allreduce(inter_output,
82                                  average=average_allreduce)
83
84    all_reduced = stage_input
85
86    # Local reduced output for verification
87    local_input = inputs
88    for i in range(stages):
89      inter_output = tf.add(local_input, weights[i],
90                            name="addin_loc_{}".format(i))
91      my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)),
92                               dtype=tf.float32, name="loc_redr_{}".format(i))
93      for r in range(num_ranks):
94        my_reducer = tf.add(my_reducer, inter_output,
95                            name="add_loc_{}_{}".format(i, r))
96      if average_allreduce:
97        local_input = tf.div(my_reducer, num_ranks,
98                             name="div_loc_{}".format(i))
99      else:
100        local_input = my_reducer
101
102    local_reduced = local_input
103
104    # NOTE: This assumes that device IDs are numbered the same as ranks
105    gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
106    config = tf.ConfigProto(gpu_options=gpu_options)
107
108    # MPI Session to test allreduce
109    with mpi.Session(config=config) as sess:
110      sess.run(tf.global_variables_initializer())
111
112      input_feed = np.ones((batch_size, hidden_size), dtype=np.float32)
113      our_output = input_feed[0][0]
114      spread_var = 100
115      input_feed = input_feed + my_rank * spread_var
116      my_output = input_feed[0][0]
117      for i in range(stages):
118        curr_feed = my_output + pow(2.0, i + 1.0)
119        my_output = curr_feed * num_ranks + 1
120        curr_our_feed = our_output + pow(2.0, i + 1.0)
121        if i == 0:
122          sum_ranks = num_ranks * (num_ranks - 1) / 2
123          our_output = curr_our_feed * num_ranks + \
124            spread_var * sum_ranks
125        else:
126          our_output = curr_our_feed * num_ranks
127
128      print("rank {}: My output is {}".format(my_rank, my_output))
129      my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
130      my_correct = my_correct + my_output
131      print("rank {}: Our output is {}".format(my_rank, our_output))
132      our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
133      our_correct = our_correct + our_output
134
135      for i in range(1000):
136        if i % 100 == 0:
137          print("{}: iter {}".format(my_rank, i), flush=True)
138        feed_dict = {inputs: input_feed}
139        out_all_red, out_loc_red \
140          = sess.run([all_reduced, local_reduced],
141                     feed_dict=feed_dict)
142
143        if not np.allclose(out_loc_red, my_correct) or \
144           not np.allclose(out_all_red, our_correct):
145          print("Test incorrect on iter {}".format(i), flush=True)
146          self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red,
147                           our_correct)
148          assert(np.allclose(out_loc_red, my_correct) and
149                 np.allclose(out_all_red, our_correct))
150
151
152if __name__ == '__main__':
153  test.main()
154