1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_util.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/index_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace xla {
39 namespace {
40
41 using absl::StrCat;
42
43 // Return a literal with all arrays of type FromNativeT converted to type
44 // ToNativeT in the given literal.
45 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)46 Literal ConvertType(LiteralSlice literal) {
47 // First construct shape of the result.
48 Shape result_shape(literal.shape());
49 ShapeUtil::ForEachMutableSubshape(
50 &result_shape, [](Shape* subshape, const ShapeIndex&) {
51 if (subshape->element_type() ==
52 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
53 subshape->set_element_type(
54 primitive_util::NativeToPrimitiveType<ToNativeT>());
55 }
56 });
57 Literal result(result_shape);
58
59 // Then copy over the data from 'literal' converting FromNativeT values to
60 // ToNativeT values as necessary.
61 ShapeUtil::ForEachSubshape(
62 literal.shape(),
63 [&](const Shape& subshape, const ShapeIndex& shape_index) {
64 if (subshape.IsArray()) {
65 if (subshape.element_type() ==
66 primitive_util::NativeToPrimitiveType<FromNativeT>()) {
67 auto src = literal.data<FromNativeT>(shape_index);
68 auto dest = result.data<ToNativeT>(shape_index);
69 for (int64_t i = 0, end = src.size(); i < end; ++i) {
70 dest[i] = static_cast<ToNativeT>(src[i]);
71 }
72 } else {
73 TF_CHECK_OK(result.CopyFrom(literal,
74 /*dest_shape_index=*/shape_index,
75 /*src_shape_index=*/shape_index));
76 }
77 }
78 });
79 return result;
80 }
81
82 template <PrimitiveType kType>
83 using NativeT = typename primitive_util::PrimitiveTypeToNative<kType>::type;
84
85 template <PrimitiveType kType, typename F, typename... Args>
CreateScalarImpl(F && value_provider,Args...args)86 Literal CreateScalarImpl(F&& value_provider, Args... args) {
87 return LiteralUtil::CreateR0<NativeT<kType>>(
88 value_provider(std::forward<Args>(args)...));
89 }
90
91 template <template <PrimitiveType> class F, typename... Args>
CreateScalar(PrimitiveType primitive_type,Args...args)92 Literal CreateScalar(PrimitiveType primitive_type, Args... args) {
93 switch (primitive_type) {
94 case U8:
95 return CreateScalarImpl<U8>(F<U8>{}, std::forward<Args>(args)...);
96 case U16:
97 return CreateScalarImpl<U16>(F<U16>{}, std::forward<Args>(args)...);
98 case U32:
99 return CreateScalarImpl<U32>(F<U32>{}, std::forward<Args>(args)...);
100 case U64:
101 return CreateScalarImpl<U64>(F<U64>{}, std::forward<Args>(args)...);
102 case S8:
103 return CreateScalarImpl<S8>(F<S8>{}, std::forward<Args>(args)...);
104 case S16:
105 return CreateScalarImpl<S16>(F<S16>{}, std::forward<Args>(args)...);
106 case S32:
107 return CreateScalarImpl<S32>(F<S32>{}, std::forward<Args>(args)...);
108 case S64:
109 return CreateScalarImpl<S64>(F<S64>{}, std::forward<Args>(args)...);
110 case F16:
111 return CreateScalarImpl<F16>(F<F16>{}, std::forward<Args>(args)...);
112 case BF16:
113 return CreateScalarImpl<BF16>(F<BF16>{}, std::forward<Args>(args)...);
114 case F32:
115 return CreateScalarImpl<F32>(F<F32>{}, std::forward<Args>(args)...);
116 case F64:
117 return CreateScalarImpl<F64>(F<F64>{}, std::forward<Args>(args)...);
118 case C64:
119 return CreateScalarImpl<C64>(F<C64>{}, std::forward<Args>(args)...);
120 case C128:
121 return CreateScalarImpl<C128>(F<C128>{}, std::forward<Args>(args)...);
122 case PRED:
123 return CreateScalarImpl<PRED>(F<PRED>{}, std::forward<Args>(args)...);
124 case TUPLE:
125 LOG(FATAL) << "tuple element type cannot be a scalar type.";
126 case OPAQUE_TYPE:
127 LOG(FATAL) << "opaque element type cannot be a scalar type.";
128 default:
129 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
130 }
131 }
132
133 template <PrimitiveType kType>
134 struct ZeroProvider {
operator ()xla::__anond64a65160111::ZeroProvider135 NativeT<kType> operator()() const { return static_cast<NativeT<kType>>(0); }
136 };
137
138 template <PrimitiveType kType>
139 struct OneProvider {
operator ()xla::__anond64a65160111::OneProvider140 NativeT<kType> operator()() const { return static_cast<NativeT<kType>>(1); }
141 };
142
143 template <typename T>
144 struct Is16BitFloat {
145 static constexpr bool value =
146 std::is_same<bfloat16, T>::value || std::is_same<half, T>::value;
147 };
148
149 template <typename T>
150 struct IsReal {
151 static constexpr bool value =
152 std::is_integral<T>::value || std::is_floating_point<T>::value ||
153 std::is_same<bfloat16, T>::value || std::is_same<half, T>::value;
154 };
155
156 template <typename T>
157 struct IsValidScalarType {
158 static constexpr bool value = IsReal<T>::value ||
159 std::is_same<complex64, T>::value ||
160 std::is_same<complex128, T>::value;
161 };
162
163 template <typename NativeT>
GetMaxImpl()164 std::enable_if_t<std::is_integral<NativeT>::value, NativeT> GetMaxImpl() {
165 return std::numeric_limits<NativeT>::max();
166 }
167
168 template <typename NativeT>
GetMinImpl()169 std::enable_if_t<std::is_integral<NativeT>::value, NativeT> GetMinImpl() {
170 return std::numeric_limits<NativeT>::min();
171 }
172
173 template <typename NativeT>
GetMaxImpl()174 std::enable_if_t<std::is_floating_point<NativeT>::value, NativeT> GetMaxImpl() {
175 return std::numeric_limits<NativeT>::infinity();
176 }
177
178 template <typename NativeT>
GetMinImpl()179 std::enable_if_t<std::is_floating_point<NativeT>::value, NativeT> GetMinImpl() {
180 return -std::numeric_limits<NativeT>::infinity();
181 }
182
183 template <typename NativeT>
GetMaxImpl()184 std::enable_if_t<Is16BitFloat<NativeT>::value, NativeT> GetMaxImpl() {
185 return static_cast<NativeT>(std::numeric_limits<float>::infinity());
186 }
187
188 template <typename NativeT>
GetMinImpl()189 std::enable_if_t<Is16BitFloat<NativeT>::value, NativeT> GetMinImpl() {
190 return static_cast<NativeT>(-std::numeric_limits<float>::infinity());
191 }
192
193 template <typename NativeT>
GetMaxImpl()194 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMaxImpl() {
195 LOG(FATAL) << "No max value for given type.";
196 }
197
198 template <typename NativeT>
GetMinImpl()199 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMinImpl() {
200 LOG(FATAL) << "No min value for given type.";
201 }
202
203 template <PrimitiveType kType>
204 struct MaxProvider {
operator ()xla::__anond64a65160111::MaxProvider205 NativeT<kType> operator()() const { return GetMaxImpl<NativeT<kType>>(); }
206 };
207
208 template <PrimitiveType kType>
209 struct MinProvider {
operator ()xla::__anond64a65160111::MinProvider210 NativeT<kType> operator()() const { return GetMinImpl<NativeT<kType>>(); }
211 };
212
213 template <PrimitiveType kType>
214 struct FirstElementProvider {
operator ()xla::__anond64a65160111::FirstElementProvider215 NativeT<kType> operator()(const LiteralBase& literal) const {
216 return literal.GetFirstElement<NativeT<kType>>();
217 }
218 };
219
220 template <typename NativeT>
GetMaxElementImpl(const LiteralBase & literal)221 std::enable_if_t<IsReal<NativeT>::value, NativeT> GetMaxElementImpl(
222 const LiteralBase& literal) {
223 auto view = literal.data<NativeT>();
224 return *absl::c_max_element(view);
225 }
226
227 template <typename NativeT>
GetMaxElementImpl(const LiteralBase & literal)228 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMaxElementImpl(
229 const LiteralBase& literal) {
230 LOG(FATAL) << "Unsupported type.";
231 }
232
233 template <PrimitiveType kType>
234 struct MaxElementProvider {
operator ()xla::__anond64a65160111::MaxElementProvider235 NativeT<kType> operator()(const LiteralBase& literal) const {
236 return GetMaxElementImpl<NativeT<kType>>(literal);
237 }
238 };
239
240 template <typename NativeT>
241 std::enable_if_t<IsValidScalarType<NativeT>::value, NativeT>
GetElementAtIndexImpl(const LiteralBase * literal,absl::Span<const int64_t> multi_index)242 GetElementAtIndexImpl(const LiteralBase* literal,
243 absl::Span<const int64_t> multi_index) {
244 return literal->Get<NativeT>(multi_index);
245 }
246
247 template <typename NativeT>
248 std::enable_if_t<!IsValidScalarType<NativeT>::value, NativeT>
GetElementAtIndexImpl(const LiteralBase * literal,absl::Span<const int64_t> multi_index)249 GetElementAtIndexImpl(const LiteralBase* literal,
250 absl::Span<const int64_t> multi_index) {
251 LOG(FATAL) << "Not a valid scalar element type.";
252 }
253
254 template <PrimitiveType kType>
255 struct GetElementAtIndexProvider {
operator ()xla::__anond64a65160111::GetElementAtIndexProvider256 NativeT<kType> operator()(const LiteralBase* literal,
257 absl::Span<const int64_t> multi_index) const {
258 DCHECK_EQ(literal->shape().element_type(), kType);
259 return GetElementAtIndexImpl<NativeT<kType>>(literal, multi_index);
260 }
261 };
262
263 template <PrimitiveType kType>
SetScalarAtIndexImpl(MutableLiteralBase & literal,absl::Span<const int64_t> multi_index,const LiteralBase & scalar)264 void SetScalarAtIndexImpl(MutableLiteralBase& literal,
265 absl::Span<const int64_t> multi_index,
266 const LiteralBase& scalar) {
267 DCHECK_EQ(literal.shape().element_type(), kType);
268 using NativeT = typename primitive_util::PrimitiveTypeToNative<kType>::type;
269 literal.Set<NativeT>(multi_index, scalar.Get<NativeT>({}));
270 }
271
272 } // namespace
273
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64_t> dimensions)274 /* static */ Literal LiteralUtil::CreateFromDimensions(
275 PrimitiveType primitive_type, absl::Span<const int64_t> dimensions) {
276 return Literal::CreateFromShape(
277 ShapeUtil::MakeShape(primitive_type, dimensions));
278 }
279
ConvertBF16ToF32(const LiteralSlice & bf16_literal)280 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
281 const LiteralSlice& bf16_literal) {
282 return ConvertType<bfloat16, float>(bf16_literal);
283 }
284
ConvertBF16ToF64(const LiteralSlice & bf16_literal)285 /* static */ Literal LiteralUtil::ConvertBF16ToF64(
286 const LiteralSlice& bf16_literal) {
287 return ConvertType<bfloat16, double>(bf16_literal);
288 }
289
ConvertF32ToBF16(const LiteralSlice & f32_literal)290 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
291 const LiteralSlice& f32_literal) {
292 return ConvertType<float, bfloat16>(f32_literal);
293 }
294
ConvertF32ToF64(const LiteralSlice & f32_literal)295 /* static */ Literal LiteralUtil::ConvertF32ToF64(
296 const LiteralSlice& f32_literal) {
297 return ConvertType<float, double>(f32_literal);
298 }
299
ConvertF64ToBF16(const LiteralSlice & f64_literal)300 /* static */ Literal LiteralUtil::ConvertF64ToBF16(
301 const LiteralSlice& f64_literal) {
302 return ConvertType<double, bfloat16>(f64_literal);
303 }
304
ConvertF64ToF32(const LiteralSlice & f64_literal)305 /* static */ Literal LiteralUtil::ConvertF64ToF32(
306 const LiteralSlice& f64_literal) {
307 return ConvertType<double, float>(f64_literal);
308 }
309
ConvertS32ToF32(const LiteralSlice & s32_literal)310 /* static */ Literal LiteralUtil::ConvertS32ToF32(
311 const LiteralSlice& s32_literal) {
312 return ConvertType<int32_t, float>(s32_literal);
313 }
314
CreateToken()315 /* static */ Literal LiteralUtil::CreateToken() {
316 return Literal(ShapeUtil::MakeTokenShape());
317 }
318
Zero(PrimitiveType primitive_type)319 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
320 return CreateScalar<ZeroProvider>(primitive_type);
321 }
322
One(PrimitiveType primitive_type)323 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
324 return CreateScalar<OneProvider>(primitive_type);
325 }
326
MinValue(PrimitiveType primitive_type)327 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
328 return CreateScalar<MinProvider>(primitive_type);
329 }
330
MaxValue(PrimitiveType primitive_type)331 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
332 return CreateScalar<MaxProvider>(primitive_type);
333 }
334
NanValue(PrimitiveType primitive_type)335 /* static */ StatusOr<Literal> LiteralUtil::NanValue(
336 PrimitiveType primitive_type) {
337 switch (primitive_type) {
338 case F16:
339 return LiteralUtil::CreateR0<half>(
340 static_cast<half>(std::numeric_limits<float>::quiet_NaN()));
341 case BF16:
342 return LiteralUtil::CreateR0<bfloat16>(
343 static_cast<bfloat16>(std::numeric_limits<float>::quiet_NaN()));
344 case F32:
345 return LiteralUtil::CreateR0<float>(
346 std::numeric_limits<float>::quiet_NaN());
347 case F64:
348 return LiteralUtil::CreateR0<double>(
349 std::numeric_limits<double>::quiet_NaN());
350 case C64: {
351 float nan = std::numeric_limits<float>::quiet_NaN();
352 return LiteralUtil::CreateR0<complex64>(complex64(nan, nan));
353 }
354 case C128: {
355 double nan = std::numeric_limits<double>::quiet_NaN();
356 return LiteralUtil::CreateR0<complex128>(complex128(nan, nan));
357 }
358 default:
359 return InvalidArgument("Invalid type for NanValue: %s",
360 PrimitiveType_Name(primitive_type));
361 }
362 }
363
CreateR1(const tensorflow::core::Bitmap & values)364 /* static */ Literal LiteralUtil::CreateR1(
365 const tensorflow::core::Bitmap& values) {
366 Literal literal(
367 ShapeUtil::MakeShape(PRED, {static_cast<int64_t>(values.bits())}));
368 literal.PopulateR1(values);
369 return literal;
370 }
371
CreateR1U8(absl::string_view value)372 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
373 Literal literal(
374 ShapeUtil::MakeShape(U8, {static_cast<int64_t>(value.size())}));
375 for (int i = 0, end = value.size(); i < end; ++i) {
376 literal.Set<uint8_t>({i}, value[i]);
377 }
378 return literal;
379 }
380
CreateR2F32Linspace(float from,float to,int64_t rows,int64_t cols)381 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
382 int64_t rows,
383 int64_t cols) {
384 auto value = MakeLinspaceArray2D(from, to, rows, cols);
385 return CreateR2FromArray2D(*value);
386 }
387
ReshapeSlice(absl::Span<const int64_t> new_dimensions,absl::Span<const int64_t> minor_to_major,const LiteralSlice & literal)388 /* static */ Literal LiteralUtil::ReshapeSlice(
389 absl::Span<const int64_t> new_dimensions,
390 absl::Span<const int64_t> minor_to_major, const LiteralSlice& literal) {
391 int64_t new_num_elements = 1;
392 for (int64_t i = 0, end = new_dimensions.size(); i < end; ++i) {
393 new_num_elements *= new_dimensions[i];
394 }
395 CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
396 CHECK_EQ(new_dimensions.size(), minor_to_major.size());
397
398 Literal new_literal(
399 ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
400
401 // Create a new shape with the given minor-to-major layout. This shape is used
402 // solely for converting linear address to multi-dimensional addresses when
403 // writing elements to the new literal.
404 Shape shape_with_layout = new_literal.shape();
405 *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
406
407 // Copy data into new literal, element-by-element.
408 for (int64_t i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
409 std::vector<int64_t> from_multi_index =
410 IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
411 std::vector<int64_t> to_multi_index =
412 IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
413 switch (literal.shape().element_type()) {
414 case PRED:
415 new_literal.Set<bool>(to_multi_index,
416 literal.Get<bool>(from_multi_index));
417 break;
418 case U8:
419 new_literal.Set<uint8_t>(to_multi_index,
420 literal.Get<uint8_t>(from_multi_index));
421 break;
422 case U32:
423 new_literal.Set<uint32_t>(to_multi_index,
424 literal.Get<uint32_t>(from_multi_index));
425 break;
426 case S32:
427 new_literal.Set<int32_t>(to_multi_index,
428 literal.Get<int32_t>(from_multi_index));
429 break;
430 case U64:
431 new_literal.Set<uint64_t>(to_multi_index,
432 literal.Get<uint64_t>(from_multi_index));
433 break;
434 case S64:
435 new_literal.Set<int64_t>(to_multi_index,
436 literal.Get<int64_t>(from_multi_index));
437 break;
438 case F32:
439 new_literal.Set<float>(to_multi_index,
440 literal.Get<float>(from_multi_index));
441 break;
442 case F64:
443 new_literal.Set<double>(to_multi_index,
444 literal.Get<double>(from_multi_index));
445 break;
446 case C64:
447 new_literal.Set<complex64>(to_multi_index,
448 literal.Get<complex64>(from_multi_index));
449 break;
450 case C128:
451 new_literal.Set<complex128>(to_multi_index,
452 literal.Get<complex128>(from_multi_index));
453 break;
454 default:
455 LOG(FATAL) << "Unhandled primitive element type: "
456 << PrimitiveType_Name(literal.shape().element_type());
457 }
458 }
459
460 return new_literal;
461 }
462
GetFirstScalarLiteral(const LiteralSlice & literal)463 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
464 const LiteralSlice& literal) {
465 CHECK(literal.shape().IsArray());
466 CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
467 return CreateScalar<FirstElementProvider>(literal.shape().element_type(),
468 literal);
469 }
470
GetScalarLiteral(const LiteralBase & literal,absl::Span<const int64_t> multi_index)471 /*static*/ Literal LiteralUtil::GetScalarLiteral(
472 const LiteralBase& literal, absl::Span<const int64_t> multi_index) {
473 return CreateScalar<GetElementAtIndexProvider>(literal.shape().element_type(),
474 &literal, multi_index);
475 }
476
SetScalarLiteral(MutableLiteralBase & literal,absl::Span<const int64_t> multi_index,const LiteralBase & scalar)477 /*static*/ void LiteralUtil::SetScalarLiteral(
478 MutableLiteralBase& literal, absl::Span<const int64_t> multi_index,
479 const LiteralBase& scalar) {
480 switch (literal.shape().element_type()) {
481 case PRED:
482 SetScalarAtIndexImpl<PRED>(literal, multi_index, scalar);
483 break;
484 case U8:
485 SetScalarAtIndexImpl<U8>(literal, multi_index, scalar);
486 break;
487 case U16:
488 SetScalarAtIndexImpl<U16>(literal, multi_index, scalar);
489 break;
490 case U32:
491 SetScalarAtIndexImpl<U32>(literal, multi_index, scalar);
492 break;
493 case U64:
494 SetScalarAtIndexImpl<U64>(literal, multi_index, scalar);
495 break;
496 case S8:
497 SetScalarAtIndexImpl<S8>(literal, multi_index, scalar);
498 break;
499 case S16:
500 SetScalarAtIndexImpl<S16>(literal, multi_index, scalar);
501 break;
502 case S32:
503 SetScalarAtIndexImpl<S32>(literal, multi_index, scalar);
504 break;
505 case S64:
506 SetScalarAtIndexImpl<S64>(literal, multi_index, scalar);
507 break;
508 case F16:
509 SetScalarAtIndexImpl<F16>(literal, multi_index, scalar);
510 break;
511 case BF16:
512 SetScalarAtIndexImpl<BF16>(literal, multi_index, scalar);
513 break;
514 case F32:
515 SetScalarAtIndexImpl<F32>(literal, multi_index, scalar);
516 break;
517 case F64:
518 SetScalarAtIndexImpl<F64>(literal, multi_index, scalar);
519 break;
520 case C64:
521 SetScalarAtIndexImpl<C64>(literal, multi_index, scalar);
522 break;
523 case C128:
524 SetScalarAtIndexImpl<C128>(literal, multi_index, scalar);
525 break;
526 default:
527 LOG(FATAL) << "Unsupported element type: "
528 << literal.shape().element_type();
529 }
530 }
531
MaxElement(const LiteralSlice & literal)532 /* static */ Literal LiteralUtil::MaxElement(const LiteralSlice& literal) {
533 CHECK(literal.shape().IsArray());
534 CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
535 return CreateScalar<MaxElementProvider>(literal.shape().element_type(),
536 literal);
537 }
538
MakeTuple(absl::Span<const Literal * const> elements)539 /* static */ Literal LiteralUtil::MakeTuple(
540 absl::Span<const Literal* const> elements) {
541 std::vector<const Shape*> element_shapes;
542 element_shapes.reserve(elements.size());
543 for (const auto* element : elements) {
544 element_shapes.push_back(&element->shape());
545 }
546 Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
547 for (int i = 0, end = elements.size(); i < end; ++i) {
548 TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
549 }
550 return literal;
551 }
552
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)553 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
554 absl::Span<const LiteralSlice> elements) {
555 std::vector<const Shape*> element_shapes;
556 element_shapes.reserve(elements.size());
557 for (const auto& element : elements) {
558 element_shapes.push_back(&element.shape());
559 }
560 Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
561 for (int i = 0, end = elements.size(); i < end; ++i) {
562 TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
563 }
564 return literal;
565 }
566
MakeTupleOwned(std::vector<Literal> elements)567 /* static */ Literal LiteralUtil::MakeTupleOwned(
568 std::vector<Literal> elements) {
569 std::vector<const Shape*> element_shapes;
570 element_shapes.reserve(elements.size());
571 for (const auto& element : elements) {
572 element_shapes.push_back(&element.shape());
573 }
574 Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
575 for (int64_t i = 0, end = elements.size(); i < end; ++i) {
576 TF_CHECK_OK(
577 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
578 }
579 return literal;
580 }
581
MultiIndexAsString(absl::Span<const int64_t> multi_index)582 /* static */ std::string LiteralUtil::MultiIndexAsString(
583 absl::Span<const int64_t> multi_index) {
584 return StrCat("{", absl::StrJoin(multi_index, ","), "}");
585 }
586
587 } // namespace xla
588