• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2005, 2006 Douglas Gregor <doug.gregor -at- gmail.com>.
2 
3 // Use, modification and distribution is subject to the Boost Software
4 // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6 
7 // A test of the scan() collective.
8 #include <boost/mpi/collectives/scan.hpp>
9 #include <boost/mpi/communicator.hpp>
10 #include <boost/mpi/environment.hpp>
11 #include <algorithm>
12 #include <boost/serialization/string.hpp>
13 #include <boost/iterator/counting_iterator.hpp>
14 #include <boost/lexical_cast.hpp>
15 #include <numeric>
16 
17 #define BOOST_TEST_MODULE mpi_scan_test
18 #include <boost/test/included/unit_test.hpp>
19 
20 
21 using boost::mpi::communicator;
22 
23 // A simple point class that we can build, add, compare, and
24 // serialize.
25 struct point
26 {
pointpoint27   point() : x(0), y(0), z(0) { }
pointpoint28   point(int x, int y, int z) : x(x), y(y), z(z) { }
29 
30   int x;
31   int y;
32   int z;
33 
34  private:
35   template<typename Archiver>
serializepoint36   void serialize(Archiver& ar, unsigned int /*version*/)
37   {
38     ar & x & y & z;
39   }
40 
41   friend class boost::serialization::access;
42 };
43 
operator <<(std::ostream & out,const point & p)44 std::ostream& operator<<(std::ostream& out, const point& p)
45 {
46   return out << p.x << ' ' << p.y << ' ' << p.z;
47 }
48 
operator ==(const point & p1,const point & p2)49 bool operator==(const point& p1, const point& p2)
50 {
51   return p1.x == p2.x && p1.y == p2.y && p1.z == p2.z;
52 }
53 
operator !=(const point & p1,const point & p2)54 bool operator!=(const point& p1, const point& p2)
55 {
56   return !(p1 == p2);
57 }
58 
operator +(const point & p1,const point & p2)59 point operator+(const point& p1, const point& p2)
60 {
61   return point(p1.x + p2.x, p1.y + p2.y, p1.z + p2.z);
62 }
63 
64 namespace boost { namespace mpi {
65 
66   template <>
67   struct is_mpi_datatype<point> : public mpl::true_ { };
68 
69 } } // end namespace boost::mpi
70 
71 template<typename Generator, typename Op>
72 void
scan_test(const communicator & comm,Generator generator,const char * type_kind,Op op,const char * op_kind)73 scan_test(const communicator& comm, Generator generator,
74           const char* type_kind, Op op, const char* op_kind)
75 {
76   typedef typename Generator::result_type value_type;
77   value_type value = generator(comm.rank());
78   using boost::mpi::scan;
79 
80   if (comm.rank() == 0) {
81     std::cout << "Prefix reducing to " << op_kind << " of " << type_kind
82               << "...";
83     std::cout.flush();
84   }
85 
86   value_type result_value;
87   scan(comm, value, result_value, op);
88   value_type scan_result = scan(comm, value, op);
89   BOOST_CHECK(scan_result == result_value);
90 
91   // Compute expected result
92   std::vector<value_type> generated_values;
93   for (int p = 0; p < comm.size(); ++p)
94     generated_values.push_back(generator(p));
95   std::vector<value_type> expected_results(comm.size());
96   std::partial_sum(generated_values.begin(), generated_values.end(),
97                    expected_results.begin(), op);
98   BOOST_CHECK(result_value == expected_results[comm.rank()]);
99   if (comm.rank() == 0) std::cout << "Done." << std::endl;
100 
101   (comm.barrier)();
102 }
103 
104 // Generates integers to test with scan()
105 struct int_generator
106 {
107   typedef int result_type;
108 
int_generatorint_generator109   int_generator(int base = 1) : base(base) { }
110 
operator ()int_generator111   int operator()(int p) const { return base + p; }
112 
113  private:
114   int base;
115 };
116 
117 // Generate points to test with scan()
118 struct point_generator
119 {
120   typedef point result_type;
121 
point_generatorpoint_generator122   point_generator(point origin) : origin(origin) { }
123 
operator ()point_generator124   point operator()(int p) const
125   {
126     return point(origin.x + 1, origin.y + 1, origin.z + 1);
127   }
128 
129  private:
130   point origin;
131 };
132 
133 struct string_generator
134 {
135   typedef std::string result_type;
136 
operator ()string_generator137   std::string operator()(int p) const
138   {
139     std::string result = boost::lexical_cast<std::string>(p);
140     result += " rosebud";
141     if (p != 1) result += 's';
142     return result;
143   }
144 };
145 
146 struct secret_int_bit_and
147 {
operator ()secret_int_bit_and148   int operator()(int x, int y) const { return x & y; }
149 };
150 
151 struct wrapped_int
152 {
wrapped_intwrapped_int153   wrapped_int() : value(0) { }
wrapped_intwrapped_int154   explicit wrapped_int(int value) : value(value) { }
155 
156   template<typename Archive>
serializewrapped_int157   void serialize(Archive& ar, unsigned int /* version */)
158   {
159     ar & value;
160   }
161 
162   int value;
163 };
164 
operator +(const wrapped_int & x,const wrapped_int & y)165 wrapped_int operator+(const wrapped_int& x, const wrapped_int& y)
166 {
167   return wrapped_int(x.value + y.value);
168 }
169 
operator ==(const wrapped_int & x,const wrapped_int & y)170 bool operator==(const wrapped_int& x, const wrapped_int& y)
171 {
172   return x.value == y.value;
173 }
174 
175 // Generates wrapped_its to test with scan()
176 struct wrapped_int_generator
177 {
178   typedef wrapped_int result_type;
179 
wrapped_int_generatorwrapped_int_generator180   wrapped_int_generator(int base = 1) : base(base) { }
181 
operator ()wrapped_int_generator182   wrapped_int operator()(int p) const { return wrapped_int(base + p); }
183 
184  private:
185   int base;
186 };
187 
188 namespace boost { namespace mpi {
189 
190 // Make std::plus<wrapped_int> commutative.
191 template<>
192 struct is_commutative<std::plus<wrapped_int>, wrapped_int>
193   : mpl::true_ { };
194 
195 } } // end namespace boost::mpi
196 
BOOST_AUTO_TEST_CASE(scan_check)197 BOOST_AUTO_TEST_CASE(scan_check)
198 {
199   using namespace boost::mpi;
200   environment env;
201   communicator comm;
202 
203   // Built-in MPI datatypes with built-in MPI operations
204   scan_test(comm, int_generator(), "integers", std::plus<int>(), "sum");
205   scan_test(comm, int_generator(), "integers", std::multiplies<int>(),
206             "product");
207   scan_test(comm, int_generator(), "integers", maximum<int>(),
208             "maximum");
209   scan_test(comm, int_generator(), "integers", minimum<int>(),
210             "minimum");
211 
212   // User-defined MPI datatypes with operations that have the
213   // same name as built-in operations.
214   scan_test(comm, point_generator(point(0,0,0)), "points",
215             std::plus<point>(), "sum");
216 
217   // Built-in MPI datatypes with user-defined operations
218   scan_test(comm, int_generator(17), "integers", secret_int_bit_and(),
219             "bitwise and");
220 
221   // Arbitrary types with user-defined, commutative operations.
222   scan_test(comm, wrapped_int_generator(17), "wrapped integers",
223             std::plus<wrapped_int>(), "sum");
224 
225   // Arbitrary types with (non-commutative) user-defined operations
226   scan_test(comm, string_generator(), "strings",
227             std::plus<std::string>(), "concatenation");
228 }
229