• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_util.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27 
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/index_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 namespace {
40 
41 using absl::StrCat;
42 
43 // Return a literal with all arrays of type FromNativeT converted to type
44 // ToNativeT in the given literal.
45 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)46 Literal ConvertType(LiteralSlice literal) {
47   // First construct shape of the result.
48   Shape result_shape(literal.shape());
49   ShapeUtil::ForEachMutableSubshape(
50       &result_shape, [](Shape* subshape, const ShapeIndex&) {
51         if (subshape->element_type() ==
52             primitive_util::NativeToPrimitiveType<FromNativeT>()) {
53           subshape->set_element_type(
54               primitive_util::NativeToPrimitiveType<ToNativeT>());
55         }
56       });
57   Literal result(result_shape);
58 
59   // Then copy over the data from 'literal' converting FromNativeT values to
60   // ToNativeT values as necessary.
61   ShapeUtil::ForEachSubshape(
62       literal.shape(),
63       [&](const Shape& subshape, const ShapeIndex& shape_index) {
64         if (subshape.IsArray()) {
65           if (subshape.element_type() ==
66               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
67             auto src = literal.data<FromNativeT>(shape_index);
68             auto dest = result.data<ToNativeT>(shape_index);
69             for (int64_t i = 0, end = src.size(); i < end; ++i) {
70               dest[i] = static_cast<ToNativeT>(src[i]);
71             }
72           } else {
73             TF_CHECK_OK(result.CopyFrom(literal,
74                                         /*dest_shape_index=*/shape_index,
75                                         /*src_shape_index=*/shape_index));
76           }
77         }
78       });
79   return result;
80 }
81 
82 template <PrimitiveType kType>
83 using NativeT = typename primitive_util::PrimitiveTypeToNative<kType>::type;
84 
85 template <PrimitiveType kType, typename F, typename... Args>
CreateScalarImpl(F && value_provider,Args...args)86 Literal CreateScalarImpl(F&& value_provider, Args... args) {
87   return LiteralUtil::CreateR0<NativeT<kType>>(
88       value_provider(std::forward<Args>(args)...));
89 }
90 
91 template <template <PrimitiveType> class F, typename... Args>
CreateScalar(PrimitiveType primitive_type,Args...args)92 Literal CreateScalar(PrimitiveType primitive_type, Args... args) {
93   switch (primitive_type) {
94     case U8:
95       return CreateScalarImpl<U8>(F<U8>{}, std::forward<Args>(args)...);
96     case U16:
97       return CreateScalarImpl<U16>(F<U16>{}, std::forward<Args>(args)...);
98     case U32:
99       return CreateScalarImpl<U32>(F<U32>{}, std::forward<Args>(args)...);
100     case U64:
101       return CreateScalarImpl<U64>(F<U64>{}, std::forward<Args>(args)...);
102     case S8:
103       return CreateScalarImpl<S8>(F<S8>{}, std::forward<Args>(args)...);
104     case S16:
105       return CreateScalarImpl<S16>(F<S16>{}, std::forward<Args>(args)...);
106     case S32:
107       return CreateScalarImpl<S32>(F<S32>{}, std::forward<Args>(args)...);
108     case S64:
109       return CreateScalarImpl<S64>(F<S64>{}, std::forward<Args>(args)...);
110     case F16:
111       return CreateScalarImpl<F16>(F<F16>{}, std::forward<Args>(args)...);
112     case BF16:
113       return CreateScalarImpl<BF16>(F<BF16>{}, std::forward<Args>(args)...);
114     case F32:
115       return CreateScalarImpl<F32>(F<F32>{}, std::forward<Args>(args)...);
116     case F64:
117       return CreateScalarImpl<F64>(F<F64>{}, std::forward<Args>(args)...);
118     case C64:
119       return CreateScalarImpl<C64>(F<C64>{}, std::forward<Args>(args)...);
120     case C128:
121       return CreateScalarImpl<C128>(F<C128>{}, std::forward<Args>(args)...);
122     case PRED:
123       return CreateScalarImpl<PRED>(F<PRED>{}, std::forward<Args>(args)...);
124     case TUPLE:
125       LOG(FATAL) << "tuple element type cannot be a scalar type.";
126     case OPAQUE_TYPE:
127       LOG(FATAL) << "opaque element type cannot be a scalar type.";
128     default:
129       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
130   }
131 }
132 
133 template <PrimitiveType kType>
134 struct ZeroProvider {
operator ()xla::__anond64a65160111::ZeroProvider135   NativeT<kType> operator()() const { return static_cast<NativeT<kType>>(0); }
136 };
137 
138 template <PrimitiveType kType>
139 struct OneProvider {
operator ()xla::__anond64a65160111::OneProvider140   NativeT<kType> operator()() const { return static_cast<NativeT<kType>>(1); }
141 };
142 
143 template <typename T>
144 struct Is16BitFloat {
145   static constexpr bool value =
146       std::is_same<bfloat16, T>::value || std::is_same<half, T>::value;
147 };
148 
149 template <typename T>
150 struct IsReal {
151   static constexpr bool value =
152       std::is_integral<T>::value || std::is_floating_point<T>::value ||
153       std::is_same<bfloat16, T>::value || std::is_same<half, T>::value;
154 };
155 
156 template <typename T>
157 struct IsValidScalarType {
158   static constexpr bool value = IsReal<T>::value ||
159                                 std::is_same<complex64, T>::value ||
160                                 std::is_same<complex128, T>::value;
161 };
162 
163 template <typename NativeT>
GetMaxImpl()164 std::enable_if_t<std::is_integral<NativeT>::value, NativeT> GetMaxImpl() {
165   return std::numeric_limits<NativeT>::max();
166 }
167 
168 template <typename NativeT>
GetMinImpl()169 std::enable_if_t<std::is_integral<NativeT>::value, NativeT> GetMinImpl() {
170   return std::numeric_limits<NativeT>::min();
171 }
172 
173 template <typename NativeT>
GetMaxImpl()174 std::enable_if_t<std::is_floating_point<NativeT>::value, NativeT> GetMaxImpl() {
175   return std::numeric_limits<NativeT>::infinity();
176 }
177 
178 template <typename NativeT>
GetMinImpl()179 std::enable_if_t<std::is_floating_point<NativeT>::value, NativeT> GetMinImpl() {
180   return -std::numeric_limits<NativeT>::infinity();
181 }
182 
183 template <typename NativeT>
GetMaxImpl()184 std::enable_if_t<Is16BitFloat<NativeT>::value, NativeT> GetMaxImpl() {
185   return static_cast<NativeT>(std::numeric_limits<float>::infinity());
186 }
187 
188 template <typename NativeT>
GetMinImpl()189 std::enable_if_t<Is16BitFloat<NativeT>::value, NativeT> GetMinImpl() {
190   return static_cast<NativeT>(-std::numeric_limits<float>::infinity());
191 }
192 
193 template <typename NativeT>
GetMaxImpl()194 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMaxImpl() {
195   LOG(FATAL) << "No max value for given type.";
196 }
197 
198 template <typename NativeT>
GetMinImpl()199 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMinImpl() {
200   LOG(FATAL) << "No min value for given type.";
201 }
202 
203 template <PrimitiveType kType>
204 struct MaxProvider {
operator ()xla::__anond64a65160111::MaxProvider205   NativeT<kType> operator()() const { return GetMaxImpl<NativeT<kType>>(); }
206 };
207 
208 template <PrimitiveType kType>
209 struct MinProvider {
operator ()xla::__anond64a65160111::MinProvider210   NativeT<kType> operator()() const { return GetMinImpl<NativeT<kType>>(); }
211 };
212 
213 template <PrimitiveType kType>
214 struct FirstElementProvider {
operator ()xla::__anond64a65160111::FirstElementProvider215   NativeT<kType> operator()(const LiteralBase& literal) const {
216     return literal.GetFirstElement<NativeT<kType>>();
217   }
218 };
219 
220 template <typename NativeT>
GetMaxElementImpl(const LiteralBase & literal)221 std::enable_if_t<IsReal<NativeT>::value, NativeT> GetMaxElementImpl(
222     const LiteralBase& literal) {
223   auto view = literal.data<NativeT>();
224   return *absl::c_max_element(view);
225 }
226 
227 template <typename NativeT>
GetMaxElementImpl(const LiteralBase & literal)228 std::enable_if_t<!IsReal<NativeT>::value, NativeT> GetMaxElementImpl(
229     const LiteralBase& literal) {
230   LOG(FATAL) << "Unsupported type.";
231 }
232 
233 template <PrimitiveType kType>
234 struct MaxElementProvider {
operator ()xla::__anond64a65160111::MaxElementProvider235   NativeT<kType> operator()(const LiteralBase& literal) const {
236     return GetMaxElementImpl<NativeT<kType>>(literal);
237   }
238 };
239 
240 template <typename NativeT>
241 std::enable_if_t<IsValidScalarType<NativeT>::value, NativeT>
GetElementAtIndexImpl(const LiteralBase * literal,absl::Span<const int64_t> multi_index)242 GetElementAtIndexImpl(const LiteralBase* literal,
243                       absl::Span<const int64_t> multi_index) {
244   return literal->Get<NativeT>(multi_index);
245 }
246 
247 template <typename NativeT>
248 std::enable_if_t<!IsValidScalarType<NativeT>::value, NativeT>
GetElementAtIndexImpl(const LiteralBase * literal,absl::Span<const int64_t> multi_index)249 GetElementAtIndexImpl(const LiteralBase* literal,
250                       absl::Span<const int64_t> multi_index) {
251   LOG(FATAL) << "Not a valid scalar element type.";
252 }
253 
254 template <PrimitiveType kType>
255 struct GetElementAtIndexProvider {
operator ()xla::__anond64a65160111::GetElementAtIndexProvider256   NativeT<kType> operator()(const LiteralBase* literal,
257                             absl::Span<const int64_t> multi_index) const {
258     DCHECK_EQ(literal->shape().element_type(), kType);
259     return GetElementAtIndexImpl<NativeT<kType>>(literal, multi_index);
260   }
261 };
262 
263 template <PrimitiveType kType>
SetScalarAtIndexImpl(MutableLiteralBase & literal,absl::Span<const int64_t> multi_index,const LiteralBase & scalar)264 void SetScalarAtIndexImpl(MutableLiteralBase& literal,
265                           absl::Span<const int64_t> multi_index,
266                           const LiteralBase& scalar) {
267   DCHECK_EQ(literal.shape().element_type(), kType);
268   using NativeT = typename primitive_util::PrimitiveTypeToNative<kType>::type;
269   literal.Set<NativeT>(multi_index, scalar.Get<NativeT>({}));
270 }
271 
272 }  // namespace
273 
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64_t> dimensions)274 /* static */ Literal LiteralUtil::CreateFromDimensions(
275     PrimitiveType primitive_type, absl::Span<const int64_t> dimensions) {
276   return Literal::CreateFromShape(
277       ShapeUtil::MakeShape(primitive_type, dimensions));
278 }
279 
ConvertBF16ToF32(const LiteralSlice & bf16_literal)280 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
281     const LiteralSlice& bf16_literal) {
282   return ConvertType<bfloat16, float>(bf16_literal);
283 }
284 
ConvertBF16ToF64(const LiteralSlice & bf16_literal)285 /* static */ Literal LiteralUtil::ConvertBF16ToF64(
286     const LiteralSlice& bf16_literal) {
287   return ConvertType<bfloat16, double>(bf16_literal);
288 }
289 
ConvertF32ToBF16(const LiteralSlice & f32_literal)290 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
291     const LiteralSlice& f32_literal) {
292   return ConvertType<float, bfloat16>(f32_literal);
293 }
294 
ConvertF32ToF64(const LiteralSlice & f32_literal)295 /* static */ Literal LiteralUtil::ConvertF32ToF64(
296     const LiteralSlice& f32_literal) {
297   return ConvertType<float, double>(f32_literal);
298 }
299 
ConvertF64ToBF16(const LiteralSlice & f64_literal)300 /* static */ Literal LiteralUtil::ConvertF64ToBF16(
301     const LiteralSlice& f64_literal) {
302   return ConvertType<double, bfloat16>(f64_literal);
303 }
304 
ConvertF64ToF32(const LiteralSlice & f64_literal)305 /* static */ Literal LiteralUtil::ConvertF64ToF32(
306     const LiteralSlice& f64_literal) {
307   return ConvertType<double, float>(f64_literal);
308 }
309 
ConvertS32ToF32(const LiteralSlice & s32_literal)310 /* static */ Literal LiteralUtil::ConvertS32ToF32(
311     const LiteralSlice& s32_literal) {
312   return ConvertType<int32_t, float>(s32_literal);
313 }
314 
CreateToken()315 /* static */ Literal LiteralUtil::CreateToken() {
316   return Literal(ShapeUtil::MakeTokenShape());
317 }
318 
Zero(PrimitiveType primitive_type)319 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
320   return CreateScalar<ZeroProvider>(primitive_type);
321 }
322 
One(PrimitiveType primitive_type)323 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
324   return CreateScalar<OneProvider>(primitive_type);
325 }
326 
MinValue(PrimitiveType primitive_type)327 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
328   return CreateScalar<MinProvider>(primitive_type);
329 }
330 
MaxValue(PrimitiveType primitive_type)331 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
332   return CreateScalar<MaxProvider>(primitive_type);
333 }
334 
NanValue(PrimitiveType primitive_type)335 /* static */ StatusOr<Literal> LiteralUtil::NanValue(
336     PrimitiveType primitive_type) {
337   switch (primitive_type) {
338     case F16:
339       return LiteralUtil::CreateR0<half>(
340           static_cast<half>(std::numeric_limits<float>::quiet_NaN()));
341     case BF16:
342       return LiteralUtil::CreateR0<bfloat16>(
343           static_cast<bfloat16>(std::numeric_limits<float>::quiet_NaN()));
344     case F32:
345       return LiteralUtil::CreateR0<float>(
346           std::numeric_limits<float>::quiet_NaN());
347     case F64:
348       return LiteralUtil::CreateR0<double>(
349           std::numeric_limits<double>::quiet_NaN());
350     case C64: {
351       float nan = std::numeric_limits<float>::quiet_NaN();
352       return LiteralUtil::CreateR0<complex64>(complex64(nan, nan));
353     }
354     case C128: {
355       double nan = std::numeric_limits<double>::quiet_NaN();
356       return LiteralUtil::CreateR0<complex128>(complex128(nan, nan));
357     }
358     default:
359       return InvalidArgument("Invalid type for NanValue: %s",
360                              PrimitiveType_Name(primitive_type));
361   }
362 }
363 
CreateR1(const tensorflow::core::Bitmap & values)364 /* static */ Literal LiteralUtil::CreateR1(
365     const tensorflow::core::Bitmap& values) {
366   Literal literal(
367       ShapeUtil::MakeShape(PRED, {static_cast<int64_t>(values.bits())}));
368   literal.PopulateR1(values);
369   return literal;
370 }
371 
CreateR1U8(absl::string_view value)372 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
373   Literal literal(
374       ShapeUtil::MakeShape(U8, {static_cast<int64_t>(value.size())}));
375   for (int i = 0, end = value.size(); i < end; ++i) {
376     literal.Set<uint8_t>({i}, value[i]);
377   }
378   return literal;
379 }
380 
CreateR2F32Linspace(float from,float to,int64_t rows,int64_t cols)381 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
382                                                       int64_t rows,
383                                                       int64_t cols) {
384   auto value = MakeLinspaceArray2D(from, to, rows, cols);
385   return CreateR2FromArray2D(*value);
386 }
387 
ReshapeSlice(absl::Span<const int64_t> new_dimensions,absl::Span<const int64_t> minor_to_major,const LiteralSlice & literal)388 /* static */ Literal LiteralUtil::ReshapeSlice(
389     absl::Span<const int64_t> new_dimensions,
390     absl::Span<const int64_t> minor_to_major, const LiteralSlice& literal) {
391   int64_t new_num_elements = 1;
392   for (int64_t i = 0, end = new_dimensions.size(); i < end; ++i) {
393     new_num_elements *= new_dimensions[i];
394   }
395   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
396   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
397 
398   Literal new_literal(
399       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
400 
401   // Create a new shape with the given minor-to-major layout. This shape is used
402   // solely for converting linear address to multi-dimensional addresses when
403   // writing elements to the new literal.
404   Shape shape_with_layout = new_literal.shape();
405   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
406 
407   // Copy data into new literal, element-by-element.
408   for (int64_t i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
409     std::vector<int64_t> from_multi_index =
410         IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
411     std::vector<int64_t> to_multi_index =
412         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
413     switch (literal.shape().element_type()) {
414       case PRED:
415         new_literal.Set<bool>(to_multi_index,
416                               literal.Get<bool>(from_multi_index));
417         break;
418       case U8:
419         new_literal.Set<uint8_t>(to_multi_index,
420                                  literal.Get<uint8_t>(from_multi_index));
421         break;
422       case U32:
423         new_literal.Set<uint32_t>(to_multi_index,
424                                   literal.Get<uint32_t>(from_multi_index));
425         break;
426       case S32:
427         new_literal.Set<int32_t>(to_multi_index,
428                                  literal.Get<int32_t>(from_multi_index));
429         break;
430       case U64:
431         new_literal.Set<uint64_t>(to_multi_index,
432                                   literal.Get<uint64_t>(from_multi_index));
433         break;
434       case S64:
435         new_literal.Set<int64_t>(to_multi_index,
436                                  literal.Get<int64_t>(from_multi_index));
437         break;
438       case F32:
439         new_literal.Set<float>(to_multi_index,
440                                literal.Get<float>(from_multi_index));
441         break;
442       case F64:
443         new_literal.Set<double>(to_multi_index,
444                                 literal.Get<double>(from_multi_index));
445         break;
446       case C64:
447         new_literal.Set<complex64>(to_multi_index,
448                                    literal.Get<complex64>(from_multi_index));
449         break;
450       case C128:
451         new_literal.Set<complex128>(to_multi_index,
452                                     literal.Get<complex128>(from_multi_index));
453         break;
454       default:
455         LOG(FATAL) << "Unhandled primitive element type: "
456                    << PrimitiveType_Name(literal.shape().element_type());
457     }
458   }
459 
460   return new_literal;
461 }
462 
GetFirstScalarLiteral(const LiteralSlice & literal)463 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
464     const LiteralSlice& literal) {
465   CHECK(literal.shape().IsArray());
466   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
467   return CreateScalar<FirstElementProvider>(literal.shape().element_type(),
468                                             literal);
469 }
470 
GetScalarLiteral(const LiteralBase & literal,absl::Span<const int64_t> multi_index)471 /*static*/ Literal LiteralUtil::GetScalarLiteral(
472     const LiteralBase& literal, absl::Span<const int64_t> multi_index) {
473   return CreateScalar<GetElementAtIndexProvider>(literal.shape().element_type(),
474                                                  &literal, multi_index);
475 }
476 
SetScalarLiteral(MutableLiteralBase & literal,absl::Span<const int64_t> multi_index,const LiteralBase & scalar)477 /*static*/ void LiteralUtil::SetScalarLiteral(
478     MutableLiteralBase& literal, absl::Span<const int64_t> multi_index,
479     const LiteralBase& scalar) {
480   switch (literal.shape().element_type()) {
481     case PRED:
482       SetScalarAtIndexImpl<PRED>(literal, multi_index, scalar);
483       break;
484     case U8:
485       SetScalarAtIndexImpl<U8>(literal, multi_index, scalar);
486       break;
487     case U16:
488       SetScalarAtIndexImpl<U16>(literal, multi_index, scalar);
489       break;
490     case U32:
491       SetScalarAtIndexImpl<U32>(literal, multi_index, scalar);
492       break;
493     case U64:
494       SetScalarAtIndexImpl<U64>(literal, multi_index, scalar);
495       break;
496     case S8:
497       SetScalarAtIndexImpl<S8>(literal, multi_index, scalar);
498       break;
499     case S16:
500       SetScalarAtIndexImpl<S16>(literal, multi_index, scalar);
501       break;
502     case S32:
503       SetScalarAtIndexImpl<S32>(literal, multi_index, scalar);
504       break;
505     case S64:
506       SetScalarAtIndexImpl<S64>(literal, multi_index, scalar);
507       break;
508     case F16:
509       SetScalarAtIndexImpl<F16>(literal, multi_index, scalar);
510       break;
511     case BF16:
512       SetScalarAtIndexImpl<BF16>(literal, multi_index, scalar);
513       break;
514     case F32:
515       SetScalarAtIndexImpl<F32>(literal, multi_index, scalar);
516       break;
517     case F64:
518       SetScalarAtIndexImpl<F64>(literal, multi_index, scalar);
519       break;
520     case C64:
521       SetScalarAtIndexImpl<C64>(literal, multi_index, scalar);
522       break;
523     case C128:
524       SetScalarAtIndexImpl<C128>(literal, multi_index, scalar);
525       break;
526     default:
527       LOG(FATAL) << "Unsupported element type: "
528                  << literal.shape().element_type();
529   }
530 }
531 
MaxElement(const LiteralSlice & literal)532 /* static */ Literal LiteralUtil::MaxElement(const LiteralSlice& literal) {
533   CHECK(literal.shape().IsArray());
534   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
535   return CreateScalar<MaxElementProvider>(literal.shape().element_type(),
536                                           literal);
537 }
538 
MakeTuple(absl::Span<const Literal * const> elements)539 /* static */ Literal LiteralUtil::MakeTuple(
540     absl::Span<const Literal* const> elements) {
541   std::vector<const Shape*> element_shapes;
542   element_shapes.reserve(elements.size());
543   for (const auto* element : elements) {
544     element_shapes.push_back(&element->shape());
545   }
546   Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
547   for (int i = 0, end = elements.size(); i < end; ++i) {
548     TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
549   }
550   return literal;
551 }
552 
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)553 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
554     absl::Span<const LiteralSlice> elements) {
555   std::vector<const Shape*> element_shapes;
556   element_shapes.reserve(elements.size());
557   for (const auto& element : elements) {
558     element_shapes.push_back(&element.shape());
559   }
560   Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
561   for (int i = 0, end = elements.size(); i < end; ++i) {
562     TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
563   }
564   return literal;
565 }
566 
MakeTupleOwned(std::vector<Literal> elements)567 /* static */ Literal LiteralUtil::MakeTupleOwned(
568     std::vector<Literal> elements) {
569   std::vector<const Shape*> element_shapes;
570   element_shapes.reserve(elements.size());
571   for (const auto& element : elements) {
572     element_shapes.push_back(&element.shape());
573   }
574   Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes));
575   for (int64_t i = 0, end = elements.size(); i < end; ++i) {
576     TF_CHECK_OK(
577         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
578   }
579   return literal;
580 }
581 
MultiIndexAsString(absl::Span<const int64_t> multi_index)582 /* static */ std::string LiteralUtil::MultiIndexAsString(
583     absl::Span<const int64_t> multi_index) {
584   return StrCat("{", absl::StrJoin(multi_index, ","), "}");
585 }
586 
587 }  // namespace xla
588