1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/client/computation_builder.h"
20 #include "tensorflow/compiler/xla/client/global_data.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace xla {
30 namespace {
31
32 class SelectTest : public ClientLibraryTestBase {
33 public:
34 ErrorSpec error_spec_{0.0001};
35 };
36
TEST_F(SelectTest,SelectScalarF32True)37 TEST_F(SelectTest, SelectScalarF32True) {
38 ComputationBuilder builder(client_, TestName());
39 auto pred = builder.ConstantR0<bool>(true);
40 auto on_true = builder.ConstantR0<float>(123.0f);
41 auto on_false = builder.ConstantR0<float>(42.0f);
42 auto result = builder.Select(pred, on_true, on_false);
43
44 ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
45 }
46
TEST_F(SelectTest,SelectScalarS32True)47 TEST_F(SelectTest, SelectScalarS32True) {
48 ComputationBuilder builder(client_, TestName());
49 auto pred = builder.ConstantR0<bool>(true);
50 auto on_true = builder.ConstantR0<int32>(-42);
51 auto on_false = builder.ConstantR0<int32>(42);
52 auto result = builder.Select(pred, on_true, on_false);
53
54 ComputeAndCompareR0<int32>(&builder, -42, {});
55 }
56
TEST_F(SelectTest,SelectScalarF32False)57 TEST_F(SelectTest, SelectScalarF32False) {
58 ComputationBuilder builder(client_, TestName());
59 auto pred = builder.ConstantR0<bool>(false);
60 auto on_true = builder.ConstantR0<float>(123.0f);
61 auto on_false = builder.ConstantR0<float>(42.0f);
62 auto result = builder.Select(pred, on_true, on_false);
63
64 ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
65 }
66
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)67 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
68 ComputationBuilder builder(client_, TestName());
69 auto pred = builder.ConstantR1<bool>({});
70 auto on_true = builder.ConstantR1<float>({});
71 auto on_false = builder.ConstantR1<float>({});
72 auto select = builder.Select(pred, on_true, on_false);
73
74 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
75 }
76
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)77 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
78 ComputationBuilder builder(client_, TestName());
79 auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
80 auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
81 auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
82 auto select = builder.Select(pred, on_true, on_false);
83
84 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
85 error_spec_);
86 }
87
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)88 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
89 // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
90 // is not a constant, but rather the result of comparing two other vectors.
91 ComputationBuilder builder(client_, TestName());
92 auto v1 = builder.ConstantR1<int32>({});
93 auto v2 = builder.ConstantR1<int32>({});
94 auto cmp = builder.Eq(v1, v2);
95 auto on_true = builder.ConstantR1<float>({});
96 auto on_false = builder.ConstantR1<float>({});
97 auto select = builder.Select(cmp, on_true, on_false);
98
99 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
100 }
101
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)102 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
103 // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
104 // not a constant, but rather the result of comparing two other vectors.
105 ComputationBuilder builder(client_, TestName());
106 auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
107 auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
108 auto cmp = builder.Eq(v1, v2);
109 auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
110 auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
111 auto select = builder.Select(cmp, on_true, on_false);
112
113 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
114 error_spec_);
115 }
116
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)117 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
118 // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
119 ComputationBuilder builder(client_, TestName());
120 auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
121 auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
122 auto cmp = builder.Gt(v1, v2);
123 auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
124 auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
125 auto select = builder.Select(cmp, on_true, on_false);
126
127 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
128 error_spec_);
129 }
130
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)131 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
132 // Selects among two R1F32s, which come from parameters. v1 and v2 are
133 // compared, and selection between them happens based on a gt-comparison mask.
134 ComputationBuilder builder(client_, TestName());
135
136 ComputationDataHandle v1, v2;
137 std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
138 {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
139 /*builder=*/&builder, /*data_handle=*/&v1);
140 std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
141 {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
142 /*builder=*/&builder, /*data_handle=*/&v2);
143
144 auto cmp = builder.Gt(v1, v2);
145 auto select = builder.Select(cmp, v1, v2);
146 ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
147 {param0_data.get(), param1_data.get()},
148 error_spec_);
149 }
150
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)151 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
152 // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
153 // data size passed in and out is large.
154 ComputationBuilder builder(client_, TestName());
155
156 // Number of floats in the data passed into and out of the computation.
157 constexpr int datalen = 15 * 1000;
158
159 // The inputs are initialized with a special pattern where in the first third
160 // of the data v1[i] > v2[i] and elsewhere it's vice versa.
161 std::vector<float> v1vec;
162 std::vector<float> v2vec;
163 std::vector<float> expected_vec;
164 for (int i = 0; i < datalen; ++i) {
165 float smaller = i;
166 float larger = i * 2;
167 if (i < datalen / 3) {
168 v1vec.push_back(larger);
169 v2vec.push_back(smaller);
170 } else {
171 v1vec.push_back(smaller);
172 v2vec.push_back(larger);
173 }
174 expected_vec.push_back(larger);
175 }
176
177 ComputationDataHandle v1, v2;
178 std::unique_ptr<GlobalData> param0_data =
179 CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
180 /*builder=*/&builder, /*data_handle=*/&v1);
181 std::unique_ptr<GlobalData> param1_data =
182 CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
183 /*builder=*/&builder, /*data_handle=*/&v2);
184
185 auto cmp = builder.Gt(v1, v2);
186 auto select = builder.Select(cmp, v1, v2);
187 ComputeAndCompareR1<float>(&builder, expected_vec,
188 {param0_data.get(), param1_data.get()},
189 error_spec_);
190 }
191
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)192 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
193 // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
194 // select between two R1F32s.
195 ComputationBuilder builder(client_, TestName());
196 auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
197 auto s = builder.ConstantR0<int32>(0);
198 auto cmp = builder.Gt(v, s);
199
200 auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
201 auto on_false =
202 builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
203 auto select = builder.Select(cmp, on_true, on_false);
204
205 ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
206 error_spec_);
207 }
208
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)209 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
210 // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
211 // select between two R1F32s.
212 ComputationBuilder builder(client_, TestName());
213 auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
214 auto s = builder.ConstantR0<float>(2.5f);
215 auto cmp = builder.Gt(v, s);
216
217 auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
218 auto on_false =
219 builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
220 auto select = builder.Select(cmp, on_true, on_false);
221
222 ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
223 error_spec_);
224 }
225
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)226 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
227 for (bool which : {false, true}) {
228 ComputationBuilder builder(client_, TestName());
229 auto pred = builder.ConstantR0<bool>(which);
230 auto on_true = builder.ConstantR1<float>({});
231 auto on_false = builder.ConstantR1<float>({});
232 auto select = builder.Select(pred, on_true, on_false);
233
234 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
235 }
236 }
237
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)238 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
239 ComputationBuilder builder(client_, TestName());
240 auto pred = builder.ConstantR0<bool>(true);
241 auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
242 auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
243 auto select = builder.Select(pred, on_true, on_false);
244
245 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
246 }
247
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)248 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
249 ComputationBuilder builder(client_, TestName());
250 auto pred = builder.ConstantR0<bool>(false);
251 auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
252 auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
253 auto select = builder.Select(pred, on_true, on_false);
254
255 ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
256 }
257 } // namespace
258 } // namespace xla
259