• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 #include "tensorflow/compiler/xla/shape_util.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/test_helpers.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22 
23 namespace xla {
24 namespace {
25 
26 class TrivialAllReduceTest : public HloTestBase {};
27 
28 // Currently the CPU and GPU backends only support AllReduce with one
29 // replica.  But we can at least check this.
30 
XLA_TEST_F(TrivialAllReduceTest,OneOperand)31 XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
32   const char* module_str = R"(
33   HloModule test
34 
35   add {
36     x = f32[] parameter(0)
37     y = f32[] parameter(1)
38     add = f32[] add(x, y)
39   }
40 
41   ENTRY test_computation {
42     p = f32[3] parameter(0)
43     ROOT crs = f32[3] all-reduce(p), to_apply=add
44   })";
45   auto module =
46       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
47           .ValueOrDie();
48   auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
49   EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
50 }
51 
XLA_TEST_F(TrivialAllReduceTest,MultipleOperands)52 XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
53   const char* module_str = R"(
54   HloModule test
55 
56   add {
57     x = f32[] parameter(0)
58     y = f32[] parameter(1)
59     add = f32[] add(x, y)
60   }
61 
62   ENTRY test_computation {
63     p0 = f32[3] parameter(0)
64     p1 = f32[2] parameter(1)
65     ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
66   })";
67   auto module =
68       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
69           .ValueOrDie();
70   auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
71   auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
72   EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
73             ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
74 }
75 
76 // On the GPU backend, constants get special handling.  Someone might pass a
77 // constant to CRS to e.g. count the number of replicas -- we need to make sure
78 // it works.
XLA_TEST_F(TrivialAllReduceTest,ConstantOperand)79 XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
80   const char* module_str = R"(
81   HloModule test
82 
83   add {
84     x = f32[] parameter(0)
85     y = f32[] parameter(1)
86     add = f32[] add(x, y)
87   }
88 
89   ENTRY test_computation {
90     p0 = f32[3] parameter(0)
91     p1 = f32[2] constant({10, 20})
92     ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
93   })";
94   auto module =
95       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
96           .ValueOrDie();
97   auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
98   auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
99   EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
100             ExecuteAndTransfer(std::move(module), {&literal0}));
101 }
102 
XLA_TEST_F(TrivialAllReduceTest,AllReduceU8)103 XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) {
104   const char* module_str = R"(
105 HloModule test
106 
107 %AddComputation.15 {
108   %x.16 = u8[] parameter(0)
109   %y.17 = u8[] parameter(1)
110   ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17)
111 }
112 
113 ENTRY %test_computation {
114   %constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
115   %reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
116   %broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
117   %reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
118   %broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
119   %constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
120   %reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
121   %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
122   %dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
123   %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
124   %convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
125   %tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
126   %get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
127   %get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
128   %all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
129   %get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
130   %convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
131   %get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
132   ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20)
133 })";
134 
135   auto module =
136       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
137           .ValueOrDie();
138   auto literal_in = LiteralUtil::CreateR0<float>(0);
139   auto literal0 = LiteralUtil::CreateR1<uint8_t>({1, 0, 0, 0, 0, 0, 0, 0});
140   EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
141             ExecuteAndTransfer(std::move(module), {&literal_in}));
142 }
143 
XLA_TEST_F(TrivialAllReduceTest,AllReduceS32)144 XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) {
145   const char* module_str = R"(
146 
147 HloModule test
148 
149 %AddComputation.15 {
150   %x.16 = s32[] parameter(0)
151   %y.17 = s32[] parameter(1)
152   ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17)
153 }
154 
155 ENTRY %test_computation {
156   %constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
157   %reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
158   %broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
159   %reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
160   %broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
161   %constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
162   %reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
163   %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
164   %dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
165   %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
166   %convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
167   %tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
168   %get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
169   %get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
170   %all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
171   %get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
172   %convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
173   %get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
174   ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20)
175 })";
176 
177   auto module =
178       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
179           .ValueOrDie();
180   auto literal_in = LiteralUtil::CreateR0<float>(0);
181   auto literal0 = LiteralUtil::CreateR1<int32>({1, 0, 0, 0, 0, 0, 0, 0});
182   EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
183             ExecuteAndTransfer(std::move(module), {&literal_in}));
184 }
185 
186 }  // namespace
187 }  // namespace xla
188