1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "tensorflow/compiler/xla/array2d.h"
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/tests/test_utils.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace xla {
38 namespace {
39
40 class MapTest : public ClientLibraryTestBase {
41 public:
MapTest(se::Platform * platform=nullptr)42 explicit MapTest(se::Platform* platform = nullptr)
43 : ClientLibraryTestBase(platform) {
44 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
45 mutable_debug_options()->add_xla_disable_hlo_passes("inline");
46 }
47
48 // Creates a function that adds its scalar argument with the constant 1.0.
49 //
50 // x {R0F32} ----> (add)
51 // /
52 // 1.0f ---------/
CreateAdderToOne()53 XlaComputation CreateAdderToOne() {
54 XlaBuilder mapped_builder(TestName());
55 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
56 auto one = ConstantR0<float>(&mapped_builder, 1.0);
57 Add(x, one);
58 auto computation_status = mapped_builder.Build();
59 TF_CHECK_OK(computation_status.status());
60 return computation_status.ConsumeValueOrDie();
61 }
62
CreateMax()63 XlaComputation CreateMax() {
64 XlaBuilder b(TestName());
65 auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
66 auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
67 Max(lhs, rhs);
68 auto computation_status = b.Build();
69 TF_CHECK_OK(computation_status.status());
70 return computation_status.ConsumeValueOrDie();
71 }
72
73 // Creates a computation that accepts an F32 and returns T(1) (ignoring the
74 // argument).
75 template <class T>
CreateScalarOne()76 XlaComputation CreateScalarOne() {
77 XlaBuilder mapped_builder("scalar_one");
78 (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
79 ConstantR0<T>(&mapped_builder, 1);
80 auto computation_status = mapped_builder.Build();
81 TF_CHECK_OK(computation_status.status());
82 return computation_status.ConsumeValueOrDie();
83 }
84
85 // Creates a function that multiplies its scalar argument by the constant 2.0
86 //
87 // x {R0F32} ----> (mul)
88 // /
89 // 2.0f ---------/
CreateMulByTwo()90 XlaComputation CreateMulByTwo() {
91 XlaBuilder mapped_builder(TestName());
92 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
93 auto two = ConstantR0<float>(&mapped_builder, 2.0);
94 Mul(x, two);
95 auto computation_status = mapped_builder.Build();
96 TF_CHECK_OK(computation_status.status());
97 return computation_status.ConsumeValueOrDie();
98 }
99
100 // Creates a function that adds its scalar argument with the constant 1.0 and
101 // then multiplies by the original element.
102 //
103 // /------------------|
104 // / |
105 // x {R0F32} ----> (add) ----> (mul)
106 // /
107 // 1.0f ---------/
CreateAdderToOneTimesItself()108 XlaComputation CreateAdderToOneTimesItself() {
109 XlaBuilder mapped_builder(TestName());
110 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
111 auto one = ConstantR0<float>(&mapped_builder, 1.0);
112 auto adder_to_one = Add(x, one);
113 Mul(x, adder_to_one);
114 auto computation_status = mapped_builder.Build();
115 TF_CHECK_OK(computation_status.status());
116 return computation_status.ConsumeValueOrDie();
117 }
118
119 // Creates a function that takes a single parameter and calls map with
120 // "embedded_computation" on it, and then adds "n" to the result.
121 //
122 // x {R0F32} -----------> (map) ----> (add)
123 // / /
124 // embedded_computation --/ n --/
CreateMapPlusN(const XlaComputation & embedded_computation,float n)125 XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation,
126 float n) {
127 XlaBuilder builder(TestName());
128 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
129 auto map = Map(&builder, {x}, embedded_computation, {});
130 auto constant_n = ConstantR0<float>(&builder, n);
131 Add(map, constant_n);
132 auto computation_status = builder.Build();
133 TF_CHECK_OK(computation_status.status());
134 return computation_status.ConsumeValueOrDie();
135 }
136
137 // Creates a binary function with signature (F32, F32) -> Pred
138 // defined by (x, y) -> x > y.
CreateGt()139 XlaComputation CreateGt() {
140 XlaBuilder b("Gt");
141 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
142 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
143 Gt(x, y);
144 auto computation_status = b.Build();
145 TF_CHECK_OK(computation_status.status());
146 return computation_status.ConsumeValueOrDie();
147 }
148
149 // Creates a function that adds three scalar arguments
150 //
151 // x {R0F32} -------|
152 // |
153 // y {R0F32} ----> (add) ---> (add)
154 // /
155 // z {R0F32} ---------------/
CreateTernaryAdder()156 XlaComputation CreateTernaryAdder() {
157 XlaBuilder mapped_builder("TernaryAdder");
158 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
159 auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
160 auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z");
161 auto xy = Add(x, y);
162 Add(xy, z);
163 auto computation_status = mapped_builder.Build();
164 TF_CHECK_OK(computation_status.status());
165 return computation_status.ConsumeValueOrDie();
166 }
167 };
168
TEST_F(MapTest,MapEachElemPlusOneR0)169 TEST_F(MapTest, MapEachElemPlusOneR0) {
170 // Applies lambda (x) (+ x 1)) to an input scalar.
171 XlaBuilder builder(TestName());
172 Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
173 std::unique_ptr<GlobalData> param0_data =
174 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
175
176 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
177 Map(&builder, {param}, CreateAdderToOne(), {});
178
179 ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
180 ErrorSpec(0.01f));
181 }
182
XLA_TEST_F(MapTest,MapEachElemPlusOneR1S0)183 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
184 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
185 XlaBuilder builder(TestName());
186 Literal param0_literal = LiteralUtil::CreateR1<float>({});
187 std::unique_ptr<GlobalData> param0_data =
188 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
189
190 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
191 Map(&builder, {param}, CreateAdderToOne(), {0});
192
193 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
194 ErrorSpec(0.01f));
195 }
196
TEST_F(MapTest,MapEachElemPlusOneR1S4)197 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
198 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
199 XlaBuilder builder(TestName());
200 Literal param0_literal =
201 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
202 std::unique_ptr<GlobalData> param0_data =
203 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
204
205 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
206 Map(&builder, {param}, CreateAdderToOne(), {0});
207
208 ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
209 {param0_data.get()}, ErrorSpec(0.01f));
210 }
211
TEST_F(MapTest,MapEachF32ElementToS32Constant)212 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
213 XlaBuilder builder(TestName());
214 Literal param0_literal =
215 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
216 std::unique_ptr<GlobalData> param0_data =
217 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
218
219 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
220 Map(&builder, {param}, CreateScalarOne<int32>(), {0});
221
222 ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
223 }
224
TEST_F(MapTest,MapEachF32ElementToU32Constant)225 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
226 XlaBuilder builder(TestName());
227 Literal param0_literal =
228 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
229 std::unique_ptr<GlobalData> param0_data =
230 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
231
232 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
233 Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
234
235 ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
236 }
237
TEST_F(MapTest,MapEachElemLongerChainR1)238 TEST_F(MapTest, MapEachElemLongerChainR1) {
239 // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
240 XlaBuilder builder(TestName());
241 Literal param0_literal =
242 LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
243 std::unique_ptr<GlobalData> param0_data =
244 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
245
246 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
247 Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
248
249 ComputeAndCompareR1<float>(
250 &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
251 {param0_data.get()}, ErrorSpec(0.01f));
252 }
253
XLA_TEST_F(MapTest,MapMultipleMapsR1S0)254 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
255 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
256 // maps (lambda (x) (* x 2)) on the result.
257 XlaBuilder builder(TestName());
258 Literal param0_literal = LiteralUtil::CreateR1<float>({});
259 std::unique_ptr<GlobalData> param0_data =
260 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
261
262 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
263 auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
264 Map(&builder, {map1}, CreateMulByTwo(), {0});
265
266 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
267 ErrorSpec(0.01f));
268 }
269
TEST_F(MapTest,MapMultipleMapsR1S4)270 TEST_F(MapTest, MapMultipleMapsR1S4) {
271 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
272 // maps (lambda (x) (* x 2)) on the result.
273 XlaBuilder builder(TestName());
274 Literal param0_literal =
275 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
276 std::unique_ptr<GlobalData> param0_data =
277 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
278
279 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
280 auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
281 Map(&builder, {map1}, CreateMulByTwo(), {0});
282
283 ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
284 {param0_data.get()}, ErrorSpec(0.01f));
285 }
286
TEST_F(MapTest,MapEachElemPlusOneR2)287 TEST_F(MapTest, MapEachElemPlusOneR2) {
288 // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
289 XlaBuilder builder(TestName());
290 Literal param0_literal = LiteralUtil::CreateR2<float>(
291 {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
292 std::unique_ptr<GlobalData> param0_data =
293 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
294
295 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
296 Map(&builder, {param}, CreateAdderToOne(), {0, 1});
297
298 Array2D<float> expected_array(
299 {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
300 ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
301 ErrorSpec(0.01f));
302 }
303
XLA_TEST_F(MapTest,ComplexNestedMaps)304 XLA_TEST_F(MapTest, ComplexNestedMaps) {
305 // Constructs a complex graph of embedded computations to test the computation
306 // lowering order. Python equivalent:
307 //
308 // embed1 = lambda x: x + 1 # x + 1
309 // embed2 = lambda x: embed1(x) + 2 # x + 3
310 // embed3 = lambda x: embed1(x) + 4 # x + 5
311 // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8
312 // embed5 = lambda x: embed2(x) + 6 # x + 9
313 // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73
314
315 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
316
317 auto embed1 = CreateAdderToOne();
318 auto embed2 = CreateMapPlusN(embed1, 2.0);
319 auto embed3 = CreateMapPlusN(embed1, 4.0);
320
321 XlaBuilder embed4_builder("embed4");
322 auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x");
323 auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {});
324 auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {});
325 Add(embed4_map_lhs, embed4_map_rhs);
326 auto embed4_status = embed4_builder.Build();
327 ASSERT_IS_OK(embed4_status.status());
328 auto embed4 = embed4_status.ConsumeValueOrDie();
329
330 auto embed5 = CreateMapPlusN(embed2, 6.0);
331
332 XlaBuilder builder(TestName());
333 auto constant_42 = ConstantR0<float>(&builder, 42.0);
334 auto constant_7 = ConstantR0<float>(&builder, 7.0);
335 auto map_42 = Map(&builder, {constant_42}, embed5, {});
336 auto map_7 = Map(&builder, {constant_7}, embed4, {});
337 Add(map_42, map_7);
338
339 ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
340 }
341
TEST_F(MapTest,MapBinaryAdder)342 TEST_F(MapTest, MapBinaryAdder) {
343 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
344 XlaBuilder builder(TestName());
345 Literal param0_literal =
346 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
347 std::unique_ptr<GlobalData> param0_data =
348 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
349 Literal param1_literal =
350 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
351 std::unique_ptr<GlobalData> param1_data =
352 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
353
354 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
355 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
356 Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
357 {0});
358
359 ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
360 {param0_data.get(), param1_data.get()},
361 ErrorSpec(0.01f));
362 }
363
364 // Adds two rank-2 arrays with different layouts. This test exercises a path
365 // for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest,AddWithMixedLayouts)366 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
367 XlaBuilder builder(TestName());
368 Literal param0_literal = LiteralUtil::CreateR2WithLayout(
369 {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
370 std::unique_ptr<GlobalData> param0_data =
371 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
372
373 Literal param1_literal = LiteralUtil::CreateR2WithLayout(
374 {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
375 std::unique_ptr<GlobalData> param1_data =
376 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
377
378 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
379 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
380 Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
381 {0, 1});
382
383 Array2D<int32> expected(2, 2);
384 expected(0, 0) = 11;
385 expected(0, 1) = 22;
386 expected(1, 0) = 33;
387 expected(1, 1) = 44;
388 ComputeAndCompareR2<int32>(&builder, expected,
389 {param0_data.get(), param1_data.get()});
390 }
391
XLA_TEST_F(MapTest,AddR3_3x0x2)392 XLA_TEST_F(MapTest, AddR3_3x0x2) {
393 XlaBuilder builder(TestName());
394 Literal param0_literal =
395 LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
396 std::unique_ptr<GlobalData> param0_data =
397 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
398
399 Literal param1_literal =
400 LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
401 std::unique_ptr<GlobalData> param1_data =
402 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
403
404 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
405 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
406 Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
407 {0, 1, 2});
408
409 ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
410 {param0_data.get(), param1_data.get()});
411 }
412
TEST_F(MapTest,MapTernaryAdder)413 TEST_F(MapTest, MapTernaryAdder) {
414 // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
415 XlaBuilder builder(TestName());
416 Literal param0_literal =
417 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
418 std::unique_ptr<GlobalData> param0_data =
419 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
420 Literal param1_literal =
421 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
422 std::unique_ptr<GlobalData> param1_data =
423 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
424 Literal param2_literal =
425 LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
426 std::unique_ptr<GlobalData> param2_data =
427 client_->TransferToServer(param2_literal).ConsumeValueOrDie();
428
429 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
430 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
431 auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
432 Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
433
434 ComputeAndCompareR1<float>(
435 &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
436 {param0_data.get(), param1_data.get(), param2_data.get()},
437 ErrorSpec(0.01f));
438 }
439
TEST_F(MapTest,MapGt)440 TEST_F(MapTest, MapGt) {
441 // Maps (x,y) -> x > y onto two R1F32 vectors.
442 XlaBuilder b(TestName());
443 auto gt = CreateGt();
444 Map(&b, {ConstantR1<float>(&b, {1, 20}), ConstantR1<float>(&b, {10, 2})}, gt,
445 {0});
446 ComputeAndCompareR1<bool>(&b, {false, true}, {});
447 }
448
TEST_F(MapTest,NestedBinaryMap)449 TEST_F(MapTest, NestedBinaryMap) {
450 XlaComputation max_with_square;
451 {
452 // max_with_square(x) = do max(x, x^2) via a map.
453 XlaBuilder b("max_with_square");
454 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
455 Map(&b, {x, Mul(x, x)}, CreateMax(), {});
456 auto computation_status = b.Build();
457 ASSERT_IS_OK(computation_status.status());
458 max_with_square = computation_status.ConsumeValueOrDie();
459 }
460 XlaBuilder b(TestName());
461 auto input = ConstantR1<float>(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
462 Map(&b, {input}, max_with_square, {0});
463 ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
464 }
465
TEST_F(MapTest,MapOperantionWithBuildError)466 TEST_F(MapTest, MapOperantionWithBuildError) {
467 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
468 // type combination (F32 + U16) to test that the error is reported to the
469 // outermost XlaBuilder.
470 XlaBuilder builder(TestName());
471
472 auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
473 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
474 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y");
475 Add(x, y);
476 auto error_add = sub_builder->BuildAndNoteError();
477
478 Literal param0_literal =
479 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
480 std::unique_ptr<GlobalData> param0_data =
481 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
482 Literal param1_literal =
483 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
484 std::unique_ptr<GlobalData> param1_data =
485 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
486
487 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
488 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
489 Map(&builder, {param0, param1}, error_add, {0});
490
491 StatusOr<XlaComputation> computation_status = builder.Build();
492 ASSERT_TRUE(!computation_status.ok());
493 EXPECT_THAT(computation_status.status().ToString(),
494 ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
495 "different element types: f32[] and u16[]"));
496 }
497
498 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all
499 // optimizations.
500 using MapTestWithFullOpt = ClientLibraryTestBase;
501
502 // Regression test for b/31466798. The inliner simplifies map(param0, param1,
503 // power) to power(param0, param1) without deleting the old subcomputation which
504 // is the same as the new entry computation. HloSubcomputationUnification used
505 // to have issues with such patterns and maybe invalidate the pointer to entry
506 // computation.
TEST_F(MapTestWithFullOpt,MapScalarPower)507 TEST_F(MapTestWithFullOpt, MapScalarPower) {
508 XlaBuilder builder(TestName());
509
510 auto sub_builder = builder.CreateSubBuilder("power");
511 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
512 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
513 Pow(x, y);
514 auto power = sub_builder->BuildAndNoteError();
515
516 Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
517 Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
518 std::unique_ptr<GlobalData> param0_data =
519 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
520 std::unique_ptr<GlobalData> param1_data =
521 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
522
523 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
524 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
525 Map(&builder, {param0, param1}, power, {});
526
527 ComputeAndCompareR0<float>(&builder, 32.0f,
528 {param0_data.get(), param1_data.get()},
529 ErrorSpec(0.01f));
530 }
531
532 // Regression test for b/35786417, where the inliner would not notice the change
533 // of parameter order inside the map.
TEST_F(MapTestWithFullOpt,MapSubtractOppositeOrder)534 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
535 XlaBuilder builder(TestName());
536
537 auto sub_builder = builder.CreateSubBuilder("power");
538 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
539 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
540 Sub(y, x); // note that this is y - x, not x - y
541 auto sub_opposite = sub_builder->BuildAndNoteError();
542
543 Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
544 Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
545 std::unique_ptr<GlobalData> param0_data =
546 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
547 std::unique_ptr<GlobalData> param1_data =
548 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
549
550 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
551 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
552 Map(&builder, {param0, param1}, sub_opposite, {});
553
554 ComputeAndCompareR0<float>(
555 &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
556 }
557
558 // Regression test for b/35786417, where the inliner would CHECK-fail due to the
559 // mul inside the map having more parameters than the map does.
TEST_F(MapTestWithFullOpt,MapSquare)560 TEST_F(MapTestWithFullOpt, MapSquare) {
561 XlaBuilder builder(TestName());
562
563 auto sub_builder = builder.CreateSubBuilder("power");
564 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
565 Mul(x, x);
566 auto square = sub_builder->BuildAndNoteError();
567
568 Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
569 std::unique_ptr<GlobalData> param0_data =
570 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
571
572 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
573 Map(&builder, {param0}, square, {});
574
575 ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
576 ErrorSpec(0.01f));
577 }
578
579 } // namespace
580 } // namespace xla
581