1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17 #include "tensorflow/compiler/xla/shape_util.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/test_helpers.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22
23 namespace xla {
24 namespace {
25
26 class TrivialAllReduceTest : public HloTestBase {};
27
28 // Currently the CPU and GPU backends only support AllReduce with one
29 // replica. But we can at least check this.
30
XLA_TEST_F(TrivialAllReduceTest,OneOperand)31 XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
32 const char* module_str = R"(
33 HloModule test
34
35 add {
36 x = f32[] parameter(0)
37 y = f32[] parameter(1)
38 add = f32[] add(x, y)
39 }
40
41 ENTRY test_computation {
42 p = f32[3] parameter(0)
43 ROOT crs = f32[3] all-reduce(p), to_apply=add
44 })";
45 auto module =
46 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
47 .ValueOrDie();
48 auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
49 EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
50 }
51
XLA_TEST_F(TrivialAllReduceTest,MultipleOperands)52 XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
53 const char* module_str = R"(
54 HloModule test
55
56 add {
57 x = f32[] parameter(0)
58 y = f32[] parameter(1)
59 add = f32[] add(x, y)
60 }
61
62 ENTRY test_computation {
63 p0 = f32[3] parameter(0)
64 p1 = f32[2] parameter(1)
65 ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
66 })";
67 auto module =
68 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
69 .ValueOrDie();
70 auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
71 auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
72 EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
73 ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
74 }
75
76 // On the GPU backend, constants get special handling. Someone might pass a
77 // constant to CRS to e.g. count the number of replicas -- we need to make sure
78 // it works.
XLA_TEST_F(TrivialAllReduceTest,ConstantOperand)79 XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
80 const char* module_str = R"(
81 HloModule test
82
83 add {
84 x = f32[] parameter(0)
85 y = f32[] parameter(1)
86 add = f32[] add(x, y)
87 }
88
89 ENTRY test_computation {
90 p0 = f32[3] parameter(0)
91 p1 = f32[2] constant({10, 20})
92 ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
93 })";
94 auto module =
95 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
96 .ValueOrDie();
97 auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
98 auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
99 EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
100 ExecuteAndTransfer(std::move(module), {&literal0}));
101 }
102
XLA_TEST_F(TrivialAllReduceTest,AllReduceU8)103 XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) {
104 const char* module_str = R"(
105 HloModule test
106
107 %AddComputation.15 {
108 %x.16 = u8[] parameter(0)
109 %y.17 = u8[] parameter(1)
110 ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17)
111 }
112
113 ENTRY %test_computation {
114 %constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
115 %reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
116 %broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
117 %reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
118 %broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
119 %constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
120 %reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
121 %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
122 %dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
123 %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
124 %convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
125 %tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
126 %get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
127 %get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
128 %all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
129 %get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
130 %convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
131 %get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
132 ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20)
133 })";
134
135 auto module =
136 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
137 .ValueOrDie();
138 auto literal_in = LiteralUtil::CreateR0<float>(0);
139 auto literal0 = LiteralUtil::CreateR1<uint8_t>({1, 0, 0, 0, 0, 0, 0, 0});
140 EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
141 ExecuteAndTransfer(std::move(module), {&literal_in}));
142 }
143
XLA_TEST_F(TrivialAllReduceTest,AllReduceS32)144 XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) {
145 const char* module_str = R"(
146
147 HloModule test
148
149 %AddComputation.15 {
150 %x.16 = s32[] parameter(0)
151 %y.17 = s32[] parameter(1)
152 ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17)
153 }
154
155 ENTRY %test_computation {
156 %constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
157 %reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
158 %broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
159 %reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
160 %broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
161 %constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
162 %reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
163 %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
164 %dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
165 %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
166 %convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
167 %tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
168 %get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
169 %get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
170 %all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
171 %get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
172 %convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
173 %get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
174 ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20)
175 })";
176
177 auto module =
178 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
179 .ValueOrDie();
180 auto literal_in = LiteralUtil::CreateR0<float>(0);
181 auto literal0 = LiteralUtil::CreateR1<int32>({1, 0, 0, 0, 0, 0, 0, 0});
182 EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
183 ExecuteAndTransfer(std::move(module), {&literal_in}));
184 }
185
186 } // namespace
187 } // namespace xla
188