1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import numpy as np 22import tensorflow as tf 23import tensorflow.contrib.mpi_collectives as mpi 24from tensorflow.python.platform import test 25 26 27average_allreduce = False 28max_wrong_count = -1 29 30 31class AllreduceTest(test.TestCase): 32 def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red, 33 our_correct): 34 # Find reduced/allreduced indices that are wrong and print all the 35 # values from output, slices, reduced, allreduced, so we can debug 36 # which is incorrect: 37 wrong_count = 0 38 red_dims = out_loc_red.shape 39 assert(len(red_dims) == 2) 40 for i in range(red_dims[0]): 41 for j in range(red_dims[1]): 42 suffix = "" 43 if out_loc_red[i][j] != my_correct[i][j] or \ 44 out_all_red[i][j] != our_correct[i][j]: 45 suffix = "WRONG" 46 wrong_count += 1 47 print("{}\t{}\t{}\t{}\t{}\t{}" 48 .format(my_rank, i, j, out_loc_red[i][j], 49 out_all_red[i][j], suffix), flush=True) 50 if max_wrong_count > 0 and wrong_count >= max_wrong_count: 51 return 52 53 def test_mpi_allreduce(self): 54 # Get MPI rank 55 my_rank = int(os.environ['PMI_RANK']) 56 num_ranks = int(os.environ['PMI_SIZE']) 57 58 stages = 13 59 batch_size = 1331 60 hidden_size = batch_size 61 out_size = batch_size 62 63 # Input placeholder (batch_size x hidden) - init to 1s 64 inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), 65 name="Input") 66 67 # Large matrices (hidden x out_dim) - init random 68 weights = [] 69 for i in range(stages): 70 initer = tf.constant_initializer(pow(2.0, i + 1.0)) 71 weights.append(tf.get_variable("weights_{}".format(i), 72 shape=(hidden_size, out_size), 73 dtype=tf.float32, 74 initializer=initer)) 75 76 # Calculate output through dependent allreduces 77 stage_input = inputs 78 for i in range(stages): 79 inter_output = tf.add(stage_input, weights[i], 80 name="add_red_{}".format(i)) 81 stage_input = mpi.allreduce(inter_output, 82 average=average_allreduce) 83 84 all_reduced = stage_input 85 86 # Local reduced output for verification 87 local_input = inputs 88 for i in range(stages): 89 inter_output = tf.add(local_input, weights[i], 90 name="addin_loc_{}".format(i)) 91 my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)), 92 dtype=tf.float32, name="loc_redr_{}".format(i)) 93 for r in range(num_ranks): 94 my_reducer = tf.add(my_reducer, inter_output, 95 name="add_loc_{}_{}".format(i, r)) 96 if average_allreduce: 97 local_input = tf.div(my_reducer, num_ranks, 98 name="div_loc_{}".format(i)) 99 else: 100 local_input = my_reducer 101 102 local_reduced = local_input 103 104 # NOTE: This assumes that device IDs are numbered the same as ranks 105 gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) 106 config = tf.ConfigProto(gpu_options=gpu_options) 107 108 # MPI Session to test allreduce 109 with mpi.Session(config=config) as sess: 110 sess.run(tf.global_variables_initializer()) 111 112 input_feed = np.ones((batch_size, hidden_size), dtype=np.float32) 113 our_output = input_feed[0][0] 114 spread_var = 100 115 input_feed = input_feed + my_rank * spread_var 116 my_output = input_feed[0][0] 117 for i in range(stages): 118 curr_feed = my_output + pow(2.0, i + 1.0) 119 my_output = curr_feed * num_ranks + 1 120 curr_our_feed = our_output + pow(2.0, i + 1.0) 121 if i == 0: 122 sum_ranks = num_ranks * (num_ranks - 1) / 2 123 our_output = curr_our_feed * num_ranks + \ 124 spread_var * sum_ranks 125 else: 126 our_output = curr_our_feed * num_ranks 127 128 print("rank {}: My output is {}".format(my_rank, my_output)) 129 my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) 130 my_correct = my_correct + my_output 131 print("rank {}: Our output is {}".format(my_rank, our_output)) 132 our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) 133 our_correct = our_correct + our_output 134 135 for i in range(1000): 136 if i % 100 == 0: 137 print("{}: iter {}".format(my_rank, i), flush=True) 138 feed_dict = {inputs: input_feed} 139 out_all_red, out_loc_red \ 140 = sess.run([all_reduced, local_reduced], 141 feed_dict=feed_dict) 142 143 if not np.allclose(out_loc_red, my_correct) or \ 144 not np.allclose(out_all_red, our_correct): 145 print("Test incorrect on iter {}".format(i), flush=True) 146 self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red, 147 our_correct) 148 assert(np.allclose(out_loc_red, my_correct) and 149 np.allclose(out_all_red, our_correct)) 150 151 152if __name__ == '__main__': 153 test.main() 154