1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <vector>
19
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 using ::testing::ElementsAre;
40 using ::testing::HasSubstr;
41
42 class LiteralUtilTest : public ::testing::Test {
43 protected:
LiteralUtilTest()44 LiteralUtilTest() {
45 Array4D<float> arr4d({
46 // clang-format off
47 { // i0=0
48 { // i1=0
49 {1, 2, 3}, // i2=0
50 {4, 5, 6}, // i2=1
51 {7, 8, 9}, // i2=2
52 },
53 { // i1=1
54 {11, 12, 13},
55 {14, 15, 16},
56 {17, 18, 19},
57 },
58 },
59 { // i0=1
60 { // i1=0
61 {101, 102, 103},
62 {104, 105, 106},
63 {107, 108, 109},
64 },
65 { // i1=1
66 {201, 202, 203}, // i2=0
67 {204, 205, 206}, // i2=1
68 {207, 208, 209}, // i2=2
69 },
70 },
71 // clang-format on
72 });
73
74 layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
75 layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
76 layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
77 layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
78 layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
79 layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
80
81 literal_r4_2x2x3x3_dim0major_ =
82 LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
83 layout_r4_dim0major_);
84 literal_r4_2x2x3x3_dim0minor_ =
85 LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
86 layout_r4_dim0minor_);
87 }
88
89 Layout layout_r2_dim0major_;
90 Layout layout_r2_dim0minor_;
91 Layout layout_r3_dim0major_;
92 Layout layout_r3_dim0minor_;
93 Layout layout_r4_dim0major_;
94 Layout layout_r4_dim0minor_;
95 Literal literal_r4_2x2x3x3_dim0major_;
96 Literal literal_r4_2x2x3x3_dim0minor_;
97 };
98
TEST_F(LiteralUtilTest,LiteralScalarToString)99 TEST_F(LiteralUtilTest, LiteralScalarToString) {
100 auto true_lit = LiteralUtil::CreateR0<bool>(true);
101 EXPECT_EQ("pred[] true", true_lit.ToString());
102
103 auto false_lit = LiteralUtil::CreateR0<bool>(false);
104 EXPECT_EQ("pred[] false", false_lit.ToString());
105
106 auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
107 EXPECT_EQ("u32[] 42", u32_lit.ToString());
108
109 auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
110 EXPECT_EQ("s32[] -999", s32_lit.ToString());
111
112 auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
113 EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
114
115 auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
116 EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
117
118 auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
119 EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
120
121 auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14f, 2.78f});
122 EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
123
124 auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
125 EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
126
127 // 3.14 will be rounded to 3.14062 in bfloat16 format.
128 auto bf16_lit_truncated =
129 LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
130 ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString());
131
132 auto bf16_lit_truncated2 =
133 LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
134 EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
135 }
136
TEST_F(LiteralUtilTest,LiteralVectorToString)137 TEST_F(LiteralUtilTest, LiteralVectorToString) {
138 auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
139 EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
140 }
141
TEST_F(LiteralUtilTest,R2ToString)142 TEST_F(LiteralUtilTest, R2ToString) {
143 const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
144 const string expected = R"(s32[3,2] {
145 { 1, 2 },
146 { 3, 4 },
147 { 5, 6 }
148 })";
149 EXPECT_EQ(expected, literal.ToString());
150 }
151
TEST_F(LiteralUtilTest,R3ToString)152 TEST_F(LiteralUtilTest, R3ToString) {
153 const auto literal =
154 LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
155 const string expected = R"(s32[3,2,1] {
156 {
157 {1},
158 {2}
159 },
160 {
161 {3},
162 {4}
163 },
164 {
165 {5},
166 {6}
167 }
168 })";
169 EXPECT_EQ(expected, literal.ToString());
170 }
171
TEST_F(LiteralUtilTest,R6ToString)172 TEST_F(LiteralUtilTest, R6ToString) {
173 const auto literal =
174 LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
175 const string expected = R"(s32[2,2,1,1,1,2] {
176 { /*i0=0*/
177 { /*i1=0*/
178 { /*i2=0*/
179 { /*i3=0*/
180 { 0, 0 }
181 }
182 }
183 },
184 { /*i1=1*/
185 { /*i2=0*/
186 { /*i3=0*/
187 { 0, 0 }
188 }
189 }
190 }
191 },
192 { /*i0=1*/
193 { /*i1=0*/
194 { /*i2=0*/
195 { /*i3=0*/
196 { 0, 0 }
197 }
198 }
199 },
200 { /*i1=1*/
201 { /*i2=0*/
202 { /*i3=0*/
203 { 0, 0 }
204 }
205 }
206 }
207 }
208 })";
209 EXPECT_EQ(expected, literal.ToString());
210 }
211
TEST_F(LiteralUtilTest,TupleToString)212 TEST_F(LiteralUtilTest, TupleToString) {
213 auto scalar = LiteralUtil::CreateR0<float>(1.0);
214 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
215 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
216 const string expected = R"((
217 f32[] 1,
218 f32[2,2] {
219 { 1, 2 },
220 { 3, 4 }
221 }
222 ))";
223 EXPECT_EQ(expected, tuple.ToString());
224 }
225
TEST_F(LiteralUtilTest,CreateR3FromArray3d)226 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
227 // clang-format off
228 Array3D<float> array_3d({
229 {{1.0f, 2.0f},
230 {3.0f, 4.0f},
231 {5.0f, 6.0f}},
232 {{7.0f, 8.0f},
233 {9.0f, 10.0f},
234 {11.0f, 12.0f}},
235 });
236 // clang-format on
237
238 auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
239 EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
240 string result = literal.ToString();
241 const string expected = R"(f32[2,3,2] {
242 {
243 { 1, 2 },
244 { 3, 4 },
245 { 5, 6 }
246 },
247 {
248 { 7, 8 },
249 { 9, 10 },
250 { 11, 12 }
251 }
252 })";
253 EXPECT_EQ(expected, result);
254 }
255
TEST_F(LiteralUtilTest,CreateSparse)256 TEST_F(LiteralUtilTest, CreateSparse) {
257 std::vector<int64> dimensions = {8, 8, 8};
258 Array2D<int64> indices = {
259 {3, 4, 5},
260 {1, 2, 3},
261 {2, 3, 4},
262 {3, 5, 6},
263 };
264 std::vector<int64> values = {7, 8, 9, 10};
265 auto literal = LiteralUtil::CreateSparse<int64>(
266 dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
267
268 Array2D<int64> expected_indices = {
269 {1, 2, 3},
270 {2, 3, 4},
271 {3, 4, 5},
272 {3, 5, 6},
273 };
274 std::vector<int64> expected_values = {8, 9, 7, 10};
275
276 EXPECT_EQ(literal.sparse_indices()->data(),
277 absl::Span<const int64>(expected_indices.data(),
278 expected_indices.num_elements()));
279 EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
280
281 // Serialize then deserialize and verify the resulting literal.
282 TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
283 Literal::CreateFromProto(literal.ToProto()));
284
285 EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
286 absl::Span<const int64>(expected_indices.data(),
287 expected_indices.num_elements()));
288 EXPECT_EQ(literal_from_proto.data<int64>(),
289 absl::Span<const int64>(expected_values));
290 }
291
TEST_F(LiteralUtilTest,LiteralR4F32ProjectedStringifies)292 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
293 // clang-format off
294 auto literal = LiteralUtil::CreateR4Projected<float>({
295 {1, 2},
296 {1001, 1002},
297 {2001, 2002},
298 }, /*projection_p=*/1, /*projection_z=*/2);
299 // clang-format on
300 EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
301 string result = literal.ToString();
302 const string expected = R"(f32[1,2,3,2] {
303 { /*i0=0*/
304 { /*i1=0*/
305 { 1, 2 },
306 { 1001, 1002 },
307 { 2001, 2002 }
308 },
309 { /*i1=1*/
310 { 1, 2 },
311 { 1001, 1002 },
312 { 2001, 2002 }
313 }
314 }
315 })";
316 EXPECT_EQ(expected, result);
317 }
318
TEST_F(LiteralUtilTest,LiteralR4F32Stringifies)319 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
320 EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
321 ElementsAre(2, 2, 3, 3));
322 string result = literal_r4_2x2x3x3_dim0major_.ToString();
323 const string expected = R"(f32[2,2,3,3] {
324 { /*i0=0*/
325 { /*i1=0*/
326 { 1, 2, 3 },
327 { 4, 5, 6 },
328 { 7, 8, 9 }
329 },
330 { /*i1=1*/
331 { 11, 12, 13 },
332 { 14, 15, 16 },
333 { 17, 18, 19 }
334 }
335 },
336 { /*i0=1*/
337 { /*i1=0*/
338 { 101, 102, 103 },
339 { 104, 105, 106 },
340 { 107, 108, 109 }
341 },
342 { /*i1=1*/
343 { 201, 202, 203 },
344 { 204, 205, 206 },
345 { 207, 208, 209 }
346 }
347 }
348 })";
349 EXPECT_EQ(expected, result);
350 }
351
TEST_F(LiteralUtilTest,EachCellR2F32)352 TEST_F(LiteralUtilTest, EachCellR2F32) {
353 // clang-format off
354 auto literal = LiteralUtil::CreateR2<float>({
355 {3.1f, 4.2f},
356 {9.3f, 12.4f},
357 });
358 // clang-format on
359 std::vector<std::tuple<int64, int64, string>> seen;
360 literal.EachCellAsString(
361 [&seen](absl::Span<const int64> indices, const string& value) {
362 seen.emplace_back(indices[0], indices[1], value);
363 });
364
365 using Elem = std::tuple<int64, int64, string>;
366 std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
367 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
368 EXPECT_EQ(expected, seen);
369 }
370
TEST_F(LiteralUtilTest,ScalarEquality)371 TEST_F(LiteralUtilTest, ScalarEquality) {
372 // Test equality with scalars.
373 auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
374 auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
375
376 EXPECT_EQ(f32_42, f32_42);
377 EXPECT_EQ(f32_42, f32_42_clone);
378
379 auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
380 EXPECT_NE(f32_42, f32_123);
381
382 auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
383 EXPECT_NE(f32_42, f64_42);
384 }
385
TEST_F(LiteralUtilTest,NonScalarEquality)386 TEST_F(LiteralUtilTest, NonScalarEquality) {
387 // Test equality with nonscalars.
388 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
389 auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
390 auto matrix_different =
391 LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
392 auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
393 auto scalar = LiteralUtil::CreateR0<float>(1.0);
394 Literal nil(ShapeUtil::MakeNil());
395
396 EXPECT_EQ(matrix, matrix);
397 EXPECT_EQ(matrix, matrix_clone);
398 EXPECT_NE(matrix, matrix_different);
399 EXPECT_NE(matrix, vector_literal);
400 EXPECT_NE(matrix, scalar);
401 EXPECT_NE(matrix, nil);
402 EXPECT_EQ(nil, nil);
403 }
404
TEST_F(LiteralUtilTest,TokenEquality)405 TEST_F(LiteralUtilTest, TokenEquality) {
406 auto token0 = LiteralUtil::CreateToken();
407 auto token1 = LiteralUtil::CreateToken();
408 auto scalar = LiteralUtil::CreateR0<float>(1.0);
409
410 EXPECT_EQ(token0, token1);
411 EXPECT_NE(token0, scalar);
412
413 EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
414 LiteralUtil::MakeTuple({&token0}));
415 EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
416 LiteralUtil::MakeTuple({&token1, &scalar}));
417 EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
418 LiteralUtil::MakeTuple({&scalar, &token1}));
419 }
420
TEST_F(LiteralUtilTest,DifferentLayoutEquality)421 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
422 // Test equality with literals which have different layouts.
423 Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
424 colmajor.Set<float>({0, 0}, 1.0);
425 colmajor.Set<float>({0, 1}, 2.0);
426 colmajor.Set<float>({1, 0}, 3.0);
427 colmajor.Set<float>({1, 1}, 4.0);
428
429 Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
430 rowmajor.Set<float>({0, 0}, 1.0);
431 rowmajor.Set<float>({0, 1}, 2.0);
432 rowmajor.Set<float>({1, 0}, 3.0);
433 rowmajor.Set<float>({1, 1}, 4.0);
434
435 EXPECT_EQ(rowmajor, colmajor);
436 }
437
TEST_F(LiteralUtilTest,TupleEquality)438 TEST_F(LiteralUtilTest, TupleEquality) {
439 // Test equality with tuples.
440 auto scalar = LiteralUtil::CreateR0<float>(1.0);
441 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
442 auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
443
444 // Tuple with the same elements. One element is shared with the original
445 // tuple, the other is a clone of the element in the original tuple.
446 auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
447 auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
448 EXPECT_EQ(tuple1, tuple2);
449
450 // Tuple with elements reversed.
451 auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
452 EXPECT_NE(tuple1, reversed_tuple);
453
454 // Tuple with different value.
455 auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
456 auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
457 EXPECT_NE(tuple1, different_tuple);
458 }
459
TEST_F(LiteralUtilTest,C64Equality)460 TEST_F(LiteralUtilTest, C64Equality) {
461 // Test equality with tuples.
462 auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
463
464 // Tuple with the same elements. One element is shared with the original
465 // tuple, the other is a clone of the element in the original tuple.
466 auto vector_clone =
467 LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
468 EXPECT_EQ(vector, vector_clone);
469
470 auto vector_reversed =
471 LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
472 EXPECT_NE(vector, vector_reversed);
473 }
474
TEST_F(LiteralUtilTest,C128Equality)475 TEST_F(LiteralUtilTest, C128Equality) {
476 // Test equality with tuples.
477 auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
478
479 // Tuple with the same elements. One element is shared with the original
480 // tuple, the other is a clone of the element in the original tuple.
481 auto vector_clone =
482 LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
483 EXPECT_EQ(vector, vector_clone);
484
485 auto vector_reversed =
486 LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
487 EXPECT_NE(vector, vector_reversed);
488 }
489
TEST_F(LiteralUtilTest,IsAllTuple)490 TEST_F(LiteralUtilTest, IsAllTuple) {
491 auto element1 = LiteralUtil::CreateR0<float>(0.0);
492 auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
493 auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
494
495 // Tuples should always return false for IsAll.
496 EXPECT_FALSE(tuple.IsAll(0));
497 EXPECT_FALSE(tuple.IsAll(1));
498 }
499
500 // Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest,CreateFromShapeTuple)501 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
502 auto scalar = LiteralUtil::CreateR0<float>(0.0);
503 auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
504 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
505
506 auto x = Literal::CreateFromShape(tuple.shape());
507 EXPECT_EQ(tuple, x);
508 }
509
TEST_F(LiteralUtilTest,IsAll)510 TEST_F(LiteralUtilTest, IsAll) {
511 EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
512 EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
513 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
514 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
515 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
516 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
517 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
518
519 // We shouldn't reinterpret int8_min as an unsigned type and then decide that
520 // it is equal to 255.
521 auto int8_min = std::numeric_limits<int8>::min();
522 EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
523
524 EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
525 EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
526
527 EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
528 EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
529
530 EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
531 EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
532 EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
533
534 half h8(8.0f);
535 half h9(9.0f);
536 EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
537 EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
538 EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
539
540 bfloat16 b8(8.0f);
541 bfloat16 b9(9.0f);
542
543 EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
544 EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
545 EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
546
547 // 9.001 will be truncated to 9.0
548 bfloat16 b91(9.001f);
549 bfloat16 b90(9.00f);
550 EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
551
552 complex64 c8_9 = {8, 9};
553 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
554
555 auto uint64_max = std::numeric_limits<uint64>::max();
556 EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
557 {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
558 .IsAll(-1));
559 }
560
TEST_F(LiteralUtilTest,IsAllFloat)561 TEST_F(LiteralUtilTest, IsAllFloat) {
562 // IsAllFloat always returns false when the literal is not floating-point.
563 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
564 EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
565 EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
566 EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
567
568 EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
569 EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
570 EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
571 EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
572 EXPECT_FALSE(
573 LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
574 EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
575 .IsAllFloat(.5));
576
577 EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
578 EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
579 EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
580 EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
581 EXPECT_FALSE(
582 LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
583 }
584
TEST_F(LiteralUtilTest,IsAllComplex)585 TEST_F(LiteralUtilTest, IsAllComplex) {
586 // IsAllComplex always returns false when the literal is not complex.
587 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
588 EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
589 EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
590 EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
591 EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
592 EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
593
594 complex64 c8_9 = {8, 9};
595 complex64 c7_9 = {7, 9};
596 EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
597 .IsAllComplex({8.0f, 9.0f}));
598 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
599 .IsAllComplex({8.0f, 9.0f}));
600 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
601 .IsAllComplex({8.0f, 9.0f}));
602 }
603
TEST_F(LiteralUtilTest,IsAllFirst)604 TEST_F(LiteralUtilTest, IsAllFirst) {
605 // IsAllComplex always returns false when the literal is not complex.
606 EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
607 EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
608 EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
609 EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
610 EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
611 EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
612 EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
613 EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
614 EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
615
616 complex64 c8_9 = {8, 9};
617 complex64 c7_9 = {7, 9};
618 EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
619 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
620 }
621
TEST_F(LiteralUtilTest,IsZero)622 TEST_F(LiteralUtilTest, IsZero) {
623 auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
624 auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
625 EXPECT_TRUE(scalar_zero.IsZero({}));
626 EXPECT_FALSE(scalar_one.IsZero({}));
627
628 auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
629 EXPECT_FALSE(array.IsZero({0, 1}));
630 EXPECT_TRUE(array.IsZero({0, 2}));
631 EXPECT_TRUE(array.IsZero({1, 1}));
632 EXPECT_FALSE(array.IsZero({1, 2}));
633
634 auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
635 auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
636 EXPECT_TRUE(complex_zero.IsZero({}));
637 EXPECT_FALSE(complex_nonzero.IsZero({}));
638 }
639
640 template <typename T>
641 class LiteralUtilTestTemplated : public ::testing::Test {};
642
643 using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
644 TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
645
TYPED_TEST(LiteralUtilTestTemplated,Relayout2x2)646 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
647 // Make a non-integer for floating point types.
648 TypeParam half = TypeParam(1) / TypeParam(2);
649 auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
650 const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
651 const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
652
653 auto data01 = data.Relayout(layout01);
654 EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
655 EXPECT_EQ(data, data01);
656
657 auto data10 = data.Relayout(layout10);
658 EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
659 EXPECT_EQ(data, data10);
660 }
661
TEST_F(LiteralUtilTest,ReshapeR0)662 TEST_F(LiteralUtilTest, ReshapeR0) {
663 auto original = LiteralUtil::CreateR0<float>(1.7f);
664 auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
665 EXPECT_EQ(original, reshape);
666 }
667
TEST_F(LiteralUtilTest,ReshapeR4)668 TEST_F(LiteralUtilTest, ReshapeR4) {
669 // clang-format off
670 // F32[1x3x2x4]
671 auto original = LiteralUtil::CreateR4WithLayout<float>({{
672 {{10, 11, 12, 13}, {14, 15, 16, 17}},
673 {{18, 19, 20, 21}, {22, 23, 24, 25}},
674 {{26, 27, 28, 29}, {30, 31, 32, 33}},
675 }}, layout_r4_dim0major_);
676 // F32[1x3x4x2]
677 auto expected = LiteralUtil::CreateR3WithLayout<float>({
678 {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
679 {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
680 {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
681 }, layout_r3_dim0major_);
682 // clang-format on
683 auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
684
685 EXPECT_EQ(expected, reshape);
686 }
687
TEST_F(LiteralUtilTest,ReshapeR4Dim0Minor)688 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
689 // clang-format off
690 // F32[1x3x2x4]
691 auto original = LiteralUtil::CreateR4WithLayout<float>({{
692 {{10, 11, 12, 13}, {14, 15, 16, 17}},
693 {{18, 19, 20, 21}, {22, 23, 24, 25}},
694 {{26, 27, 28, 29}, {30, 31, 32, 33}},
695 }}, layout_r4_dim0minor_);
696 // F32[1x3x4x2]
697 auto expected = LiteralUtil::CreateR3WithLayout<float>({
698 {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
699 {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
700 {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
701 }, layout_r3_dim0major_);
702 // clang-format on
703 auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
704
705 EXPECT_EQ(expected, reshape);
706 }
707
TEST_F(LiteralUtilTest,TransposeR0)708 TEST_F(LiteralUtilTest, TransposeR0) {
709 auto original = LiteralUtil::CreateR0<float>(1.7f);
710 auto reshape = original.Transpose(/*permutation=*/{});
711 EXPECT_EQ(original, reshape);
712 }
713
TEST_F(LiteralUtilTest,TransposeR4)714 TEST_F(LiteralUtilTest, TransposeR4) {
715 // clang-format off
716 // F32[1x3x2x4]
717 auto original = LiteralUtil::CreateR4<float>({{
718 {{10, 11, 12, 13}, {14, 15, 16, 17}},
719 {{18, 19, 20, 21}, {22, 23, 24, 25}},
720 {{26, 27, 28, 29}, {30, 31, 32, 33}},
721 }});
722 // clang-format on
723 auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
724
725 reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
726 EXPECT_EQ(value, original.Get<float>(
727 {indices[2], indices[3], indices[0], indices[1]}));
728 });
729 }
730
TEST_F(LiteralUtilTest,TestR4RelayoutEquivalence)731 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
732 // Tests that using Relayout on an array is equivalent to creating it in the
733 // target layout in the first place.
734 auto dim0minor_relaid_to_dim0major =
735 literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
736 EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
737
738 auto dim0major_relaid_to_dim0minor =
739 literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
740 EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
741 }
742
TEST_F(LiteralUtilTest,TestR2LinearLayout)743 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
744 // Test expected memory layout of R2 dim0-minor (column-major) literal.
745 auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
746 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
747 EXPECT_EQ(mat_dim0minor.element_count(), 6);
748 EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
749
750 // Test expected memory layout when using Relayout to row major.
751 auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
752 EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
753 ElementsAre(1, 2, 3, 4, 5, 6));
754
755 // Test expected memory layout of R2 created with dim0-major (row-major).
756 auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
757 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
758 EXPECT_EQ(mat_dim0major.element_count(), 6);
759 EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
760
761 // Test expected memory layout when using Relayout to column major.
762 auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
763 EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
764 ElementsAre(1, 4, 2, 5, 3, 6));
765 }
766
TEST_F(LiteralUtilTest,TestR3LinearLayout)767 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
768 // Test expected memory layout of R3 dim0-minor (column-major) literal.
769 Array3D<int> arr3d(
770 // clang-format off
771 {
772 {
773 {1, 2, 3},
774 {4, 5, 6},
775 },
776 {
777 {7, 8, 9},
778 {10, 11, 12},
779 },
780 }); // clang-format on
781 auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
782 arr3d, layout_r3_dim0minor_);
783
784 EXPECT_EQ(lit_dim0minor.element_count(), 12);
785 std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
786 EXPECT_THAT(lit_dim0minor.data<int32>(),
787 testing::ElementsAreArray(expected_dim0minor));
788
789 // Test expected memory layout when using Relayout to row major.
790 auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
791 std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
792 EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
793 testing::ElementsAreArray(expected_dim0major));
794
795 // Test expected memory layout of R3 created with dim0-major (row-major).
796 auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
797 arr3d, layout_r3_dim0major_);
798 EXPECT_EQ(lit_dim0major.element_count(), 12);
799 EXPECT_THAT(lit_dim0major.data<int32>(),
800 testing::ElementsAreArray(expected_dim0major));
801
802 // Test expected memory layout when using Relayout to column major.
803 auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
804 EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
805 testing::ElementsAreArray(expected_dim0minor));
806 }
807
TEST_F(LiteralUtilTest,SliceR0S32)808 TEST_F(LiteralUtilTest, SliceR0S32) {
809 auto input = LiteralUtil::CreateR0<int32>(1);
810 auto result = input.Slice({}, {});
811 EXPECT_EQ(input, result);
812 }
813
TEST_F(LiteralUtilTest,SliceR1F32)814 TEST_F(LiteralUtilTest, SliceR1F32) {
815 auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
816 auto result = input.Slice({3}, {4});
817 auto expected = LiteralUtil::CreateR1<float>({4.0});
818 EXPECT_EQ(expected, result);
819 }
820
TEST_F(LiteralUtilTest,SliceR2U32)821 TEST_F(LiteralUtilTest, SliceR2U32) {
822 auto input_3x4 = LiteralUtil::CreateR2<uint32>(
823 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
824 auto result = input_3x4.Slice({0, 2}, {2, 4});
825 auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
826 EXPECT_EQ(expected, result);
827 }
828
TEST_F(LiteralUtilTest,SliceR3U32Full)829 TEST_F(LiteralUtilTest, SliceR3U32Full) {
830 auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
831 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
832 auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
833 EXPECT_EQ(input_2x3x2, result);
834 }
835
TEST_F(LiteralUtilTest,PopulateR1S64)836 TEST_F(LiteralUtilTest, PopulateR1S64) {
837 Literal output(ShapeUtil::MakeShape(S64, {1}));
838 output.PopulateR1<int64>({77});
839 auto expected = LiteralUtil::CreateR1<int64>({77});
840 EXPECT_EQ(output, expected);
841 }
842
TEST_F(LiteralUtilTest,PopulateR1U64)843 TEST_F(LiteralUtilTest, PopulateR1U64) {
844 Literal output(ShapeUtil::MakeShape(U64, {2}));
845 output.PopulateR1<uint64>({{77, 88}});
846 auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
847 EXPECT_EQ(output, expected);
848 }
849
TEST_F(LiteralUtilTest,PopulateR1C64)850 TEST_F(LiteralUtilTest, PopulateR1C64) {
851 Literal output(ShapeUtil::MakeShape(C64, {1}));
852 output.PopulateR1<complex64>({{77, 88}});
853 auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
854 EXPECT_EQ(output, expected);
855 }
856
TEST_F(LiteralUtilTest,PopulateR1C128)857 TEST_F(LiteralUtilTest, PopulateR1C128) {
858 Literal output(ShapeUtil::MakeShape(C128, {1}));
859 output.PopulateR1<complex128>({{77, 88}});
860 auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
861 EXPECT_EQ(output, expected);
862 }
863
TEST_F(LiteralUtilTest,PopulateR2C64)864 TEST_F(LiteralUtilTest, PopulateR2C64) {
865 Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
866 output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
867 auto expected =
868 LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
869 EXPECT_EQ(output, expected);
870 }
871
TEST_F(LiteralUtilTest,PopulateWithValueR0BF16)872 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
873 Literal output(ShapeUtil::MakeShape(BF16, {}));
874 bfloat16 h(0.25f);
875 output.PopulateWithValue<bfloat16>(h);
876 auto expected = LiteralUtil::CreateR0<bfloat16>(h);
877 EXPECT_EQ(output, expected);
878 }
879
TEST_F(LiteralUtilTest,PopulateWithValueR1BF16)880 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
881 Literal output(ShapeUtil::MakeShape(BF16, {3}));
882 bfloat16 h(0.5f);
883 output.PopulateWithValue<bfloat16>(h);
884 auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
885 EXPECT_EQ(output, expected);
886 }
887
TEST_F(LiteralUtilTest,PopulateWithValueR2BF16)888 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
889 Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
890 bfloat16 h(2.0f);
891 output.PopulateWithValue<bfloat16>(h);
892 auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
893 EXPECT_EQ(output, expected);
894 }
895
TEST_F(LiteralUtilTest,PopulateWithValueR0F32)896 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
897 Literal output(ShapeUtil::MakeShape(F32, {}));
898 output.PopulateWithValue<float>(2.5f);
899 auto expected = LiteralUtil::CreateR0<float>(2.5f);
900 EXPECT_EQ(output, expected);
901 }
902
TEST_F(LiteralUtilTest,PopulateWithValueR1S64)903 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
904 Literal output(ShapeUtil::MakeShape(S64, {3}));
905 output.PopulateWithValue<int64>(-7);
906 auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
907 EXPECT_EQ(output, expected);
908 }
909
TEST_F(LiteralUtilTest,PopulateWithValueR2U64)910 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
911 Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
912 output.PopulateWithValue<uint64>(42);
913 auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
914 EXPECT_EQ(output, expected);
915 }
916
TEST_F(LiteralUtilTest,PopulateWithValueR2C64)917 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
918 Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
919 output.PopulateWithValue<complex64>({4, 2});
920 auto expected =
921 LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
922 EXPECT_EQ(output, expected);
923 }
924
TEST_F(LiteralUtilTest,PopulateWithValueR2C128)925 TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
926 Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
927 output.PopulateWithValue<complex128>({4, 2});
928 auto expected =
929 LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
930 EXPECT_EQ(output, expected);
931 }
932
TEST_F(LiteralUtilTest,PopulateWithValueR0F16)933 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
934 Literal output(ShapeUtil::MakeShape(F16, {}));
935 half h(0.25f);
936 output.PopulateWithValue<half>(h);
937 auto expected = LiteralUtil::CreateR0<half>(h);
938 EXPECT_EQ(output, expected);
939 }
940
TEST_F(LiteralUtilTest,PopulateWithValueR1F16)941 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
942 Literal output(ShapeUtil::MakeShape(F16, {3}));
943 half h(0.5f);
944 output.PopulateWithValue<half>(h);
945 auto expected = LiteralUtil::CreateR1<half>({h, h, h});
946 EXPECT_EQ(output, expected);
947 }
948
TEST_F(LiteralUtilTest,PopulateWithValueR2F16)949 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
950 Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
951 half h(2.0f);
952 output.PopulateWithValue<half>(h);
953 auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
954 EXPECT_EQ(output, expected);
955 }
956
TEST_F(LiteralUtilTest,ReplicateR2U32)957 TEST_F(LiteralUtilTest, ReplicateR2U32) {
958 auto input = LiteralUtil::CreateR2<uint32>(
959 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
960 auto output = input.Replicate<uint32>(3);
961 auto expected = LiteralUtil::CreateR3<uint32>(
962 {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
963 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
964 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
965 EXPECT_EQ(output, expected);
966 }
967
TEST_F(LiteralUtilTest,CopySliceFrom)968 TEST_F(LiteralUtilTest, CopySliceFrom) {
969 const int64 dimensions[] = {17, 15, 34, 21};
970 const int64 layouts[][4] = {
971 {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
972 for (const auto& layout : layouts) {
973 Shape shape = ShapeUtil::MakeShapeWithLayout(
974 primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
975
976 auto source = Literal::CreateFromShape(shape);
977 const int64 zero_base[] = {0, 0, 0, 0};
978 const int64 step[] = {1, 1, 1, 1};
979 uint32 seqnr = 0;
980 auto init_proc = [&](absl::Span<const int64> indexes) {
981 source.Set(indexes, ++seqnr);
982 return true;
983 };
984 ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
985 init_proc);
986
987 auto blank = Literal::CreateFromShape(shape);
988 const int64 src_base[] = {3, 1, 5, 7};
989 const int64 dest_base[] = {6, 4, 12, 2};
990 const int64 copy_size[] = {7, 8, 11, 9};
991 TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
992
993 std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
994 std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
995 bool matched = true;
996 auto check_proc = [&](absl::Span<const int64> indexes) {
997 std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
998 std::transform(source_indexes.begin(), source_indexes.end(), src_base,
999 source_indexes.begin(), std::plus<int64>());
1000 std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
1001 std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
1002 blank_indexes.begin(), std::plus<int64>());
1003 auto bval = blank.Get<uint32>(blank_indexes);
1004 matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
1005 return matched;
1006 };
1007
1008 ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
1009 check_proc);
1010 EXPECT_TRUE(matched);
1011 }
1012 }
1013
TEST_F(LiteralUtilTest,CopyFromScalars)1014 TEST_F(LiteralUtilTest, CopyFromScalars) {
1015 auto zero = LiteralUtil::CreateR0<uint32>(0);
1016 auto nine = LiteralUtil::CreateR0<uint32>(9);
1017 TF_EXPECT_OK(zero.CopyFrom(nine));
1018 EXPECT_EQ(zero, nine);
1019
1020 auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
1021 TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
1022 EXPECT_EQ(zero.Get<uint32>({}), 17);
1023 TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
1024 EXPECT_EQ(vect.Get<uint32>({4}), 17);
1025 }
1026
TEST_F(LiteralUtilTest,CopyFromAndToZeroElement)1027 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
1028 const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
1029 const auto const_nine = LiteralUtil::CreateR1<float>({9});
1030 const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
1031
1032 {
1033 // Source contains dimension with zero elements.
1034 const auto empty = Literal::CreateFromShape(empty_r1_shape);
1035 auto nine = LiteralUtil::CreateR1<float>({9});
1036
1037 TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
1038 EXPECT_EQ(nine, const_nine);
1039 }
1040
1041 {
1042 // Copy 0 element to destination with zero elements.
1043 auto empty = Literal::CreateFromShape(empty_r1_shape);
1044 auto nine = LiteralUtil::CreateR1<float>({9});
1045
1046 TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
1047 EXPECT_EQ(empty, const_empty);
1048 }
1049 }
1050
TEST_F(LiteralUtilTest,CopyFromNilShape)1051 TEST_F(LiteralUtilTest, CopyFromNilShape) {
1052 Literal nil_literal0(ShapeUtil::MakeNil());
1053 Literal nil_literal1(ShapeUtil::MakeNil());
1054 // This doesn't actually do any copying, but it should succeed.
1055 TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
1056 }
1057
TEST_F(LiteralUtilTest,CopyFromArrays)1058 TEST_F(LiteralUtilTest, CopyFromArrays) {
1059 auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
1060 auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
1061 EXPECT_NE(scalar_42, scalar_123);
1062 TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
1063 /*src_shape_index=*/{}));
1064 EXPECT_EQ(scalar_42, scalar_123);
1065 EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
1066
1067 auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1068 auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
1069 EXPECT_NE(matrix_1234, matrix_5678);
1070 EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
1071 TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
1072 /*src_shape_index=*/{}));
1073 EXPECT_EQ(matrix_1234, matrix_5678);
1074 EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
1075 }
1076
TEST_F(LiteralUtilTest,CopyFromTuples)1077 TEST_F(LiteralUtilTest, CopyFromTuples) {
1078 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1079 Literal nil_literal(ShapeUtil::MakeNil());
1080 Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
1081 LiteralUtil::CreateR1<double>({23.0, 44.0})};
1082 Literal inner_tuple = LiteralUtil::MakeTuple(
1083 {&inner_elements[0], &inner_elements[1], &nil_literal});
1084 Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
1085 // Create a tuple the same shape as the inner tuple of nested_tuple but with
1086 // different values..
1087 Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
1088 Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
1089 Literal tuple =
1090 LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
1091
1092 EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1093 EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
1094 EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
1095 EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
1096
1097 // Overwrite the inner tuple element of nested_tuple with the contents of
1098 // 'tuple'.
1099 TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1100 /*src_shape_index=*/{}));
1101
1102 // The matrix element should be unchanged.
1103 EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1104
1105 // The tuple element should have been copied from 'tuple'.
1106 EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
1107 EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
1108 EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
1109 }
TEST_F(LiteralUtilTest,CopyBetweenSameTuple)1110 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
1111 Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
1112 LiteralUtil::CreateR0<int32>(4)};
1113 Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1114
1115 EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
1116 EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
1117
1118 // Copy from one element to the other.
1119 TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1120 /*src_shape_index=*/{0}));
1121
1122 EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
1123 EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
1124 }
1125
TEST_F(LiteralUtilTest,CopyFromDifferentShapes)1126 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
1127 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1128 auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
1129 Status status = matrix.CopyFrom(vector);
1130 ASSERT_FALSE(status.ok());
1131 EXPECT_THAT(status.error_message(),
1132 HasSubstr("Destination subshape incompatible"));
1133 }
1134
TEST_F(LiteralUtilTest,F16)1135 TEST_F(LiteralUtilTest, F16) {
1136 // Verify that the internal data views are consistent and that they
1137 // are in little endian format
1138 // TODO - modify if we make the data format machine endianess dependent
1139 Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
1140 const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
1141 EXPECT_EQ(d1[0], 0);
1142 EXPECT_EQ(d1[1], 0);
1143 EXPECT_EQ(d1[2], 0);
1144 EXPECT_EQ(d1[3], 0);
1145 EXPECT_EQ(d1[4], 0);
1146 EXPECT_EQ(d1[5], 0);
1147 EXPECT_EQ(d1[6], 0);
1148 EXPECT_EQ(d1[7], 0);
1149
1150 half h1(1.0f);
1151 half h2(2.0f);
1152 auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1153 const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
1154 EXPECT_EQ(d2[0], 0);
1155 EXPECT_EQ(d2[1], 0x3C);
1156 EXPECT_EQ(d2[2], 0);
1157 EXPECT_EQ(d2[3], 0x40);
1158 EXPECT_EQ(d2[4], 0);
1159 EXPECT_EQ(d2[5], 0x40);
1160 EXPECT_EQ(d2[6], 0);
1161 EXPECT_EQ(d2[7], 0x3C);
1162 }
1163
TEST_F(LiteralUtilTest,Populate)1164 TEST_F(LiteralUtilTest, Populate) {
1165 struct PopulateData {
1166 std::vector<int64> dimensions;
1167 std::vector<int64> layout;
1168 } populate_data[] = {
1169 {{}, {}},
1170 {{0}, {0}},
1171 {{16}, {0}},
1172 {{2, 0}, {1, 0}},
1173 {{4, 16}, {1, 0}},
1174 {{21, 12}, {0, 1}},
1175 {{6, 11, 17}, {2, 0, 1}},
1176 {{6, 11, 5, 17}, {3, 2, 0, 1}},
1177 };
1178 for (const auto& data : populate_data) {
1179 Shape shape = ShapeUtil::MakeShapeWithLayout(
1180 primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
1181 data.layout);
1182 Literal literal(shape);
1183 auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
1184 // Offsets from linear index just to avoid R0 literals to be initialized
1185 // with zero.
1186 return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1187 indexes) +
1188 17;
1189 };
1190 TF_EXPECT_OK(literal.Populate<uint32>(generator));
1191
1192 std::vector<int64> zero_base(data.dimensions.size(), 0);
1193 std::vector<int64> step(data.dimensions.size(), 1);
1194 bool matched = true;
1195 auto check_function = [&](absl::Span<const int64> indexes) {
1196 auto value = literal.Get<uint32>(indexes);
1197 matched = matched && (value == generator(indexes));
1198 return matched;
1199 };
1200 ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1201 check_function);
1202 EXPECT_TRUE(matched);
1203 }
1204 }
1205
TEST_F(LiteralUtilTest,PopulateParallel)1206 TEST_F(LiteralUtilTest, PopulateParallel) {
1207 struct PopulateData {
1208 std::vector<int64> dimensions;
1209 std::vector<int64> layout;
1210 } populate_data[] = {
1211 {{}, {}},
1212 {{0}, {0}},
1213 {{16}, {0}},
1214 {{2, 0}, {1, 0}},
1215 {{4, 16}, {1, 0}},
1216 {{21, 12}, {0, 1}},
1217 {{6, 11, 17}, {2, 0, 1}},
1218 {{6, 11, 5, 17}, {3, 2, 0, 1}},
1219 };
1220 for (const auto& data : populate_data) {
1221 Shape shape = ShapeUtil::MakeShapeWithLayout(
1222 primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
1223 data.layout);
1224 Literal literal(shape);
1225 auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
1226 // Offsets from linear index just to avoid R0 literals to be initialized
1227 // with zero.
1228 return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1229 indexes) +
1230 17;
1231 };
1232 TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
1233
1234 std::vector<int64> zero_base(data.dimensions.size(), 0);
1235 std::vector<int64> step(data.dimensions.size(), 1);
1236 bool matched = true;
1237 auto check_function = [&](absl::Span<const int64> indexes) {
1238 auto value = literal.Get<uint32>(indexes);
1239 matched = matched && (value == generator(indexes));
1240 return matched;
1241 };
1242 ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1243 check_function);
1244 EXPECT_TRUE(matched);
1245 }
1246 }
1247
TEST_F(LiteralUtilTest,ConvertR4)1248 TEST_F(LiteralUtilTest, ConvertR4) {
1249 // clang-format off
1250 auto original = LiteralUtil::CreateR4WithLayout<int8>({{
1251 {{10, 11, 12, 13}, {14, 15, 16, 17}},
1252 {{18, 19, 20, 21}, {22, 23, 24, 25}},
1253 {{26, 27, 28, 29}, {30, 31, 32, 33}},
1254 }}, layout_r4_dim0major_);
1255 auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
1256 {{10, 11, 12, 13}, {14, 15, 16, 17}},
1257 {{18, 19, 20, 21}, {22, 23, 24, 25}},
1258 {{26, 27, 28, 29}, {30, 31, 32, 33}},
1259 }}, layout_r4_dim0major_);
1260 // clang-format on
1261 TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
1262
1263 EXPECT_EQ(expected, converted);
1264 }
1265
TEST_F(LiteralUtilTest,ConvertIfTypesMatch)1266 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
1267 // clang-format off
1268 auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
1269 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1270 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1271 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1272 }}, layout_r4_dim0major_);
1273 auto s16 = LiteralUtil::CreateR4WithLayout<int16>({{
1274 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1275 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1276 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1277 }}, layout_r4_dim0major_);
1278 auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
1279 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1280 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1281 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1282 }}, layout_r4_dim0major_);
1283 auto u16 = LiteralUtil::CreateR4WithLayout<uint16>({{
1284 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1285 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1286 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1287 }}, layout_r4_dim0major_);
1288 auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
1289 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1290 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1291 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1292 }}, layout_r4_dim0major_);
1293 auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
1294 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1295 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1296 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1297 }}, layout_r4_dim0major_);
1298 auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
1299 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1300 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1301 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1302 }}, layout_r4_dim0major_);
1303 auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
1304 {{true, false, true, false}, {false, true, false, true}},
1305 {{false, true, false, true}, {true, false, true, false}},
1306 {{true, false, true, false}, {false, true, false, true}},
1307 }}, layout_r4_dim0major_);
1308 auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
1309 {{1, 0, 1, 0}, {0, 1, 0, 1}},
1310 {{0, 1, 0, 1}, {1, 0, 1, 0}},
1311 {{1, 0, 1, 0}, {0, 1, 0, 1}},
1312 }}, layout_r4_dim0major_);
1313 auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
1314 {{half(10.0), half(0.0), half(12.0), half(0.0)},
1315 {half(0.0), half(15.0), half(0.0), half(17.0)}},
1316 {{half(0.0), half(19.0), half(0.0), half(21.0)},
1317 {half(22.0), half(0.0), half(24.0), half(0.0)}},
1318 {{half(26.0), half(0.0), half(28.0), half(0.0)},
1319 {half(0.0), half(31.0), half(0.0), half(33.0)}},
1320 }}, layout_r4_dim0major_);
1321 auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
1322 {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
1323 {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
1324 {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
1325 {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
1326 {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
1327 {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
1328 }}, layout_r4_dim0major_);
1329 auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
1330 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1331 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1332 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1333 }}, layout_r4_dim0major_);
1334 auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
1335 {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1336 {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1337 {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1338 }}, layout_r4_dim0major_);
1339 auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
1340 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1341 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1342 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1343 }}, layout_r4_dim0major_);
1344 auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
1345 {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1346 {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1347 {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1348 }}, layout_r4_dim0major_); // clang-format on
1349 Literal conv;
1350
1351 conv = s8.Convert(U16).ConsumeValueOrDie();
1352 EXPECT_EQ(conv, u16);
1353
1354 conv = s8.Convert(S16).ConsumeValueOrDie();
1355 EXPECT_EQ(conv, s16);
1356
1357 conv = s8.Convert(U32).ConsumeValueOrDie();
1358 EXPECT_EQ(conv, u32);
1359
1360 conv = s8.Convert(S32).ConsumeValueOrDie();
1361 EXPECT_EQ(conv, s32);
1362
1363 conv = s8.Convert(U64).ConsumeValueOrDie();
1364 EXPECT_EQ(conv, u64);
1365
1366 conv = s8.Convert(S64).ConsumeValueOrDie();
1367 EXPECT_EQ(conv, s64);
1368
1369 conv = s8.Convert(PRED).ConsumeValueOrDie();
1370 EXPECT_EQ(conv, pred);
1371
1372 conv = bf16.Convert(S32).ConsumeValueOrDie();
1373 EXPECT_EQ(conv, s32);
1374
1375 conv = bf16.Convert(F32).ConsumeValueOrDie();
1376 EXPECT_EQ(conv, f32);
1377
1378 conv = pred.Convert(S32).ConsumeValueOrDie();
1379 EXPECT_EQ(conv, int32_pred);
1380
1381 conv = f32.Convert(S32).ConsumeValueOrDie();
1382 EXPECT_EQ(conv, s32);
1383
1384 conv = f64.Convert(S32).ConsumeValueOrDie();
1385 EXPECT_EQ(conv, s32);
1386
1387 conv = s32.Convert(F32).ConsumeValueOrDie();
1388 EXPECT_EQ(conv, f32);
1389
1390 conv = f32.Convert(F16).ConsumeValueOrDie();
1391 EXPECT_EQ(conv, f16);
1392
1393 conv = f64.Convert(F16).ConsumeValueOrDie();
1394 EXPECT_EQ(conv, f16);
1395
1396 conv = s32.Convert(F16).ConsumeValueOrDie();
1397 EXPECT_EQ(conv, f16);
1398
1399 conv = u32.Convert(F16).ConsumeValueOrDie();
1400 EXPECT_EQ(conv, f16);
1401
1402 conv = s32.Convert(C64).ConsumeValueOrDie();
1403 EXPECT_EQ(conv, c64);
1404
1405 conv = f16.Convert(C64).ConsumeValueOrDie();
1406 EXPECT_EQ(conv, c64);
1407
1408 conv = s32.Convert(S16).ConsumeValueOrDie();
1409 EXPECT_EQ(conv, s16);
1410
1411 conv = s32.Convert(U16).ConsumeValueOrDie();
1412 EXPECT_EQ(conv, u16);
1413
1414 conv = s32.Convert(C128).ConsumeValueOrDie();
1415 EXPECT_EQ(conv, c128);
1416
1417 conv = f16.Convert(C128).ConsumeValueOrDie();
1418 EXPECT_EQ(conv, c128);
1419
1420 EXPECT_EQ(s32.Convert(TUPLE).status().code(),
1421 tensorflow::error::UNIMPLEMENTED);
1422 EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
1423 EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
1424 EXPECT_EQ(c128.Convert(F32).status().code(),
1425 tensorflow::error::UNIMPLEMENTED);
1426 EXPECT_EQ(c128.Convert(S32).status().code(),
1427 tensorflow::error::UNIMPLEMENTED);
1428 }
1429
TEST_F(LiteralUtilTest,BitcastConvert)1430 TEST_F(LiteralUtilTest, BitcastConvert) {
1431 auto original = LiteralUtil::CreateR1<uint32>(
1432 {absl::bit_cast<uint32>(2.5f), absl::bit_cast<uint32>(-42.25f),
1433 absl::bit_cast<uint32>(100.f), 0xbeef});
1434 auto expected = LiteralUtil::CreateR1<float>(
1435 {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
1436 TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
1437 }
1438
TEST_F(LiteralUtilTest,BitcastConvertBetweenInvalidTypes)1439 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
1440 auto literal = LiteralUtil::CreateR0<uint32>(1234);
1441 Status status = literal.BitcastConvert(F64).status();
1442 EXPECT_NE(Status::OK(), status);
1443 EXPECT_TRUE(
1444 absl::StrContains(status.error_message(), "bit widths are different"));
1445 }
1446
1447 // Sets the layout of the given ShapeProto to the default.
SetDefaultLayoutOnProto(ShapeProto * shape_proto)1448 void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
1449 CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
1450 shape_proto->mutable_layout()->set_format(DENSE);
1451 auto* minor_to_major =
1452 shape_proto->mutable_layout()->mutable_minor_to_major();
1453 minor_to_major->Resize(shape_proto->dimensions_size(), 0);
1454 const int64 size = minor_to_major->size();
1455 for (int64 i = 0; i < size; ++i) {
1456 minor_to_major->Set(i, size - 1 - i);
1457 }
1458 }
1459
TEST_F(LiteralUtilTest,CopyFromProto_Bool)1460 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
1461 LiteralProto p;
1462 p.mutable_shape()->set_element_type(PRED);
1463 for (int len = 0; len < 25; ++len) {
1464 p.mutable_shape()->clear_dimensions();
1465 p.mutable_shape()->add_dimensions(len);
1466 SetDefaultLayoutOnProto(p.mutable_shape());
1467 p.clear_preds();
1468 for (int i = 0; i < len; ++i) {
1469 p.add_preds((i % 2) == (len % 2));
1470 }
1471
1472 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1473 ASSERT_EQ(len, literal.data<bool>().size());
1474 int i = 0;
1475 for (bool value : literal.data<bool>()) {
1476 EXPECT_EQ((i % 2) == (len % 2), value);
1477 ++i;
1478 }
1479 }
1480 }
1481
1482 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,ToProto_f16)1483 TEST_F(LiteralUtilTest, ToProto_f16) {
1484 half h1(1.0f);
1485 half h2(2.0f);
1486
1487 auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1488 EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
1489 EXPECT_EQ(4, m.data<half>().size());
1490
1491 LiteralProto p = m.ToProto();
1492 EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
1493 EXPECT_EQ(8, p.f16s().size());
1494 const char* d = p.f16s().data();
1495 EXPECT_EQ(d[0], 0);
1496 EXPECT_EQ(d[1], 0x3C);
1497 EXPECT_EQ(d[2], 0);
1498 EXPECT_EQ(d[3], 0x40);
1499 EXPECT_EQ(d[4], 0);
1500 EXPECT_EQ(d[5], 0x40);
1501 EXPECT_EQ(d[6], 0);
1502 EXPECT_EQ(d[7], 0x3C);
1503 }
1504
1505 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,CopyFromProto_f16)1506 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
1507 half h1(1.0f);
1508 half h2(2.0f);
1509
1510 const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
1511 LiteralProto p;
1512 p.mutable_shape()->set_element_type(F16);
1513 p.mutable_shape()->clear_dimensions();
1514 p.mutable_shape()->add_dimensions(4);
1515 SetDefaultLayoutOnProto(p.mutable_shape());
1516 p.clear_f16s();
1517 p.set_f16s(half_vals, 8);
1518 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1519 auto r = literal.data<half>();
1520 ASSERT_EQ(4, r.size());
1521 EXPECT_EQ(h1, r[0]);
1522 EXPECT_EQ(h2, r[1]);
1523 EXPECT_EQ(h2, r[2]);
1524 EXPECT_EQ(h1, r[3]);
1525 }
1526
TEST_F(LiteralUtilTest,CopyFromProto_u16)1527 TEST_F(LiteralUtilTest, CopyFromProto_u16) {
1528 uint16 u1(0xabcd);
1529 uint16 u2(0x1234);
1530
1531 const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
1532 0x34, 0x12, 0xcd, 0xab};
1533 LiteralProto p;
1534 p.mutable_shape()->set_element_type(U16);
1535 p.mutable_shape()->clear_dimensions();
1536 p.mutable_shape()->add_dimensions(4);
1537 SetDefaultLayoutOnProto(p.mutable_shape());
1538 p.clear_u16s();
1539 p.set_u16s(uint16_vals, 8);
1540 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1541 auto r = literal.data<uint16>();
1542 ASSERT_EQ(4, r.size());
1543 EXPECT_EQ(u1, r[0]);
1544 EXPECT_EQ(u2, r[1]);
1545 EXPECT_EQ(u2, r[2]);
1546 EXPECT_EQ(u1, r[3]);
1547 }
1548
TEST_F(LiteralUtilTest,LiteralSliceTest)1549 TEST_F(LiteralUtilTest, LiteralSliceTest) {
1550 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1551 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1552 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1553 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1554 Literal nil(ShapeUtil::MakeNil());
1555
1556 EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
1557 EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
1558 EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
1559 EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
1560 EXPECT_EQ(LiteralSlice(nil, {}), nil);
1561
1562 EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
1563 EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
1564
1565 EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
1566 EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
1567 EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
1568 EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
1569 }
1570
TEST_F(LiteralUtilTest,MutatingLiteralSlice)1571 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
1572 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1573 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1574 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1575 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1576 // Verify that changing the underlying data beneath the view changes the
1577 // data of the view itself.
1578 const auto nested_tuple_view = LiteralSlice(nested_tuple);
1579 EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1580 1.0f);
1581 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1582 /*shape_index=*/{0, 0}),
1583 1.0f);
1584 nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
1585 EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1586 555.0f);
1587 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1588 /*shape_index=*/{0, 0}),
1589 555.0f);
1590 }
1591
TEST_F(LiteralUtilTest,LiteralSliceOfALiteralSlice)1592 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
1593 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1594 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1595 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1596 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1597
1598 const auto nested_tuple_view = LiteralSlice(nested_tuple);
1599 const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
1600 const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
1601 EXPECT_EQ(matrix_view,
1602 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1603 }
1604
TEST_F(LiteralUtilTest,BorrowingLiteralFromOneBufferPtr)1605 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
1606 std::vector<int64> int64_values = {1, 2, 3};
1607 const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
1608
1609 BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
1610 literal_shape);
1611
1612 EXPECT_EQ(literal.Get<int64>({0}), 1);
1613 EXPECT_EQ(literal.Get<int64>({1}), 2);
1614 EXPECT_EQ(literal.Get<int64>({2}), 3);
1615 }
1616
TEST_F(LiteralUtilTest,BorrowingLiteralFromMultipleBufferPtrs)1617 TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
1618 std::vector<int64> one_two_three = {1, 2, 3};
1619 const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
1620
1621 std::vector<int64> hundred = {100};
1622 const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
1623
1624 std::vector<const char*> src_buf_ptrs;
1625 src_buf_ptrs.emplace_back(
1626 reinterpret_cast<const char*>(one_two_three.data()));
1627 src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
1628 auto literal_tuple = BorrowingLiteral(
1629 src_buf_ptrs,
1630 ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
1631
1632 EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{0}),
1633 1);
1634 EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{1}),
1635 100);
1636
1637 EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{1}, /*shape_index=*/{0}),
1638 2);
1639
1640 EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{2}, /*shape_index=*/{0}),
1641 3);
1642 }
1643
TEST_F(LiteralUtilTest,LiteralMove)1644 TEST_F(LiteralUtilTest, LiteralMove) {
1645 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1646 Literal literal(std::move(matrix));
1647
1648 EXPECT_TRUE(
1649 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1650 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1651 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1652 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1653 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1654 }
1655
TEST_F(LiteralUtilTest,DecomposeTuple)1656 TEST_F(LiteralUtilTest, DecomposeTuple) {
1657 Literal nil_literal(ShapeUtil::MakeNil());
1658 Literal inner_elements[] = {
1659 LiteralUtil::CreateR0<int32>(42),
1660 LiteralUtil::CreateR1<double>({23.0, 44.0}),
1661 };
1662 Literal tuple_elements[] = {
1663 LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
1664 LiteralUtil::MakeTuple(
1665 {&inner_elements[0], &inner_elements[1], &nil_literal}),
1666 };
1667 Literal nested_tuple = LiteralUtil::MakeTuple(
1668 {&tuple_elements[0], &tuple_elements[1], &nil_literal});
1669
1670 EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1671 std::vector<Literal> elements = nested_tuple.DecomposeTuple();
1672 EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1673
1674 ASSERT_EQ(elements.size(), 3);
1675
1676 EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
1677 ShapeUtil::MakeShape(S32, {2, 2})));
1678 EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
1679 EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
1680 EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
1681 EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
1682
1683 EXPECT_TRUE(ShapeUtil::Compatible(
1684 elements[1].shape(),
1685 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
1686 ShapeUtil::MakeShape(F64, {2}),
1687 ShapeUtil::MakeNil()})));
1688 EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
1689 EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
1690 EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
1691
1692 EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
1693 }
1694
TEST_F(LiteralUtilTest,DecomposeEmptyTuple)1695 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
1696 Literal nil_literal(ShapeUtil::MakeNil());
1697 std::vector<Literal> elements = nil_literal.DecomposeTuple();
1698 EXPECT_EQ(elements.size(), 0);
1699 }
1700
TEST_F(LiteralUtilTest,MoveIntoTuple)1701 TEST_F(LiteralUtilTest, MoveIntoTuple) {
1702 std::vector<Literal> elements;
1703 elements.push_back(LiteralUtil::CreateR0<float>(1.0));
1704 elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
1705 std::vector<Literal> inner_elements;
1706 inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
1707 inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
1708 elements.push_back(
1709 LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
1710
1711 Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
1712 ASSERT_TRUE(literal.shape().IsTuple());
1713 ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
1714
1715 EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
1716 EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
1717 EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
1718 EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
1719 EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
1720 EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
1721
1722 for (const Literal& element : elements) {
1723 EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
1724 }
1725 }
1726
TEST_F(LiteralUtilTest,MoveIntoEmptyTuple)1727 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
1728 Literal literal = Literal::MoveIntoTuple({});
1729 ASSERT_TRUE(literal.shape().IsTuple());
1730 EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
1731 }
1732
TEST_F(LiteralUtilTest,LiteralMoveAssignment)1733 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
1734 Literal literal;
1735 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
1736
1737 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1738 literal = std::move(matrix);
1739
1740 EXPECT_TRUE(
1741 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1742 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1743 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1744 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1745 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1746 }
1747
TEST_F(LiteralUtilTest,LiteralSliceCopy)1748 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
1749 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1750 const auto matrix_view = LiteralSlice(matrix);
1751 LiteralSlice matrix_view_copy(matrix_view);
1752
1753 EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
1754 EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
1755 EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
1756 EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
1757 }
1758
TEST_F(LiteralUtilTest,GetSetTuple)1759 TEST_F(LiteralUtilTest, GetSetTuple) {
1760 Literal elements[] = {
1761 LiteralUtil::CreateR0<float>(42.0),
1762 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
1763 };
1764 auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1765 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
1766 tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
1767 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
1768
1769 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
1770 tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
1771 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
1772 -4.0);
1773 }
1774
TEST_F(LiteralUtilTest,CreateFromShapeZeroInitialized)1775 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
1776 // Literals constructed using CreateFromShape should be zero initialized.
1777 Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
1778 EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
1779 EXPECT_TRUE(scalar_f32.IsAll(0));
1780
1781 Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
1782 EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
1783 EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
1784 EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
1785 EXPECT_TRUE(vector_s32.IsAll(0));
1786
1787 Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
1788 {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
1789 ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
1790 ShapeUtil::MakeShape(C128, {})}));
1791
1792 EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
1793 EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
1794 EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
1795 EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
1796 EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
1797 EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
1798 EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
1799 }
1800
TEST_F(LiteralUtilTest,ProtoRoundTrip)1801 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
1802 // Test serializing then deserializing a Literal through a proto.
1803 auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
1804 auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
1805 auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
1806 auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
1807 auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
1808 auto vector_c128 =
1809 LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
1810 auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
1811 {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
1812 auto vector_half =
1813 LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
1814 auto matrix_pred =
1815 LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
1816 auto tuple = LiteralUtil::MakeTuple(
1817 {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
1818 Literal nil_literal(ShapeUtil::MakeNil());
1819 auto nested_tuple =
1820 LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
1821
1822 auto to_from_proto = [](const Literal& literal) -> Literal {
1823 return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
1824 };
1825
1826 EXPECT_EQ(one_f32, to_from_proto(one_f32));
1827 EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
1828 EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
1829 EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
1830 EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
1831 EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
1832 EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
1833 EXPECT_EQ(tuple, to_from_proto(tuple));
1834 EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
1835 EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
1836
1837 EXPECT_NE(one_f32, two_f32);
1838 EXPECT_NE(one_f32, to_from_proto(two_f32));
1839 }
1840
TEST_F(LiteralUtilTest,InvalidProtoNoValues)1841 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
1842 // Proto contains a shape, but no values.
1843 LiteralProto proto;
1844 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1845 Status status = Literal::CreateFromProto(proto).status();
1846 ASSERT_FALSE(status.ok());
1847 EXPECT_THAT(status.error_message(),
1848 HasSubstr("Expected 3 elements in LiteralProto"));
1849 }
1850
TEST_F(LiteralUtilTest,InvalidProtoNoShape)1851 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
1852 // Proto contains values, but no shape.
1853 LiteralProto proto;
1854 proto.add_preds(false);
1855 proto.add_preds(true);
1856 proto.add_preds(false);
1857 Status status = Literal::CreateFromProto(proto).status();
1858 ASSERT_FALSE(status.ok());
1859 EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
1860 }
1861
TEST_F(LiteralUtilTest,InvalidProtoWrongContainer)1862 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
1863 // Proto contains values in wrong container.
1864 LiteralProto proto;
1865 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1866 proto.add_preds(false);
1867 proto.add_preds(true);
1868 proto.add_preds(false);
1869 Status status = Literal::CreateFromProto(proto).status();
1870 ASSERT_FALSE(status.ok());
1871 EXPECT_THAT(status.error_message(),
1872 HasSubstr("Expected 3 elements in LiteralProto"));
1873 }
1874
TEST_F(LiteralUtilTest,InvalidProtoTooFewValues)1875 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
1876 // Proto contains too few values.
1877 LiteralProto proto;
1878 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
1879 proto.add_f32s(1.0);
1880 proto.add_f32s(2.0);
1881 proto.add_f32s(3.0);
1882 Status status = Literal::CreateFromProto(proto).status();
1883 ASSERT_FALSE(status.ok());
1884 EXPECT_THAT(status.error_message(),
1885 HasSubstr("Expected 84 elements in LiteralProto"));
1886 }
1887
TEST_F(LiteralUtilTest,InvalidProtoTooManyValues)1888 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
1889 // Proto contains too many values.
1890 LiteralProto proto;
1891 *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
1892 proto.add_s32s(42);
1893 proto.add_s32s(-10);
1894 proto.add_s32s(100);
1895 Status status = Literal::CreateFromProto(proto).status();
1896 ASSERT_FALSE(status.ok());
1897 EXPECT_THAT(status.error_message(),
1898 HasSubstr("Expected 2 elements in LiteralProto"));
1899 }
1900
TEST_F(LiteralUtilTest,InvalidProtoMissingLayout)1901 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
1902 // Proto shape missing layout.
1903 LiteralProto proto;
1904 *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
1905 proto.mutable_shape()->clear_layout();
1906 proto.add_preds(true);
1907 proto.add_preds(false);
1908 proto.add_preds(true);
1909 proto.add_preds(false);
1910 Status status = Literal::CreateFromProto(proto).status();
1911 ASSERT_FALSE(status.ok());
1912 EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
1913 }
1914
TEST_F(LiteralUtilTest,InvalidProtoTooFewTupleElements)1915 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
1916 // Proto has the too few tuple elements.
1917 LiteralProto proto;
1918 *proto.mutable_shape() =
1919 ShapeUtil::MakeTupleShape(
1920 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
1921 .ToProto();
1922 LiteralProto* element0 = proto.add_tuple_literals();
1923 *element0->mutable_shape() =
1924 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
1925 element0->add_preds(false);
1926 element0->add_preds(true);
1927
1928 Status status = Literal::CreateFromProto(proto).status();
1929 ASSERT_FALSE(status.ok());
1930 EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
1931 }
1932
TEST_F(LiteralUtilTest,InvalidProtoTooManyTupleElements)1933 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
1934 // Proto has the too many tuple elements.
1935 LiteralProto proto;
1936 *proto.mutable_shape() =
1937 ShapeUtil::MakeTupleShape(
1938 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
1939 .ToProto();
1940 LiteralProto* element0 = proto.add_tuple_literals();
1941 *element0->mutable_shape() =
1942 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
1943 element0->add_preds(false);
1944 element0->add_preds(true);
1945 LiteralProto* element1 = proto.add_tuple_literals();
1946 *element1->mutable_shape() =
1947 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
1948 element1->add_f32s(42.0);
1949 LiteralProto* element2 = proto.add_tuple_literals();
1950 *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
1951 element2->add_f32s(123.0);
1952
1953 Status status = Literal::CreateFromProto(proto).status();
1954 ASSERT_FALSE(status.ok());
1955 EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
1956 }
1957
TEST_F(LiteralUtilTest,SortSparseElements)1958 TEST_F(LiteralUtilTest, SortSparseElements) {
1959 auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
1960 SparseIndexArray(10, 3), {});
1961 literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
1962 literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
1963 literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
1964 literal.SortSparseElements();
1965 EXPECT_EQ(literal.ToString(),
1966 "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
1967 }
1968
TEST_F(LiteralUtilTest,GetSparseElementAsString)1969 TEST_F(LiteralUtilTest, GetSparseElementAsString) {
1970 std::vector<int64> dimensions = {10, 10, 10};
1971 SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
1972
1973 EXPECT_EQ(
1974 LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
1975 .GetSparseElementAsString(1),
1976 "false");
1977 EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
1978 .GetSparseElementAsString(1),
1979 absl::StrCat(int64{2}));
1980 EXPECT_EQ(
1981 LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
1982 .GetSparseElementAsString(1),
1983 absl::StrCat(double{2.0}));
1984 EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
1985 {half{1.0}, half{2.0}, half{3.0}})
1986 .GetSparseElementAsString(1),
1987 absl::StrCat(static_cast<float>(half{2.0})));
1988 EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
1989 dimensions, indices,
1990 std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
1991 .GetSparseElementAsString(1),
1992 absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
1993 }
1994
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix0)1995 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
1996 Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
1997 TF_ASSERT_OK_AND_ASSIGN(
1998 Literal broadcasted_literal,
1999 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2000 /*dimensions=*/{0}));
2001 EXPECT_EQ(broadcasted_literal,
2002 LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
2003 }
2004
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix1)2005 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
2006 Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
2007 TF_ASSERT_OK_AND_ASSIGN(
2008 Literal broadcasted_literal,
2009 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2010 /*dimensions=*/{1}));
2011 EXPECT_EQ(broadcasted_literal,
2012 LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
2013 }
2014
TEST_F(LiteralUtilTest,BroadcastScalarToMatrix)2015 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
2016 Literal literal = LiteralUtil::CreateR0<int32>(9);
2017 TF_ASSERT_OK_AND_ASSIGN(
2018 Literal broadcasted_literal,
2019 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
2020 /*dimensions=*/{}));
2021 EXPECT_EQ(broadcasted_literal,
2022 LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
2023 }
2024
2025 } // namespace
2026 } // namespace xla
2027