• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/client/computation_builder.h"
20 #include "tensorflow/compiler/xla/client/global_data.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace xla {
30 namespace {
31 
32 class SelectTest : public ClientLibraryTestBase {
33  public:
34   ErrorSpec error_spec_{0.0001};
35 };
36 
TEST_F(SelectTest,SelectScalarF32True)37 TEST_F(SelectTest, SelectScalarF32True) {
38   ComputationBuilder builder(client_, TestName());
39   auto pred = builder.ConstantR0<bool>(true);
40   auto on_true = builder.ConstantR0<float>(123.0f);
41   auto on_false = builder.ConstantR0<float>(42.0f);
42   auto result = builder.Select(pred, on_true, on_false);
43 
44   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
45 }
46 
TEST_F(SelectTest,SelectScalarS32True)47 TEST_F(SelectTest, SelectScalarS32True) {
48   ComputationBuilder builder(client_, TestName());
49   auto pred = builder.ConstantR0<bool>(true);
50   auto on_true = builder.ConstantR0<int32>(-42);
51   auto on_false = builder.ConstantR0<int32>(42);
52   auto result = builder.Select(pred, on_true, on_false);
53 
54   ComputeAndCompareR0<int32>(&builder, -42, {});
55 }
56 
TEST_F(SelectTest,SelectScalarF32False)57 TEST_F(SelectTest, SelectScalarF32False) {
58   ComputationBuilder builder(client_, TestName());
59   auto pred = builder.ConstantR0<bool>(false);
60   auto on_true = builder.ConstantR0<float>(123.0f);
61   auto on_false = builder.ConstantR0<float>(42.0f);
62   auto result = builder.Select(pred, on_true, on_false);
63 
64   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
65 }
66 
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)67 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
68   ComputationBuilder builder(client_, TestName());
69   auto pred = builder.ConstantR1<bool>({});
70   auto on_true = builder.ConstantR1<float>({});
71   auto on_false = builder.ConstantR1<float>({});
72   auto select = builder.Select(pred, on_true, on_false);
73 
74   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
75 }
76 
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)77 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
78   ComputationBuilder builder(client_, TestName());
79   auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
80   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
81   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
82   auto select = builder.Select(pred, on_true, on_false);
83 
84   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
85                              error_spec_);
86 }
87 
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)88 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
89   // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
90   // is not a constant, but rather the result of comparing two other vectors.
91   ComputationBuilder builder(client_, TestName());
92   auto v1 = builder.ConstantR1<int32>({});
93   auto v2 = builder.ConstantR1<int32>({});
94   auto cmp = builder.Eq(v1, v2);
95   auto on_true = builder.ConstantR1<float>({});
96   auto on_false = builder.ConstantR1<float>({});
97   auto select = builder.Select(cmp, on_true, on_false);
98 
99   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
100 }
101 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)102 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
103   // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
104   // not a constant, but rather the result of comparing two other vectors.
105   ComputationBuilder builder(client_, TestName());
106   auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
107   auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
108   auto cmp = builder.Eq(v1, v2);
109   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
110   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
111   auto select = builder.Select(cmp, on_true, on_false);
112 
113   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
114                              error_spec_);
115 }
116 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)117 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
118   // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
119   ComputationBuilder builder(client_, TestName());
120   auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
121   auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
122   auto cmp = builder.Gt(v1, v2);
123   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
124   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
125   auto select = builder.Select(cmp, on_true, on_false);
126 
127   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
128                              error_spec_);
129 }
130 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)131 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
132   // Selects among two R1F32s, which come from parameters. v1 and v2 are
133   // compared, and selection between them happens based on a gt-comparison mask.
134   ComputationBuilder builder(client_, TestName());
135 
136   ComputationDataHandle v1, v2;
137   std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
138       {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
139       /*builder=*/&builder, /*data_handle=*/&v1);
140   std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
141       {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
142       /*builder=*/&builder, /*data_handle=*/&v2);
143 
144   auto cmp = builder.Gt(v1, v2);
145   auto select = builder.Select(cmp, v1, v2);
146   ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
147                              {param0_data.get(), param1_data.get()},
148                              error_spec_);
149 }
150 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)151 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
152   // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
153   // data size passed in and out is large.
154   ComputationBuilder builder(client_, TestName());
155 
156   // Number of floats in the data passed into and out of the computation.
157   constexpr int datalen = 15 * 1000;
158 
159   // The inputs are initialized with a special pattern where in the first third
160   // of the data v1[i] > v2[i] and elsewhere it's vice versa.
161   std::vector<float> v1vec;
162   std::vector<float> v2vec;
163   std::vector<float> expected_vec;
164   for (int i = 0; i < datalen; ++i) {
165     float smaller = i;
166     float larger = i * 2;
167     if (i < datalen / 3) {
168       v1vec.push_back(larger);
169       v2vec.push_back(smaller);
170     } else {
171       v1vec.push_back(smaller);
172       v2vec.push_back(larger);
173     }
174     expected_vec.push_back(larger);
175   }
176 
177   ComputationDataHandle v1, v2;
178   std::unique_ptr<GlobalData> param0_data =
179       CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
180                                /*builder=*/&builder, /*data_handle=*/&v1);
181   std::unique_ptr<GlobalData> param1_data =
182       CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
183                                /*builder=*/&builder, /*data_handle=*/&v2);
184 
185   auto cmp = builder.Gt(v1, v2);
186   auto select = builder.Select(cmp, v1, v2);
187   ComputeAndCompareR1<float>(&builder, expected_vec,
188                              {param0_data.get(), param1_data.get()},
189                              error_spec_);
190 }
191 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)192 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
193   // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
194   // select between two R1F32s.
195   ComputationBuilder builder(client_, TestName());
196   auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
197   auto s = builder.ConstantR0<int32>(0);
198   auto cmp = builder.Gt(v, s);
199 
200   auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
201   auto on_false =
202       builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
203   auto select = builder.Select(cmp, on_true, on_false);
204 
205   ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
206                              error_spec_);
207 }
208 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)209 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
210   // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
211   // select between two R1F32s.
212   ComputationBuilder builder(client_, TestName());
213   auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
214   auto s = builder.ConstantR0<float>(2.5f);
215   auto cmp = builder.Gt(v, s);
216 
217   auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
218   auto on_false =
219       builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
220   auto select = builder.Select(cmp, on_true, on_false);
221 
222   ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
223                              error_spec_);
224 }
225 
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)226 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
227   for (bool which : {false, true}) {
228     ComputationBuilder builder(client_, TestName());
229     auto pred = builder.ConstantR0<bool>(which);
230     auto on_true = builder.ConstantR1<float>({});
231     auto on_false = builder.ConstantR1<float>({});
232     auto select = builder.Select(pred, on_true, on_false);
233 
234     ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
235   }
236 }
237 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)238 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
239   ComputationBuilder builder(client_, TestName());
240   auto pred = builder.ConstantR0<bool>(true);
241   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
242   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
243   auto select = builder.Select(pred, on_true, on_false);
244 
245   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
246 }
247 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)248 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
249   ComputationBuilder builder(client_, TestName());
250   auto pred = builder.ConstantR0<bool>(false);
251   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
252   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
253   auto select = builder.Select(pred, on_true, on_false);
254 
255   ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
256 }
257 }  // namespace
258 }  // namespace xla
259