1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_util.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/index_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mem.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace xla {
40 namespace {
41
42 using absl::StrCat;
43
44 // Return a literal with all arrays of type FromNativeT converted to type
45 // ToNativeT in the given literal.
46 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)47 Literal ConvertType(LiteralSlice literal) {
48 // First construct shape of the result.
49 Shape result_shape(literal.shape());
50 ShapeUtil::ForEachMutableSubshape(
51 &result_shape, [](Shape* subshape, const ShapeIndex&) {
52 if (subshape->element_type() ==
53 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
54 subshape->set_element_type(
55 primitive_util::NativeToPrimitiveType<ToNativeT>());
56 }
57 });
58 Literal result(result_shape);
59
60 // Then copy over the data from 'literal' converting FromNativeT values to
61 // ToNativeT values as necessary.
62 ShapeUtil::ForEachSubshape(
63 literal.shape(),
64 [&](const Shape& subshape, const ShapeIndex& shape_index) {
65 if (subshape.IsArray()) {
66 if (subshape.element_type() ==
67 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
68 auto src = literal.data<FromNativeT>(shape_index);
69 auto dest = result.data<ToNativeT>(shape_index);
70 for (int64 i = 0; i < src.size(); ++i) {
71 dest[i] = static_cast<ToNativeT>(src[i]);
72 }
73 } else {
74 TF_CHECK_OK(result.CopyFrom(literal,
75 /*dest_shape_index=*/shape_index,
76 /*src_shape_index=*/shape_index));
77 }
78 }
79 });
80 return result;
81 }
82
83 } // namespace
84
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64> dimensions)85 /* static */ Literal LiteralUtil::CreateFromDimensions(
86 PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
87 return Literal::CreateFromShape(
88 ShapeUtil::MakeShape(primitive_type, dimensions));
89 }
90
ConvertBF16ToF32(const LiteralSlice & bf16_literal)91 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
92 const LiteralSlice& bf16_literal) {
93 return ConvertType<bfloat16, float>(bf16_literal);
94 }
95
ConvertBF16ToF64(const LiteralSlice & bf16_literal)96 /* static */ Literal LiteralUtil::ConvertBF16ToF64(
97 const LiteralSlice& bf16_literal) {
98 return ConvertType<bfloat16, double>(bf16_literal);
99 }
100
ConvertF32ToBF16(const LiteralSlice & f32_literal)101 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
102 const LiteralSlice& f32_literal) {
103 return ConvertType<float, bfloat16>(f32_literal);
104 }
105
ConvertF32ToF64(const LiteralSlice & f32_literal)106 /* static */ Literal LiteralUtil::ConvertF32ToF64(
107 const LiteralSlice& f32_literal) {
108 return ConvertType<float, double>(f32_literal);
109 }
110
ConvertF64ToBF16(const LiteralSlice & f64_literal)111 /* static */ Literal LiteralUtil::ConvertF64ToBF16(
112 const LiteralSlice& f64_literal) {
113 return ConvertType<double, bfloat16>(f64_literal);
114 }
115
ConvertF64ToF32(const LiteralSlice & f64_literal)116 /* static */ Literal LiteralUtil::ConvertF64ToF32(
117 const LiteralSlice& f64_literal) {
118 return ConvertType<double, float>(f64_literal);
119 }
120
CreateToken()121 /* static */ Literal LiteralUtil::CreateToken() {
122 return Literal(ShapeUtil::MakeTokenShape());
123 }
124
Zero(PrimitiveType primitive_type)125 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
126 switch (primitive_type) {
127 case U8:
128 return LiteralUtil::CreateR0<uint8>(0);
129 case U16:
130 return LiteralUtil::CreateR0<uint16>(0);
131 case U32:
132 return LiteralUtil::CreateR0<uint32>(0);
133 case U64:
134 return LiteralUtil::CreateR0<uint64>(0);
135 case S8:
136 return LiteralUtil::CreateR0<int8>(0);
137 case S16:
138 return LiteralUtil::CreateR0<int16>(0);
139 case S32:
140 return LiteralUtil::CreateR0<int32>(0);
141 case S64:
142 return LiteralUtil::CreateR0<int64>(0);
143 case F16:
144 return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
145 case BF16:
146 return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
147 case F32:
148 return LiteralUtil::CreateR0<float>(0);
149 case F64:
150 return LiteralUtil::CreateR0<double>(0);
151 case C64:
152 return LiteralUtil::CreateR0<complex64>(0);
153 case C128:
154 return LiteralUtil::CreateR0<complex128>(0);
155 case PRED:
156 return LiteralUtil::CreateR0<bool>(false);
157 case TUPLE:
158 LOG(FATAL) << "tuple element type cannot take on value of 0";
159 case OPAQUE_TYPE:
160 LOG(FATAL) << "opaque element type cannot take on value of 0";
161 default:
162 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
163 }
164 }
165
One(PrimitiveType primitive_type)166 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
167 switch (primitive_type) {
168 case U8:
169 return LiteralUtil::CreateR0<uint8>(1);
170 case U16:
171 return LiteralUtil::CreateR0<uint16>(1);
172 case U32:
173 return LiteralUtil::CreateR0<uint32>(1);
174 case U64:
175 return LiteralUtil::CreateR0<uint64>(1);
176 case S8:
177 return LiteralUtil::CreateR0<int8>(1);
178 case S16:
179 return LiteralUtil::CreateR0<int16>(1);
180 case S32:
181 return LiteralUtil::CreateR0<int32>(1);
182 case S64:
183 return LiteralUtil::CreateR0<int64>(1);
184 case F16:
185 return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
186 case BF16:
187 return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
188 case F32:
189 return LiteralUtil::CreateR0<float>(1);
190 case F64:
191 return LiteralUtil::CreateR0<double>(1);
192 case C64:
193 return LiteralUtil::CreateR0<complex64>(1);
194 case C128:
195 return LiteralUtil::CreateR0<complex128>(1);
196 case PRED:
197 return LiteralUtil::CreateR0<bool>(true);
198 case TUPLE:
199 LOG(FATAL) << "tuple element type cannot take on value of 1";
200 case OPAQUE_TYPE:
201 LOG(FATAL) << "opaque element type cannot take on value of 1";
202 default:
203 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
204 }
205 }
206
MinValue(PrimitiveType primitive_type)207 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
208 switch (primitive_type) {
209 case U8:
210 return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
211 case U16:
212 return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::min());
213 case U32:
214 return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
215 case U64:
216 return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
217 case S8:
218 return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
219 case S16:
220 return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::min());
221 case S32:
222 return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
223 case S64:
224 return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
225 case F32:
226 return LiteralUtil::CreateR0<float>(
227 -std::numeric_limits<float>::infinity());
228 case F64:
229 return LiteralUtil::CreateR0<double>(
230 -std::numeric_limits<double>::infinity());
231 case C64:
232 LOG(FATAL) << "C64 element type has no minimum value";
233 case C128:
234 LOG(FATAL) << "C128 element type has no minimum value";
235 case PRED:
236 return LiteralUtil::CreateR0<bool>(false);
237 case F16:
238 return LiteralUtil::CreateR0<half>(
239 static_cast<half>(-std::numeric_limits<float>::infinity()));
240 case BF16:
241 return LiteralUtil::CreateR0<bfloat16>(
242 static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
243 case TUPLE:
244 LOG(FATAL) << "tuple element type has no minimum value";
245 case OPAQUE_TYPE:
246 LOG(FATAL) << "opaque element type has no minimum value";
247 default:
248 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
249 }
250 }
251
MaxValue(PrimitiveType primitive_type)252 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
253 switch (primitive_type) {
254 case U8:
255 return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
256 case U16:
257 return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::max());
258 case U32:
259 return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
260 case U64:
261 return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
262 case S8:
263 return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
264 case S16:
265 return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::max());
266 case S32:
267 return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
268 case S64:
269 return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
270 case F32:
271 return LiteralUtil::CreateR0<float>(
272 std::numeric_limits<float>::infinity());
273 case F64:
274 return LiteralUtil::CreateR0<double>(
275 std::numeric_limits<double>::infinity());
276 case PRED:
277 return LiteralUtil::CreateR0<bool>(true);
278 case F16:
279 return LiteralUtil::CreateR0<half>(
280 static_cast<half>(std::numeric_limits<float>::infinity()));
281 case BF16:
282 return LiteralUtil::CreateR0<bfloat16>(
283 static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
284 case TUPLE:
285 LOG(FATAL) << "tuple element type has no maximum value";
286 case OPAQUE_TYPE:
287 LOG(FATAL) << "opaque element type has no maximum value";
288 default:
289 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
290 }
291 }
292
NanValue(PrimitiveType primitive_type)293 /* static */ StatusOr<Literal> LiteralUtil::NanValue(
294 PrimitiveType primitive_type) {
295 switch (primitive_type) {
296 case F16:
297 return LiteralUtil::CreateR0<half>(
298 static_cast<half>(std::numeric_limits<float>::quiet_NaN()));
299 case BF16:
300 return LiteralUtil::CreateR0<bfloat16>(
301 static_cast<bfloat16>(std::numeric_limits<float>::quiet_NaN()));
302 case F32:
303 return LiteralUtil::CreateR0<float>(
304 std::numeric_limits<float>::quiet_NaN());
305 case F64:
306 return LiteralUtil::CreateR0<double>(
307 std::numeric_limits<double>::quiet_NaN());
308 case C64: {
309 float nan = std::numeric_limits<float>::quiet_NaN();
310 return LiteralUtil::CreateR0<complex64>(complex64(nan, nan));
311 }
312 case C128: {
313 double nan = std::numeric_limits<double>::quiet_NaN();
314 return LiteralUtil::CreateR0<complex128>(complex128(nan, nan));
315 }
316 default:
317 return InvalidArgument("Invalid type for NanValue: %s",
318 PrimitiveType_Name(primitive_type));
319 }
320 }
321
CreateR1(const tensorflow::core::Bitmap & values)322 /* static */ Literal LiteralUtil::CreateR1(
323 const tensorflow::core::Bitmap& values) {
324 Literal literal(
325 ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
326 literal.PopulateR1(values);
327 return literal;
328 }
329
CreateR1U8(absl::string_view value)330 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
331 Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
332 for (int i = 0; i < value.size(); ++i) {
333 literal.Set<uint8>({i}, value[i]);
334 }
335 return literal;
336 }
337
CreateR2F32Linspace(float from,float to,int64 rows,int64 cols)338 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
339 int64 rows, int64 cols) {
340 auto value = MakeLinspaceArray2D(from, to, rows, cols);
341 return CreateR2FromArray2D(*value);
342 }
343
ReshapeSlice(absl::Span<const int64> new_dimensions,absl::Span<const int64> minor_to_major,const LiteralSlice & literal)344 /* static */ Literal LiteralUtil::ReshapeSlice(
345 absl::Span<const int64> new_dimensions,
346 absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
347 int64 new_num_elements = 1;
348 for (int64 i = 0; i < new_dimensions.size(); ++i) {
349 new_num_elements *= new_dimensions[i];
350 }
351 CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
352 CHECK_EQ(new_dimensions.size(), minor_to_major.size());
353
354 Literal new_literal(
355 ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
356
357 // Create a new shape with the given minor-to-major layout. This shape is used
358 // solely for converting linear address to multi-dimensional addresses when
359 // writing elements to the new literal.
360 Shape shape_with_layout = new_literal.shape();
361 *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
362
363 // Copy data into new literal, element-by-element.
364 for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
365 std::vector<int64> from_multi_index =
366 IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
367 std::vector<int64> to_multi_index =
368 IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
369 switch (literal.shape().element_type()) {
370 case PRED:
371 new_literal.Set<bool>(to_multi_index,
372 literal.Get<bool>(from_multi_index));
373 break;
374 case U8:
375 new_literal.Set<uint8>(to_multi_index,
376 literal.Get<uint8>(from_multi_index));
377 break;
378 case U32:
379 new_literal.Set<uint32>(to_multi_index,
380 literal.Get<uint32>(from_multi_index));
381 break;
382 case S32:
383 new_literal.Set<int32>(to_multi_index,
384 literal.Get<int32>(from_multi_index));
385 break;
386 case U64:
387 new_literal.Set<uint64>(to_multi_index,
388 literal.Get<uint64>(from_multi_index));
389 break;
390 case S64:
391 new_literal.Set<int64>(to_multi_index,
392 literal.Get<int64>(from_multi_index));
393 break;
394 case F32:
395 new_literal.Set<float>(to_multi_index,
396 literal.Get<float>(from_multi_index));
397 break;
398 case F64:
399 new_literal.Set<double>(to_multi_index,
400 literal.Get<double>(from_multi_index));
401 break;
402 case C64:
403 new_literal.Set<complex64>(to_multi_index,
404 literal.Get<complex64>(from_multi_index));
405 break;
406 case C128:
407 new_literal.Set<complex128>(to_multi_index,
408 literal.Get<complex128>(from_multi_index));
409 break;
410 default:
411 LOG(FATAL) << "Unhandled primitive element type: "
412 << PrimitiveType_Name(literal.shape().element_type());
413 }
414 }
415
416 return new_literal;
417 }
418
GetFirstScalarLiteral(const LiteralSlice & literal)419 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
420 const LiteralSlice& literal) {
421 CHECK(literal.shape().IsArray());
422 CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
423 switch (literal.shape().element_type()) {
424 case PRED:
425 return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
426 // 8 bit types.
427 case S8:
428 return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
429 case U8:
430 return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
431 // 16 bit types.
432 case BF16:
433 return LiteralUtil::CreateR0<bfloat16>(
434 literal.GetFirstElement<bfloat16>());
435 case F16:
436 return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
437 case S16:
438 return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
439 case U16:
440 return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
441 // 32 bit types.
442 case F32:
443 return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
444 case S32:
445 return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
446 case U32:
447 return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
448 // 64 bit types.
449 case C64:
450 return LiteralUtil::CreateR0<complex64>(
451 literal.GetFirstElement<complex64>());
452 case F64:
453 return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
454 case S64:
455 return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
456 case U64:
457 return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
458
459 case C128:
460 return LiteralUtil::CreateR0<complex128>(
461 literal.GetFirstElement<complex128>());
462 default:
463 LOG(FATAL) << "Unhandled primitive type "
464 << literal.shape().element_type();
465 }
466 }
467
MakeTuple(absl::Span<const Literal * const> elements)468 /* static */ Literal LiteralUtil::MakeTuple(
469 absl::Span<const Literal* const> elements) {
470 std::vector<Shape> element_shapes;
471 for (const auto* element : elements) {
472 element_shapes.push_back(element->shape());
473 }
474 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
475 for (int i = 0; i < elements.size(); ++i) {
476 TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
477 }
478 return literal;
479 }
480
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)481 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
482 absl::Span<const LiteralSlice> elements) {
483 std::vector<Shape> element_shapes;
484 for (const auto& element : elements) {
485 element_shapes.push_back(element.shape());
486 }
487 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
488 for (int i = 0; i < elements.size(); ++i) {
489 TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
490 }
491 return literal;
492 }
493
MakeTupleOwned(std::vector<Literal> elements)494 /* static */ Literal LiteralUtil::MakeTupleOwned(
495 std::vector<Literal> elements) {
496 std::vector<Shape> element_shapes;
497 element_shapes.reserve(elements.size());
498 for (const auto& element : elements) {
499 element_shapes.push_back(element.shape());
500 }
501 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
502 for (int64 i = 0; i < elements.size(); ++i) {
503 TF_CHECK_OK(
504 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
505 }
506 return literal;
507 }
508
MultiIndexAsString(absl::Span<const int64> multi_index)509 /* static */ string LiteralUtil::MultiIndexAsString(
510 absl::Span<const int64> multi_index) {
511 return StrCat("{", absl::StrJoin(multi_index, ","), "}");
512 }
513
514 } // namespace xla
515