• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "tensorflow/compiler/xla/array2d.h"
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/tests/test_utils.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 namespace {
39 
40 class MapTest : public ClientLibraryTestBase {
41  public:
MapTest(se::Platform * platform=nullptr)42   explicit MapTest(se::Platform* platform = nullptr)
43       : ClientLibraryTestBase(platform) {
44     mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
45     mutable_debug_options()->add_xla_disable_hlo_passes("inline");
46   }
47 
48   // Creates a function that adds its scalar argument with the constant 1.0.
49   //
50   // x {R0F32} ----> (add)
51   //                /
52   // 1.0f ---------/
CreateAdderToOne()53   XlaComputation CreateAdderToOne() {
54     XlaBuilder mapped_builder(TestName());
55     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
56     auto one = ConstantR0<float>(&mapped_builder, 1.0);
57     Add(x, one);
58     auto computation_status = mapped_builder.Build();
59     TF_CHECK_OK(computation_status.status());
60     return computation_status.ConsumeValueOrDie();
61   }
62 
CreateMax()63   XlaComputation CreateMax() {
64     XlaBuilder b(TestName());
65     auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
66     auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
67     Max(lhs, rhs);
68     auto computation_status = b.Build();
69     TF_CHECK_OK(computation_status.status());
70     return computation_status.ConsumeValueOrDie();
71   }
72 
73   // Creates a computation that accepts an F32 and returns T(1) (ignoring the
74   // argument).
75   template <class T>
CreateScalarOne()76   XlaComputation CreateScalarOne() {
77     XlaBuilder mapped_builder("scalar_one");
78     (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
79     ConstantR0<T>(&mapped_builder, 1);
80     auto computation_status = mapped_builder.Build();
81     TF_CHECK_OK(computation_status.status());
82     return computation_status.ConsumeValueOrDie();
83   }
84 
85   // Creates a function that multiplies its scalar argument by the constant 2.0
86   //
87   // x {R0F32} ----> (mul)
88   //                /
89   // 2.0f ---------/
CreateMulByTwo()90   XlaComputation CreateMulByTwo() {
91     XlaBuilder mapped_builder(TestName());
92     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
93     auto two = ConstantR0<float>(&mapped_builder, 2.0);
94     Mul(x, two);
95     auto computation_status = mapped_builder.Build();
96     TF_CHECK_OK(computation_status.status());
97     return computation_status.ConsumeValueOrDie();
98   }
99 
100   // Creates a function that adds its scalar argument with the constant 1.0 and
101   // then multiplies by the original element.
102   //
103   //           /------------------|
104   //          /                   |
105   // x {R0F32} ----> (add) ----> (mul)
106   //                /
107   // 1.0f ---------/
CreateAdderToOneTimesItself()108   XlaComputation CreateAdderToOneTimesItself() {
109     XlaBuilder mapped_builder(TestName());
110     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
111     auto one = ConstantR0<float>(&mapped_builder, 1.0);
112     auto adder_to_one = Add(x, one);
113     Mul(x, adder_to_one);
114     auto computation_status = mapped_builder.Build();
115     TF_CHECK_OK(computation_status.status());
116     return computation_status.ConsumeValueOrDie();
117   }
118 
119   // Creates a function that takes a single parameter and calls map with
120   // "embedded_computation" on it, and then adds "n" to the result.
121   //
122   // x {R0F32} -----------> (map) ----> (add)
123   //                         /           /
124   // embedded_computation --/       n --/
CreateMapPlusN(const XlaComputation & embedded_computation,float n)125   XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation,
126                                 float n) {
127     XlaBuilder builder(TestName());
128     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
129     auto map = Map(&builder, {x}, embedded_computation, {});
130     auto constant_n = ConstantR0<float>(&builder, n);
131     Add(map, constant_n);
132     auto computation_status = builder.Build();
133     TF_CHECK_OK(computation_status.status());
134     return computation_status.ConsumeValueOrDie();
135   }
136 
137   // Creates a binary function with signature (F32, F32) -> Pred
138   // defined by (x, y) -> x > y.
CreateGt()139   XlaComputation CreateGt() {
140     XlaBuilder b("Gt");
141     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
142     auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
143     Gt(x, y);
144     auto computation_status = b.Build();
145     TF_CHECK_OK(computation_status.status());
146     return computation_status.ConsumeValueOrDie();
147   }
148 
149   // Creates a function that adds three scalar arguments
150   //
151   // x {R0F32} -------|
152   //                  |
153   // y {R0F32} ----> (add) ---> (add)
154   //                           /
155   // z {R0F32} ---------------/
CreateTernaryAdder()156   XlaComputation CreateTernaryAdder() {
157     XlaBuilder mapped_builder("TernaryAdder");
158     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
159     auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
160     auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z");
161     auto xy = Add(x, y);
162     Add(xy, z);
163     auto computation_status = mapped_builder.Build();
164     TF_CHECK_OK(computation_status.status());
165     return computation_status.ConsumeValueOrDie();
166   }
167 };
168 
TEST_F(MapTest,MapEachElemPlusOneR0)169 TEST_F(MapTest, MapEachElemPlusOneR0) {
170   // Applies lambda (x) (+ x 1)) to an input scalar.
171   XlaBuilder builder(TestName());
172   Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
173   std::unique_ptr<GlobalData> param0_data =
174       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
175 
176   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
177   Map(&builder, {param}, CreateAdderToOne(), {});
178 
179   ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
180                              ErrorSpec(0.01f));
181 }
182 
XLA_TEST_F(MapTest,MapEachElemPlusOneR1S0)183 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
184   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
185   XlaBuilder builder(TestName());
186   Literal param0_literal = LiteralUtil::CreateR1<float>({});
187   std::unique_ptr<GlobalData> param0_data =
188       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
189 
190   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
191   Map(&builder, {param}, CreateAdderToOne(), {0});
192 
193   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
194                              ErrorSpec(0.01f));
195 }
196 
TEST_F(MapTest,MapEachElemPlusOneR1S4)197 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
198   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
199   XlaBuilder builder(TestName());
200   Literal param0_literal =
201       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
202   std::unique_ptr<GlobalData> param0_data =
203       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
204 
205   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
206   Map(&builder, {param}, CreateAdderToOne(), {0});
207 
208   ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
209                              {param0_data.get()}, ErrorSpec(0.01f));
210 }
211 
TEST_F(MapTest,MapEachF32ElementToS32Constant)212 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
213   XlaBuilder builder(TestName());
214   Literal param0_literal =
215       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
216   std::unique_ptr<GlobalData> param0_data =
217       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
218 
219   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
220   Map(&builder, {param}, CreateScalarOne<int32>(), {0});
221 
222   ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
223 }
224 
TEST_F(MapTest,MapEachF32ElementToU32Constant)225 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
226   XlaBuilder builder(TestName());
227   Literal param0_literal =
228       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
229   std::unique_ptr<GlobalData> param0_data =
230       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
231 
232   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
233   Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
234 
235   ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
236 }
237 
TEST_F(MapTest,MapEachElemLongerChainR1)238 TEST_F(MapTest, MapEachElemLongerChainR1) {
239   // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
240   XlaBuilder builder(TestName());
241   Literal param0_literal =
242       LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
243   std::unique_ptr<GlobalData> param0_data =
244       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
245 
246   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
247   Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
248 
249   ComputeAndCompareR1<float>(
250       &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
251       {param0_data.get()}, ErrorSpec(0.01f));
252 }
253 
XLA_TEST_F(MapTest,MapMultipleMapsR1S0)254 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
255   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
256   // maps (lambda (x) (* x 2)) on the result.
257   XlaBuilder builder(TestName());
258   Literal param0_literal = LiteralUtil::CreateR1<float>({});
259   std::unique_ptr<GlobalData> param0_data =
260       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
261 
262   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
263   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
264   Map(&builder, {map1}, CreateMulByTwo(), {0});
265 
266   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
267                              ErrorSpec(0.01f));
268 }
269 
TEST_F(MapTest,MapMultipleMapsR1S4)270 TEST_F(MapTest, MapMultipleMapsR1S4) {
271   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
272   // maps (lambda (x) (* x 2)) on the result.
273   XlaBuilder builder(TestName());
274   Literal param0_literal =
275       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
276   std::unique_ptr<GlobalData> param0_data =
277       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
278 
279   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
280   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
281   Map(&builder, {map1}, CreateMulByTwo(), {0});
282 
283   ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
284                              {param0_data.get()}, ErrorSpec(0.01f));
285 }
286 
TEST_F(MapTest,MapEachElemPlusOneR2)287 TEST_F(MapTest, MapEachElemPlusOneR2) {
288   // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
289   XlaBuilder builder(TestName());
290   Literal param0_literal = LiteralUtil::CreateR2<float>(
291       {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
292   std::unique_ptr<GlobalData> param0_data =
293       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
294 
295   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
296   Map(&builder, {param}, CreateAdderToOne(), {0, 1});
297 
298   Array2D<float> expected_array(
299       {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
300   ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
301                              ErrorSpec(0.01f));
302 }
303 
XLA_TEST_F(MapTest,ComplexNestedMaps)304 XLA_TEST_F(MapTest, ComplexNestedMaps) {
305   // Constructs a complex graph of embedded computations to test the computation
306   // lowering order. Python equivalent:
307   //
308   //   embed1 = lambda x: x + 1                  #  x + 1
309   //   embed2 = lambda x: embed1(x) + 2          #  x + 3
310   //   embed3 = lambda x: embed1(x) + 4          #  x + 5
311   //   embed4 = lambda x: embed2(x) + embed3(x)  # 2x + 8
312   //   embed5 = lambda x: embed2(x) + 6          #  x + 9
313   //   result = embed5(42) + embed4(7)           # (42 + 9) + (2 * 7 + 8) = 73
314 
315   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
316 
317   auto embed1 = CreateAdderToOne();
318   auto embed2 = CreateMapPlusN(embed1, 2.0);
319   auto embed3 = CreateMapPlusN(embed1, 4.0);
320 
321   XlaBuilder embed4_builder("embed4");
322   auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x");
323   auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {});
324   auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {});
325   Add(embed4_map_lhs, embed4_map_rhs);
326   auto embed4_status = embed4_builder.Build();
327   ASSERT_IS_OK(embed4_status.status());
328   auto embed4 = embed4_status.ConsumeValueOrDie();
329 
330   auto embed5 = CreateMapPlusN(embed2, 6.0);
331 
332   XlaBuilder builder(TestName());
333   auto constant_42 = ConstantR0<float>(&builder, 42.0);
334   auto constant_7 = ConstantR0<float>(&builder, 7.0);
335   auto map_42 = Map(&builder, {constant_42}, embed5, {});
336   auto map_7 = Map(&builder, {constant_7}, embed4, {});
337   Add(map_42, map_7);
338 
339   ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
340 }
341 
TEST_F(MapTest,MapBinaryAdder)342 TEST_F(MapTest, MapBinaryAdder) {
343   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
344   XlaBuilder builder(TestName());
345   Literal param0_literal =
346       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
347   std::unique_ptr<GlobalData> param0_data =
348       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
349   Literal param1_literal =
350       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
351   std::unique_ptr<GlobalData> param1_data =
352       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
353 
354   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
355   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
356   Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
357       {0});
358 
359   ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
360                              {param0_data.get(), param1_data.get()},
361                              ErrorSpec(0.01f));
362 }
363 
364 // Adds two rank-2 arrays with different layouts. This test exercises a path
365 // for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest,AddWithMixedLayouts)366 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
367   XlaBuilder builder(TestName());
368   Literal param0_literal = LiteralUtil::CreateR2WithLayout(
369       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
370   std::unique_ptr<GlobalData> param0_data =
371       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
372 
373   Literal param1_literal = LiteralUtil::CreateR2WithLayout(
374       {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
375   std::unique_ptr<GlobalData> param1_data =
376       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
377 
378   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
379   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
380   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
381       {0, 1});
382 
383   Array2D<int32> expected(2, 2);
384   expected(0, 0) = 11;
385   expected(0, 1) = 22;
386   expected(1, 0) = 33;
387   expected(1, 1) = 44;
388   ComputeAndCompareR2<int32>(&builder, expected,
389                              {param0_data.get(), param1_data.get()});
390 }
391 
XLA_TEST_F(MapTest,AddR3_3x0x2)392 XLA_TEST_F(MapTest, AddR3_3x0x2) {
393   XlaBuilder builder(TestName());
394   Literal param0_literal =
395       LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
396   std::unique_ptr<GlobalData> param0_data =
397       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
398 
399   Literal param1_literal =
400       LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
401   std::unique_ptr<GlobalData> param1_data =
402       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
403 
404   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
405   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
406   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
407       {0, 1, 2});
408 
409   ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
410                              {param0_data.get(), param1_data.get()});
411 }
412 
TEST_F(MapTest,MapTernaryAdder)413 TEST_F(MapTest, MapTernaryAdder) {
414   // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
415   XlaBuilder builder(TestName());
416   Literal param0_literal =
417       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
418   std::unique_ptr<GlobalData> param0_data =
419       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
420   Literal param1_literal =
421       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
422   std::unique_ptr<GlobalData> param1_data =
423       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
424   Literal param2_literal =
425       LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
426   std::unique_ptr<GlobalData> param2_data =
427       client_->TransferToServer(param2_literal).ConsumeValueOrDie();
428 
429   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
430   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
431   auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
432   Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
433 
434   ComputeAndCompareR1<float>(
435       &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
436       {param0_data.get(), param1_data.get(), param2_data.get()},
437       ErrorSpec(0.01f));
438 }
439 
TEST_F(MapTest,MapGt)440 TEST_F(MapTest, MapGt) {
441   // Maps (x,y) -> x > y onto two R1F32 vectors.
442   XlaBuilder b(TestName());
443   auto gt = CreateGt();
444   Map(&b, {ConstantR1<float>(&b, {1, 20}), ConstantR1<float>(&b, {10, 2})}, gt,
445       {0});
446   ComputeAndCompareR1<bool>(&b, {false, true}, {});
447 }
448 
TEST_F(MapTest,NestedBinaryMap)449 TEST_F(MapTest, NestedBinaryMap) {
450   XlaComputation max_with_square;
451   {
452     // max_with_square(x) = do max(x, x^2) via a map.
453     XlaBuilder b("max_with_square");
454     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
455     Map(&b, {x, Mul(x, x)}, CreateMax(), {});
456     auto computation_status = b.Build();
457     ASSERT_IS_OK(computation_status.status());
458     max_with_square = computation_status.ConsumeValueOrDie();
459   }
460   XlaBuilder b(TestName());
461   auto input = ConstantR1<float>(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
462   Map(&b, {input}, max_with_square, {0});
463   ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
464 }
465 
TEST_F(MapTest,MapOperantionWithBuildError)466 TEST_F(MapTest, MapOperantionWithBuildError) {
467   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
468   // type combination (F32 + U16) to test that the error is reported to the
469   // outermost XlaBuilder.
470   XlaBuilder builder(TestName());
471 
472   auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
473   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
474   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y");
475   Add(x, y);
476   auto error_add = sub_builder->BuildAndNoteError();
477 
478   Literal param0_literal =
479       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
480   std::unique_ptr<GlobalData> param0_data =
481       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
482   Literal param1_literal =
483       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
484   std::unique_ptr<GlobalData> param1_data =
485       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
486 
487   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
488   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
489   Map(&builder, {param0, param1}, error_add, {0});
490 
491   StatusOr<XlaComputation> computation_status = builder.Build();
492   ASSERT_TRUE(!computation_status.ok());
493   EXPECT_THAT(computation_status.status().ToString(),
494               ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
495                                    "different element types: f32[] and u16[]"));
496 }
497 
498 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all
499 // optimizations.
500 using MapTestWithFullOpt = ClientLibraryTestBase;
501 
502 // Regression test for b/31466798. The inliner simplifies map(param0, param1,
503 // power) to power(param0, param1) without deleting the old subcomputation which
504 // is the same as the new entry computation. HloSubcomputationUnification used
505 // to have issues with such patterns and maybe invalidate the pointer to entry
506 // computation.
TEST_F(MapTestWithFullOpt,MapScalarPower)507 TEST_F(MapTestWithFullOpt, MapScalarPower) {
508   XlaBuilder builder(TestName());
509 
510   auto sub_builder = builder.CreateSubBuilder("power");
511   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
512   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
513   Pow(x, y);
514   auto power = sub_builder->BuildAndNoteError();
515 
516   Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
517   Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
518   std::unique_ptr<GlobalData> param0_data =
519       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
520   std::unique_ptr<GlobalData> param1_data =
521       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
522 
523   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
524   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
525   Map(&builder, {param0, param1}, power, {});
526 
527   ComputeAndCompareR0<float>(&builder, 32.0f,
528                              {param0_data.get(), param1_data.get()},
529                              ErrorSpec(0.01f));
530 }
531 
532 // Regression test for b/35786417, where the inliner would not notice the change
533 // of parameter order inside the map.
TEST_F(MapTestWithFullOpt,MapSubtractOppositeOrder)534 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
535   XlaBuilder builder(TestName());
536 
537   auto sub_builder = builder.CreateSubBuilder("power");
538   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
539   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
540   Sub(y, x);  // note that this is y - x, not x - y
541   auto sub_opposite = sub_builder->BuildAndNoteError();
542 
543   Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
544   Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
545   std::unique_ptr<GlobalData> param0_data =
546       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
547   std::unique_ptr<GlobalData> param1_data =
548       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
549 
550   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
551   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
552   Map(&builder, {param0, param1}, sub_opposite, {});
553 
554   ComputeAndCompareR0<float>(
555       &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
556 }
557 
558 // Regression test for b/35786417, where the inliner would CHECK-fail due to the
559 // mul inside the map having more parameters than the map does.
TEST_F(MapTestWithFullOpt,MapSquare)560 TEST_F(MapTestWithFullOpt, MapSquare) {
561   XlaBuilder builder(TestName());
562 
563   auto sub_builder = builder.CreateSubBuilder("power");
564   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
565   Mul(x, x);
566   auto square = sub_builder->BuildAndNoteError();
567 
568   Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
569   std::unique_ptr<GlobalData> param0_data =
570       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
571 
572   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
573   Map(&builder, {param0}, square, {});
574 
575   ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
576                              ErrorSpec(0.01f));
577 }
578 
579 }  // namespace
580 }  // namespace xla
581