• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_util.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/index_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mem.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace xla {
40 namespace {
41 
42 using absl::StrCat;
43 
44 // Return a literal with all arrays of type FromNativeT converted to type
45 // ToNativeT in the given literal.
46 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)47 Literal ConvertType(LiteralSlice literal) {
48   // First construct shape of the result.
49   Shape result_shape(literal.shape());
50   ShapeUtil::ForEachMutableSubshape(
51       &result_shape, [](Shape* subshape, const ShapeIndex&) {
52         if (subshape->element_type() ==
53             primitive_util::NativeToPrimitiveType<FromNativeT>()) {
54           subshape->set_element_type(
55               primitive_util::NativeToPrimitiveType<ToNativeT>());
56         }
57       });
58   Literal result(result_shape);
59 
60   // Then copy over the data from 'literal' converting FromNativeT values to
61   // ToNativeT values as necessary.
62   ShapeUtil::ForEachSubshape(
63       literal.shape(),
64       [&](const Shape& subshape, const ShapeIndex& shape_index) {
65         if (subshape.IsArray()) {
66           if (subshape.element_type() ==
67               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
68             auto src = literal.data<FromNativeT>(shape_index);
69             auto dest = result.data<ToNativeT>(shape_index);
70             for (int64_t i = 0, end = src.size(); i < end; ++i) {
71               dest[i] = static_cast<ToNativeT>(src[i]);
72             }
73           } else {
74             TF_CHECK_OK(result.CopyFrom(literal,
75                                         /*dest_shape_index=*/shape_index,
76                                         /*src_shape_index=*/shape_index));
77           }
78         }
79       });
80   return result;
81 }
82 
83 }  // namespace
84 
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64> dimensions)85 /* static */ Literal LiteralUtil::CreateFromDimensions(
86     PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
87   return Literal::CreateFromShape(
88       ShapeUtil::MakeShape(primitive_type, dimensions));
89 }
90 
ConvertBF16ToF32(const LiteralSlice & bf16_literal)91 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
92     const LiteralSlice& bf16_literal) {
93   return ConvertType<bfloat16, float>(bf16_literal);
94 }
95 
ConvertBF16ToF64(const LiteralSlice & bf16_literal)96 /* static */ Literal LiteralUtil::ConvertBF16ToF64(
97     const LiteralSlice& bf16_literal) {
98   return ConvertType<bfloat16, double>(bf16_literal);
99 }
100 
ConvertF32ToBF16(const LiteralSlice & f32_literal)101 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
102     const LiteralSlice& f32_literal) {
103   return ConvertType<float, bfloat16>(f32_literal);
104 }
105 
ConvertF32ToF64(const LiteralSlice & f32_literal)106 /* static */ Literal LiteralUtil::ConvertF32ToF64(
107     const LiteralSlice& f32_literal) {
108   return ConvertType<float, double>(f32_literal);
109 }
110 
ConvertF64ToBF16(const LiteralSlice & f64_literal)111 /* static */ Literal LiteralUtil::ConvertF64ToBF16(
112     const LiteralSlice& f64_literal) {
113   return ConvertType<double, bfloat16>(f64_literal);
114 }
115 
ConvertF64ToF32(const LiteralSlice & f64_literal)116 /* static */ Literal LiteralUtil::ConvertF64ToF32(
117     const LiteralSlice& f64_literal) {
118   return ConvertType<double, float>(f64_literal);
119 }
120 
CreateToken()121 /* static */ Literal LiteralUtil::CreateToken() {
122   return Literal(ShapeUtil::MakeTokenShape());
123 }
124 
Zero(PrimitiveType primitive_type)125 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
126   switch (primitive_type) {
127     case U8:
128       return LiteralUtil::CreateR0<uint8>(0);
129     case U16:
130       return LiteralUtil::CreateR0<uint16>(0);
131     case U32:
132       return LiteralUtil::CreateR0<uint32>(0);
133     case U64:
134       return LiteralUtil::CreateR0<uint64>(0);
135     case S8:
136       return LiteralUtil::CreateR0<int8>(0);
137     case S16:
138       return LiteralUtil::CreateR0<int16>(0);
139     case S32:
140       return LiteralUtil::CreateR0<int32>(0);
141     case S64:
142       return LiteralUtil::CreateR0<int64>(0);
143     case F16:
144       return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
145     case BF16:
146       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
147     case F32:
148       return LiteralUtil::CreateR0<float>(0);
149     case F64:
150       return LiteralUtil::CreateR0<double>(0);
151     case C64:
152       return LiteralUtil::CreateR0<complex64>(0);
153     case C128:
154       return LiteralUtil::CreateR0<complex128>(0);
155     case PRED:
156       return LiteralUtil::CreateR0<bool>(false);
157     case TUPLE:
158       LOG(FATAL) << "tuple element type cannot take on value of 0";
159     case OPAQUE_TYPE:
160       LOG(FATAL) << "opaque element type cannot take on value of 0";
161     default:
162       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
163   }
164 }
165 
One(PrimitiveType primitive_type)166 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
167   switch (primitive_type) {
168     case U8:
169       return LiteralUtil::CreateR0<uint8>(1);
170     case U16:
171       return LiteralUtil::CreateR0<uint16>(1);
172     case U32:
173       return LiteralUtil::CreateR0<uint32>(1);
174     case U64:
175       return LiteralUtil::CreateR0<uint64>(1);
176     case S8:
177       return LiteralUtil::CreateR0<int8>(1);
178     case S16:
179       return LiteralUtil::CreateR0<int16>(1);
180     case S32:
181       return LiteralUtil::CreateR0<int32>(1);
182     case S64:
183       return LiteralUtil::CreateR0<int64>(1);
184     case F16:
185       return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
186     case BF16:
187       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
188     case F32:
189       return LiteralUtil::CreateR0<float>(1);
190     case F64:
191       return LiteralUtil::CreateR0<double>(1);
192     case C64:
193       return LiteralUtil::CreateR0<complex64>(1);
194     case C128:
195       return LiteralUtil::CreateR0<complex128>(1);
196     case PRED:
197       return LiteralUtil::CreateR0<bool>(true);
198     case TUPLE:
199       LOG(FATAL) << "tuple element type cannot take on value of 1";
200     case OPAQUE_TYPE:
201       LOG(FATAL) << "opaque element type cannot take on value of 1";
202     default:
203       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
204   }
205 }
206 
MinValue(PrimitiveType primitive_type)207 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
208   switch (primitive_type) {
209     case U8:
210       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
211     case U16:
212       return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::min());
213     case U32:
214       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
215     case U64:
216       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
217     case S8:
218       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
219     case S16:
220       return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::min());
221     case S32:
222       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
223     case S64:
224       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
225     case F32:
226       return LiteralUtil::CreateR0<float>(
227           -std::numeric_limits<float>::infinity());
228     case F64:
229       return LiteralUtil::CreateR0<double>(
230           -std::numeric_limits<double>::infinity());
231     case C64:
232       LOG(FATAL) << "C64 element type has no minimum value";
233     case C128:
234       LOG(FATAL) << "C128 element type has no minimum value";
235     case PRED:
236       return LiteralUtil::CreateR0<bool>(false);
237     case F16:
238       return LiteralUtil::CreateR0<half>(
239           static_cast<half>(-std::numeric_limits<float>::infinity()));
240     case BF16:
241       return LiteralUtil::CreateR0<bfloat16>(
242           static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
243     case TUPLE:
244       LOG(FATAL) << "tuple element type has no minimum value";
245     case OPAQUE_TYPE:
246       LOG(FATAL) << "opaque element type has no minimum value";
247     default:
248       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
249   }
250 }
251 
MaxValue(PrimitiveType primitive_type)252 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
253   switch (primitive_type) {
254     case U8:
255       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
256     case U16:
257       return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::max());
258     case U32:
259       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
260     case U64:
261       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
262     case S8:
263       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
264     case S16:
265       return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::max());
266     case S32:
267       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
268     case S64:
269       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
270     case F32:
271       return LiteralUtil::CreateR0<float>(
272           std::numeric_limits<float>::infinity());
273     case F64:
274       return LiteralUtil::CreateR0<double>(
275           std::numeric_limits<double>::infinity());
276     case PRED:
277       return LiteralUtil::CreateR0<bool>(true);
278     case F16:
279       return LiteralUtil::CreateR0<half>(
280           static_cast<half>(std::numeric_limits<float>::infinity()));
281     case BF16:
282       return LiteralUtil::CreateR0<bfloat16>(
283           static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
284     case TUPLE:
285       LOG(FATAL) << "tuple element type has no maximum value";
286     case OPAQUE_TYPE:
287       LOG(FATAL) << "opaque element type has no maximum value";
288     default:
289       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
290   }
291 }
292 
NanValue(PrimitiveType primitive_type)293 /* static */ StatusOr<Literal> LiteralUtil::NanValue(
294     PrimitiveType primitive_type) {
295   switch (primitive_type) {
296     case F16:
297       return LiteralUtil::CreateR0<half>(
298           static_cast<half>(std::numeric_limits<float>::quiet_NaN()));
299     case BF16:
300       return LiteralUtil::CreateR0<bfloat16>(
301           static_cast<bfloat16>(std::numeric_limits<float>::quiet_NaN()));
302     case F32:
303       return LiteralUtil::CreateR0<float>(
304           std::numeric_limits<float>::quiet_NaN());
305     case F64:
306       return LiteralUtil::CreateR0<double>(
307           std::numeric_limits<double>::quiet_NaN());
308     case C64: {
309       float nan = std::numeric_limits<float>::quiet_NaN();
310       return LiteralUtil::CreateR0<complex64>(complex64(nan, nan));
311     }
312     case C128: {
313       double nan = std::numeric_limits<double>::quiet_NaN();
314       return LiteralUtil::CreateR0<complex128>(complex128(nan, nan));
315     }
316     default:
317       return InvalidArgument("Invalid type for NanValue: %s",
318                              PrimitiveType_Name(primitive_type));
319   }
320 }
321 
CreateR1(const tensorflow::core::Bitmap & values)322 /* static */ Literal LiteralUtil::CreateR1(
323     const tensorflow::core::Bitmap& values) {
324   Literal literal(
325       ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
326   literal.PopulateR1(values);
327   return literal;
328 }
329 
CreateR1U8(absl::string_view value)330 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
331   Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
332   for (int i = 0, end = value.size(); i < end; ++i) {
333     literal.Set<uint8>({i}, value[i]);
334   }
335   return literal;
336 }
337 
CreateR2F32Linspace(float from,float to,int64_t rows,int64_t cols)338 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
339                                                       int64_t rows,
340                                                       int64_t cols) {
341   auto value = MakeLinspaceArray2D(from, to, rows, cols);
342   return CreateR2FromArray2D(*value);
343 }
344 
ReshapeSlice(absl::Span<const int64> new_dimensions,absl::Span<const int64> minor_to_major,const LiteralSlice & literal)345 /* static */ Literal LiteralUtil::ReshapeSlice(
346     absl::Span<const int64> new_dimensions,
347     absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
348   int64_t new_num_elements = 1;
349   for (int64_t i = 0, end = new_dimensions.size(); i < end; ++i) {
350     new_num_elements *= new_dimensions[i];
351   }
352   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
353   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
354 
355   Literal new_literal(
356       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
357 
358   // Create a new shape with the given minor-to-major layout. This shape is used
359   // solely for converting linear address to multi-dimensional addresses when
360   // writing elements to the new literal.
361   Shape shape_with_layout = new_literal.shape();
362   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
363 
364   // Copy data into new literal, element-by-element.
365   for (int64_t i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
366     std::vector<int64> from_multi_index =
367         IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
368     std::vector<int64> to_multi_index =
369         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
370     switch (literal.shape().element_type()) {
371       case PRED:
372         new_literal.Set<bool>(to_multi_index,
373                               literal.Get<bool>(from_multi_index));
374         break;
375       case U8:
376         new_literal.Set<uint8>(to_multi_index,
377                                literal.Get<uint8>(from_multi_index));
378         break;
379       case U32:
380         new_literal.Set<uint32>(to_multi_index,
381                                 literal.Get<uint32>(from_multi_index));
382         break;
383       case S32:
384         new_literal.Set<int32>(to_multi_index,
385                                literal.Get<int32>(from_multi_index));
386         break;
387       case U64:
388         new_literal.Set<uint64>(to_multi_index,
389                                 literal.Get<uint64>(from_multi_index));
390         break;
391       case S64:
392         new_literal.Set<int64>(to_multi_index,
393                                literal.Get<int64>(from_multi_index));
394         break;
395       case F32:
396         new_literal.Set<float>(to_multi_index,
397                                literal.Get<float>(from_multi_index));
398         break;
399       case F64:
400         new_literal.Set<double>(to_multi_index,
401                                 literal.Get<double>(from_multi_index));
402         break;
403       case C64:
404         new_literal.Set<complex64>(to_multi_index,
405                                    literal.Get<complex64>(from_multi_index));
406         break;
407       case C128:
408         new_literal.Set<complex128>(to_multi_index,
409                                     literal.Get<complex128>(from_multi_index));
410         break;
411       default:
412         LOG(FATAL) << "Unhandled primitive element type: "
413                    << PrimitiveType_Name(literal.shape().element_type());
414     }
415   }
416 
417   return new_literal;
418 }
419 
GetFirstScalarLiteral(const LiteralSlice & literal)420 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
421     const LiteralSlice& literal) {
422   CHECK(literal.shape().IsArray());
423   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
424   switch (literal.shape().element_type()) {
425     case PRED:
426       return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
427     // 8 bit types.
428     case S8:
429       return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
430     case U8:
431       return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
432     // 16 bit types.
433     case BF16:
434       return LiteralUtil::CreateR0<bfloat16>(
435           literal.GetFirstElement<bfloat16>());
436     case F16:
437       return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
438     case S16:
439       return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
440     case U16:
441       return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
442     // 32 bit types.
443     case F32:
444       return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
445     case S32:
446       return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
447     case U32:
448       return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
449     // 64 bit types.
450     case C64:
451       return LiteralUtil::CreateR0<complex64>(
452           literal.GetFirstElement<complex64>());
453     case F64:
454       return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
455     case S64:
456       return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
457     case U64:
458       return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
459 
460     case C128:
461       return LiteralUtil::CreateR0<complex128>(
462           literal.GetFirstElement<complex128>());
463     default:
464       LOG(FATAL) << "Unhandled primitive type "
465                  << literal.shape().element_type();
466   }
467 }
468 
MaxElement(const LiteralSlice & literal)469 /* static */ Literal LiteralUtil::MaxElement(const LiteralSlice& literal) {
470   CHECK(literal.shape().IsArray());
471   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
472   switch (literal.shape().element_type()) {
473     case PRED: {
474       auto view = literal.data<bool>();
475       return LiteralUtil::CreateR0<bool>(*absl::c_max_element(view));
476     }
477     // 8 bit types.
478     case S8: {
479       auto view = literal.data<int8>();
480       return LiteralUtil::CreateR0<int8>(*absl::c_max_element(view));
481     }
482     case U8: {
483       auto view = literal.data<uint8>();
484       return LiteralUtil::CreateR0<uint8>(*absl::c_max_element(view));
485     }
486     // 16 bit types.
487     case BF16: {
488       auto view = literal.data<bfloat16>();
489       return LiteralUtil::CreateR0<bfloat16>(*absl::c_max_element(view));
490     }
491     case F16: {
492       auto view = literal.data<half>();
493       return LiteralUtil::CreateR0<half>(*absl::c_max_element(view));
494     }
495     case S16: {
496       auto view = literal.data<int16>();
497       return LiteralUtil::CreateR0<int16>(*absl::c_max_element(view));
498     }
499     case U16: {
500       auto view = literal.data<uint16>();
501       return LiteralUtil::CreateR0<uint16>(*absl::c_max_element(view));
502     }
503     // 32 bit types.
504     case F32: {
505       auto view = literal.data<float>();
506       return LiteralUtil::CreateR0<float>(*absl::c_max_element(view));
507     }
508     case S32: {
509       auto view = literal.data<int32>();
510       return LiteralUtil::CreateR0<int32>(*absl::c_max_element(view));
511     }
512     case U32: {
513       auto view = literal.data<uint32>();
514       return LiteralUtil::CreateR0<uint32>(*absl::c_max_element(view));
515     }
516     case F64: {
517       auto view = literal.data<double>();
518       return LiteralUtil::CreateR0<double>(*absl::c_max_element(view));
519     }
520     case S64: {
521       auto view = literal.data<int64>();
522       return LiteralUtil::CreateR0<int64>(*absl::c_max_element(view));
523     }
524     case U64: {
525       auto view = literal.data<uint64>();
526       return LiteralUtil::CreateR0<uint64>(*absl::c_max_element(view));
527     }
528     default:
529       LOG(FATAL) << "Unhandled primitive type "
530                  << literal.shape().element_type();
531   }
532 }
533 
MakeTuple(absl::Span<const Literal * const> elements)534 /* static */ Literal LiteralUtil::MakeTuple(
535     absl::Span<const Literal* const> elements) {
536   std::vector<Shape> element_shapes;
537   for (const auto* element : elements) {
538     element_shapes.push_back(element->shape());
539   }
540   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
541   for (int i = 0, end = elements.size(); i < end; ++i) {
542     TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
543   }
544   return literal;
545 }
546 
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)547 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
548     absl::Span<const LiteralSlice> elements) {
549   std::vector<Shape> element_shapes;
550   for (const auto& element : elements) {
551     element_shapes.push_back(element.shape());
552   }
553   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
554   for (int i = 0, end = elements.size(); i < end; ++i) {
555     TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
556   }
557   return literal;
558 }
559 
MakeTupleOwned(std::vector<Literal> elements)560 /* static */ Literal LiteralUtil::MakeTupleOwned(
561     std::vector<Literal> elements) {
562   std::vector<Shape> element_shapes;
563   element_shapes.reserve(elements.size());
564   for (const auto& element : elements) {
565     element_shapes.push_back(element.shape());
566   }
567   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
568   for (int64_t i = 0, end = elements.size(); i < end; ++i) {
569     TF_CHECK_OK(
570         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
571   }
572   return literal;
573 }
574 
MultiIndexAsString(absl::Span<const int64> multi_index)575 /* static */ string LiteralUtil::MultiIndexAsString(
576     absl::Span<const int64> multi_index) {
577   return StrCat("{", absl::StrJoin(multi_index, ","), "}");
578 }
579 
580 }  // namespace xla
581