1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_util.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/index_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mem.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace xla {
40 namespace {
41
42 using absl::StrCat;
43
44 // Return a literal with all arrays of type FromNativeT converted to type
45 // ToNativeT in the given literal.
46 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)47 Literal ConvertType(LiteralSlice literal) {
48 // First construct shape of the result.
49 Shape result_shape(literal.shape());
50 ShapeUtil::ForEachMutableSubshape(
51 &result_shape, [](Shape* subshape, const ShapeIndex&) {
52 if (subshape->element_type() ==
53 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
54 subshape->set_element_type(
55 primitive_util::NativeToPrimitiveType<ToNativeT>());
56 }
57 });
58 Literal result(result_shape);
59
60 // Then copy over the data from 'literal' converting FromNativeT values to
61 // ToNativeT values as necessary.
62 ShapeUtil::ForEachSubshape(
63 literal.shape(),
64 [&](const Shape& subshape, const ShapeIndex& shape_index) {
65 if (subshape.IsArray()) {
66 if (subshape.element_type() ==
67 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
68 auto src = literal.data<FromNativeT>(shape_index);
69 auto dest = result.data<ToNativeT>(shape_index);
70 for (int64 i = 0; i < src.size(); ++i) {
71 dest[i] = static_cast<ToNativeT>(src[i]);
72 }
73 } else {
74 TF_CHECK_OK(result.CopyFrom(literal,
75 /*dest_shape_index=*/shape_index,
76 /*src_shape_index=*/shape_index));
77 }
78 }
79 });
80 return result;
81 }
82
83 } // namespace
84
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64> dimensions)85 /* static */ Literal LiteralUtil::CreateFromDimensions(
86 PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
87 return Literal::CreateFromShape(
88 ShapeUtil::MakeShape(primitive_type, dimensions));
89 }
90
ConvertBF16ToF32(const LiteralSlice & bf16_literal)91 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
92 const LiteralSlice& bf16_literal) {
93 return ConvertType<bfloat16, float>(bf16_literal);
94 }
95
ConvertF32ToBF16(const LiteralSlice & f32_literal)96 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
97 const LiteralSlice& f32_literal) {
98 return ConvertType<float, bfloat16>(f32_literal);
99 }
100
CreateToken()101 /* static */ Literal LiteralUtil::CreateToken() {
102 return Literal(ShapeUtil::MakeTokenShape());
103 }
104
Zero(PrimitiveType primitive_type)105 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
106 switch (primitive_type) {
107 case U8:
108 return LiteralUtil::CreateR0<uint8>(0);
109 case U16:
110 return LiteralUtil::CreateR0<uint16>(0);
111 case U32:
112 return LiteralUtil::CreateR0<uint32>(0);
113 case U64:
114 return LiteralUtil::CreateR0<uint64>(0);
115 case S8:
116 return LiteralUtil::CreateR0<int8>(0);
117 case S16:
118 return LiteralUtil::CreateR0<int16>(0);
119 case S32:
120 return LiteralUtil::CreateR0<int32>(0);
121 case S64:
122 return LiteralUtil::CreateR0<int64>(0);
123 case F16:
124 return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
125 case BF16:
126 return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
127 case F32:
128 return LiteralUtil::CreateR0<float>(0);
129 case F64:
130 return LiteralUtil::CreateR0<double>(0);
131 case C64:
132 return LiteralUtil::CreateR0<complex64>(0);
133 case C128:
134 return LiteralUtil::CreateR0<complex128>(0);
135 case PRED:
136 return LiteralUtil::CreateR0<bool>(false);
137 case TUPLE:
138 LOG(FATAL) << "tuple element type cannot take on value of 0";
139 case OPAQUE:
140 LOG(FATAL) << "opaque element type cannot take on value of 0";
141 default:
142 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
143 }
144 }
145
One(PrimitiveType primitive_type)146 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
147 switch (primitive_type) {
148 case U8:
149 return LiteralUtil::CreateR0<uint8>(1);
150 case U32:
151 return LiteralUtil::CreateR0<uint32>(1);
152 case U64:
153 return LiteralUtil::CreateR0<uint64>(1);
154 case S8:
155 return LiteralUtil::CreateR0<int8>(1);
156 case S32:
157 return LiteralUtil::CreateR0<int32>(1);
158 case S64:
159 return LiteralUtil::CreateR0<int64>(1);
160 case F16:
161 return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
162 case BF16:
163 return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
164 case F32:
165 return LiteralUtil::CreateR0<float>(1);
166 case F64:
167 return LiteralUtil::CreateR0<double>(1);
168 case C64:
169 return LiteralUtil::CreateR0<complex64>(1);
170 case C128:
171 return LiteralUtil::CreateR0<complex128>(1);
172 case PRED:
173 return LiteralUtil::CreateR0<bool>(true);
174 case S16:
175 case U16:
176 LOG(FATAL) << "u16/s16 literals not yet implemented";
177 case TUPLE:
178 LOG(FATAL) << "tuple element type cannot take on value of 1";
179 case OPAQUE:
180 LOG(FATAL) << "opaque element type cannot take on value of 1";
181 default:
182 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
183 }
184 }
185
MinValue(PrimitiveType primitive_type)186 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
187 switch (primitive_type) {
188 case U8:
189 return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
190 case U32:
191 return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
192 case U64:
193 return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
194 case S8:
195 return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
196 case S32:
197 return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
198 case S64:
199 return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
200 case F32:
201 return LiteralUtil::CreateR0<float>(
202 -std::numeric_limits<float>::infinity());
203 case F64:
204 return LiteralUtil::CreateR0<double>(
205 -std::numeric_limits<double>::infinity());
206 case C64:
207 LOG(FATAL) << "C64 element type has no minimum value";
208 case C128:
209 LOG(FATAL) << "C128 element type has no minimum value";
210 case PRED:
211 return LiteralUtil::CreateR0<bool>(false);
212 case S16:
213 case U16:
214 LOG(FATAL) << "u16/s16 literals not yet implemented";
215 case F16:
216 return LiteralUtil::CreateR0<half>(
217 static_cast<half>(-std::numeric_limits<float>::infinity()));
218 case BF16:
219 return LiteralUtil::CreateR0<bfloat16>(
220 static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
221 case TUPLE:
222 LOG(FATAL) << "tuple element type has no minimum value";
223 case OPAQUE:
224 LOG(FATAL) << "opaque element type has no minimum value";
225 default:
226 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
227 }
228 }
229
MaxValue(PrimitiveType primitive_type)230 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
231 switch (primitive_type) {
232 case U8:
233 return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
234 case U32:
235 return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
236 case U64:
237 return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
238 case S8:
239 return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
240 case S32:
241 return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
242 case S64:
243 return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
244 case F32:
245 return LiteralUtil::CreateR0<float>(
246 std::numeric_limits<float>::infinity());
247 case F64:
248 return LiteralUtil::CreateR0<double>(
249 std::numeric_limits<double>::infinity());
250 case PRED:
251 return LiteralUtil::CreateR0<bool>(true);
252 case S16:
253 case U16:
254 LOG(FATAL) << "u16/s16 literals not yet implemented";
255 case F16:
256 return LiteralUtil::CreateR0<half>(
257 static_cast<half>(std::numeric_limits<float>::infinity()));
258 case BF16:
259 return LiteralUtil::CreateR0<bfloat16>(
260 static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
261 case TUPLE:
262 LOG(FATAL) << "tuple element type has no maximum value";
263 case OPAQUE:
264 LOG(FATAL) << "opaque element type has no maximum value";
265 default:
266 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
267 }
268 }
269
CreateR1(const tensorflow::core::Bitmap & values)270 /* static */ Literal LiteralUtil::CreateR1(
271 const tensorflow::core::Bitmap& values) {
272 Literal literal(
273 ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
274 literal.PopulateR1(values);
275 return literal;
276 }
277
CreateR1U8(absl::string_view value)278 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
279 Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
280 for (int i = 0; i < value.size(); ++i) {
281 literal.Set<uint8>({i}, value[i]);
282 }
283 return literal;
284 }
285
CreateR2F32Linspace(float from,float to,int64 rows,int64 cols)286 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
287 int64 rows, int64 cols) {
288 auto value = MakeLinspaceArray2D(from, to, rows, cols);
289 return CreateR2FromArray2D(*value);
290 }
291
ReshapeSlice(absl::Span<const int64> new_dimensions,absl::Span<const int64> minor_to_major,const LiteralSlice & literal)292 /* static */ Literal LiteralUtil::ReshapeSlice(
293 absl::Span<const int64> new_dimensions,
294 absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
295 int64 new_num_elements = 1;
296 for (int64 i = 0; i < new_dimensions.size(); ++i) {
297 new_num_elements *= new_dimensions[i];
298 }
299 CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
300 CHECK_EQ(new_dimensions.size(), minor_to_major.size());
301
302 Literal new_literal(
303 ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
304
305 // Create a new shape with the given minor-to-major layout. This shape is used
306 // solely for converting linear address to multi-dimensional addresses when
307 // writing elements to the new literal.
308 Shape shape_with_layout = new_literal.shape();
309 *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
310
311 // Copy data into new literal, element-by-element.
312 for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
313 std::vector<int64> from_multi_index =
314 IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
315 std::vector<int64> to_multi_index =
316 IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
317 switch (literal.shape().element_type()) {
318 case PRED:
319 new_literal.Set<bool>(to_multi_index,
320 literal.Get<bool>(from_multi_index));
321 break;
322 case U8:
323 new_literal.Set<uint8>(to_multi_index,
324 literal.Get<uint8>(from_multi_index));
325 break;
326 case U32:
327 new_literal.Set<uint32>(to_multi_index,
328 literal.Get<uint32>(from_multi_index));
329 break;
330 case S32:
331 new_literal.Set<int32>(to_multi_index,
332 literal.Get<int32>(from_multi_index));
333 break;
334 case U64:
335 new_literal.Set<uint64>(to_multi_index,
336 literal.Get<uint64>(from_multi_index));
337 break;
338 case S64:
339 new_literal.Set<int64>(to_multi_index,
340 literal.Get<int64>(from_multi_index));
341 break;
342 case F32:
343 new_literal.Set<float>(to_multi_index,
344 literal.Get<float>(from_multi_index));
345 break;
346 case F64:
347 new_literal.Set<double>(to_multi_index,
348 literal.Get<double>(from_multi_index));
349 break;
350 case C64:
351 new_literal.Set<complex64>(to_multi_index,
352 literal.Get<complex64>(from_multi_index));
353 break;
354 case C128:
355 new_literal.Set<complex128>(to_multi_index,
356 literal.Get<complex128>(from_multi_index));
357 break;
358 default:
359 LOG(FATAL) << "Unhandled primitive element type: "
360 << PrimitiveType_Name(literal.shape().element_type());
361 }
362 }
363
364 return new_literal;
365 }
366
GetFirstScalarLiteral(const LiteralSlice & literal)367 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
368 const LiteralSlice& literal) {
369 CHECK(literal.shape().IsArray());
370 CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
371 switch (literal.shape().element_type()) {
372 case PRED:
373 return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
374 // 8 bit types.
375 case S8:
376 return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
377 case U8:
378 return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
379 // 16 bit types.
380 case BF16:
381 return LiteralUtil::CreateR0<bfloat16>(
382 literal.GetFirstElement<bfloat16>());
383 case F16:
384 return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
385 case S16:
386 return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
387 case U16:
388 return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
389 // 32 bit types.
390 case F32:
391 return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
392 case S32:
393 return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
394 case U32:
395 return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
396 // 64 bit types.
397 case C64:
398 return LiteralUtil::CreateR0<complex64>(
399 literal.GetFirstElement<complex64>());
400 case F64:
401 return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
402 case S64:
403 return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
404 case U64:
405 return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
406
407 case C128:
408 return LiteralUtil::CreateR0<complex128>(
409 literal.GetFirstElement<complex128>());
410 default:
411 LOG(FATAL) << "Unhandled primitive type "
412 << literal.shape().element_type();
413 }
414 }
415
MakeTuple(absl::Span<const Literal * const> elements)416 /* static */ Literal LiteralUtil::MakeTuple(
417 absl::Span<const Literal* const> elements) {
418 std::vector<Shape> element_shapes;
419 for (const auto* element : elements) {
420 element_shapes.push_back(element->shape());
421 }
422 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
423 for (int i = 0; i < elements.size(); ++i) {
424 TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
425 }
426 return literal;
427 }
428
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)429 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
430 absl::Span<const LiteralSlice> elements) {
431 std::vector<Shape> element_shapes;
432 for (const auto& element : elements) {
433 element_shapes.push_back(element.shape());
434 }
435 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
436 for (int i = 0; i < elements.size(); ++i) {
437 TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
438 }
439 return literal;
440 }
441
MakeTupleOwned(std::vector<Literal> elements)442 /* static */ Literal LiteralUtil::MakeTupleOwned(
443 std::vector<Literal> elements) {
444 std::vector<Shape> element_shapes;
445 element_shapes.reserve(elements.size());
446 for (const auto& element : elements) {
447 element_shapes.push_back(element.shape());
448 }
449 Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
450 for (int64 i = 0; i < elements.size(); ++i) {
451 TF_CHECK_OK(
452 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
453 }
454 return literal;
455 }
456
MultiIndexAsString(absl::Span<const int64> multi_index)457 /* static */ string LiteralUtil::MultiIndexAsString(
458 absl::Span<const int64> multi_index) {
459 return StrCat("{", absl::StrJoin(multi_index, ","), "}");
460 }
461
462 } // namespace xla
463