• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24 
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/index_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/mem.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 namespace {
46 
47 using absl::StrCat;
48 
49 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
50 
51 // Converts between little and big endian.
52 //
53 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)54 void ConvertEndianShort(string* bytes) {
55   CHECK_EQ(bytes->size() / 2, 0);
56   for (int64 i = 0; i < bytes->size(); i += 2) {
57     std::swap((*bytes)[i], (*bytes)[i + 1]);
58   }
59 }
60 
ConvertEndianShort(char * bytes,int64 size)61 void ConvertEndianShort(char* bytes, int64 size) {
62   CHECK_EQ(size / 2, 0);
63   for (int64 i = 0; i < size; i += 2) {
64     std::swap(bytes[i], bytes[i + 1]);
65   }
66 }
67 
68 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
69 // able to transparently access the raw 16-bit value contained within.
70 template <typename T>
GetRawValue(T val)71 T GetRawValue(T val) {
72   return val;
73 }
GetRawValue(Eigen::half val)74 uint16 GetRawValue(Eigen::half val) { return val.x; }
75 
LiteralProtoHasValues(const LiteralProto & proto)76 bool LiteralProtoHasValues(const LiteralProto& proto) {
77   return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
78          proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
79          proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
80          proto.c64s_size() || proto.c128s_size() ||
81          proto.tuple_literals_size() || !proto.f16s().empty() ||
82          !proto.bf16s().empty() || !proto.u16s().empty() ||
83          !proto.s16s().empty();
84 }
85 
86 }  // namespace
87 
~LiteralBase()88 LiteralBase::~LiteralBase() {}
89 
operator <<(std::ostream & out,const Literal & literal)90 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
91   out << literal.ToString();
92   return out;
93 }
94 
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)95 MutableLiteralBase::StrideConfig::StrideConfig(
96     const Shape& source_shape, const Shape& dest_shape,
97     absl::Span<const int64> dimensions)
98     : dimensions(dimensions),
99       base(dimensions.size(), 0),
100       step(dimensions.size(), 1) {
101   if (!dimensions.empty()) {
102     // Selects the shape with the largest minor dimension as the one upon
103     // which to run the tight stride loop.
104     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
105         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
106       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
107       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
108     } else {
109       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
110       source_stride =
111           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
112     }
113     minor_loop_size = dimensions[minor_dimension];
114     step[minor_dimension] = minor_loop_size;
115   }
116 }
117 
Literal(const Shape & shape)118 Literal::Literal(const Shape& shape)
119     : Literal(shape, /*allocate_arrays=*/true) {}
120 
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)121 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
122   if (shape.IsTuple()) {
123     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
124       const Shape& subshape = shape.tuple_shapes(i);
125 
126       auto child_piece = Piece();
127       child_piece.set_subshape(&subshape);
128 
129       SetPiece(subshape, &child_piece, allocate_arrays);
130 
131       piece->emplace_back(std::move(child_piece));
132     }
133   } else if (shape.IsArray()) {
134     if (allocate_arrays) {
135       // Literals can be used as DMA targets, which can require alignment. We
136       // force a 16-byte minimum alignment.
137       constexpr int kMinimumAlignment = 16;
138       piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
139           piece->size_bytes(), kMinimumAlignment)));
140     }
141   } else {
142     // If the shape is neither an array nor tuple, then it must be
143     // zero-sized. Otherwise, some memory needs to be allocated for it.
144     CHECK_EQ(piece->size_bytes(), 0);
145   }
146 }
147 
Literal(const Shape & shape,bool allocate_arrays)148 Literal::Literal(const Shape& shape, bool allocate_arrays)
149     : MutableLiteralBase() {
150   shape_ = absl::make_unique<Shape>(shape);
151   CHECK(LayoutUtil::HasLayout(*shape_));
152   root_piece_ = new Piece();
153   root_piece_->set_subshape(shape_.get());
154   CHECK(&root_piece_->subshape() == shape_.get());
155 
156   SetPiece(*shape_, root_piece_, allocate_arrays);
157 }
158 
~Literal()159 Literal::~Literal() {
160   if (root_piece_ != nullptr) {
161     DeallocateBuffers();
162     delete root_piece_;
163   }
164 }
165 
DeallocateBuffers()166 void Literal::DeallocateBuffers() {
167   root_piece_->ForEachMutableSubpiece(
168       [&](const ShapeIndex& index, Piece* piece) {
169         if (piece->buffer() != nullptr) {
170           tensorflow::port::AlignedFree(piece->buffer());
171         }
172       });
173 }
174 
Literal(Literal && other)175 Literal::Literal(Literal&& other) : MutableLiteralBase() {
176   *this = std::move(other);
177 }
178 
operator =(Literal && other)179 Literal& Literal::operator=(Literal&& other) {
180   DCHECK(&other.root_piece_->subshape() == other.shape_.get());
181   using std::swap;
182   swap(shape_, other.shape_);
183   swap(root_piece_, other.root_piece_);
184   DCHECK(&root_piece_->subshape() == shape_.get());
185 
186   return *this;
187 }
188 
CreateFromShape(const Shape & shape)189 Literal LiteralBase::CreateFromShape(const Shape& shape) {
190   Literal literal(shape);
191   literal.root_piece_->ForEachMutableSubpiece(
192       [&](const ShapeIndex& index, Piece* piece) {
193         if (piece->subshape().IsArray()) {
194           memset(piece->untyped_data(), 0, piece->size_bytes());
195         }
196       });
197   return literal;
198 }
199 
200 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)201 Status MutableLiteralBase::CopySliceFromInternal(
202     const LiteralBase& src_literal, absl::Span<const int64> src_base,
203     absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
204   TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
205   TF_RET_CHECK(shape().rank() == dest_base.size());
206 
207   auto linear_index = [](const Shape& shape,
208                          absl::Span<const int64> multi_index) {
209     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
210   };
211 
212   if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
213     // If any of the two shapes are scalars, we can just call the StridedCopy()
214     // directly, and we know we will be copying only one value.
215     TF_RET_CHECK(copy_size.empty());
216     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
217                 src_literal.data<NativeT>(),
218                 linear_index(src_literal.shape(), src_base), 0, 1);
219   } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
220              !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
221     // Perform copy if neither src nor dest has dimensions with zero element,
222     // otherwise it's a no-op.
223     TF_RET_CHECK(src_base.size() == dest_base.size());
224     TF_RET_CHECK(src_base.size() == copy_size.size());
225 
226     // Scan the source from minor, stepping in copy size blocks, then within
227     // the index enumeration functor, do a strided copy advancing source index
228     // by one (walking through the minor dimension), and destination index by
229     // proper stride size at the matching dimension.
230     DimensionVector src_indexes(src_base.size(), 0);
231     DimensionVector dest_indexes(dest_base.size(), 0);
232     MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
233                                                    copy_size);
234 
235     auto copy_proc = [&](absl::Span<const int64> indexes) {
236       // Map from multi-dimensional index, to source index.
237       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
238                      src_indexes.begin(), std::plus<int64>());
239       // Map from multi-dimensional index, to destination index.
240       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
241                      dest_indexes.begin(), std::plus<int64>());
242 
243       int64 src_index = linear_index(src_literal.shape(), src_indexes);
244       int64 dest_index = linear_index(shape(), dest_indexes);
245 
246       // `this->` is needed to workaround MSVC bug: #16882
247       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
248                   src_literal.data<NativeT>(), src_index,
249                   stride_config.source_stride, stride_config.minor_loop_size);
250       return true;
251     };
252 
253     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
254                             stride_config.dimensions, stride_config.step,
255                             copy_proc);
256   }
257   return Status::OK();
258 }
259 
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)260 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
261                                            absl::Span<const int64> src_index,
262                                            absl::Span<const int64> dest_index) {
263   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
264   const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
265       src_literal.shape(), src_index);
266   const int64 dest_linear_index =
267       IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
268   const int64 primitive_size =
269       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
270 
271   char* dest_address =
272       static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
273   const char* source_address =
274       static_cast<const char*>(src_literal.untyped_data()) +
275       src_linear_index * primitive_size;
276   if (dest_address != source_address) {
277     memcpy(dest_address, source_address, primitive_size);
278   }
279   return Status::OK();
280 }
281 
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)282 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
283     const LiteralProto& proto, bool prohibit_empty_literal) {
284   if (!proto.has_shape()) {
285     return InvalidArgument("LiteralProto has no shape");
286   }
287   Shape shape(proto.shape());
288   if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
289     return InvalidArgument(
290         "Literal shape cannot include OPAQUE_TYPE sub-shape");
291   }
292   if (!LayoutUtil::HasLayout(shape)) {
293     return InvalidArgument("LiteralProto has no layout");
294   }
295 
296   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
297 
298   Literal literal(shape);
299 
300   TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
301       [&](const ShapeIndex& index, Piece* piece) {
302         const LiteralProto* proto_element = &proto;
303         for (int64 i : index) {
304           CHECK(i < proto_element->tuple_literals_size());
305           proto_element = &proto_element->tuple_literals(i);
306         }
307 
308         if (piece->subshape().IsTuple()) {
309           if (proto_element->tuple_literals_size() !=
310               ShapeUtil::TupleElementCount(piece->subshape())) {
311             return InvalidArgument(
312                 "Expected %d tuple elements in LiteralProto, has %d",
313                 ShapeUtil::TupleElementCount(piece->subshape()),
314                 proto_element->tuple_literals_size());
315           }
316           return Status::OK();
317         }
318         if (piece->subshape().element_type() == TOKEN) {
319           return Status::OK();
320         }
321 
322         CHECK(piece->subshape().IsArray());
323 
324         // When prohibit_empty_literal is false (allowing literal with no
325         // values), only copy from proto if the literal proto has values. This
326         // mode is used for a learned cost model.
327         if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
328           TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
329         }
330 
331         return Status::OK();
332       }));
333 
334   return std::move(literal);
335 }
336 
DecomposeTuple()337 std::vector<Literal> Literal::DecomposeTuple() {
338   CHECK(shape().IsTuple());
339   std::vector<Literal> elements;
340   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
341     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
342                                /*allocate_arrays=*/false));
343     Literal& element = elements.back();
344     element.root_piece_->ForEachMutableSubpiece(
345         [&](const ShapeIndex& index, Piece* dest_piece) {
346           ShapeIndex src_index = {i};
347           for (int64 j : index) {
348             src_index.push_back(j);
349           }
350           Piece& src_piece = piece(src_index);
351 
352           // Move the respective buffer over to the element Literal.
353           dest_piece->set_buffer(src_piece.buffer());
354           src_piece.set_buffer(nullptr);
355         });
356   }
357   // Set this literal to be nil-shaped.
358   *this = Literal();
359   return elements;
360 }
361 
362 namespace {
363 
364 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
365 // the array slices are indicated by dest_shape and src_shape respectively.
366 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)367 void CopyElementsBetween(absl::Span<NativeT> dest,
368                          absl::Span<const NativeT> src, const Shape& dest_shape,
369                          const Shape& src_shape) {
370   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
371   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
372     return;
373   }
374   std::vector<int64> index(dest_shape.rank());
375   do {
376     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
377         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
378   } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
379 }
380 
381 }  // namespace
382 
CopyFrom(const LiteralBase::Piece & src)383 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
384   CHECK(subshape_ != nullptr);
385   CHECK(src.subshape_ != nullptr);
386   if (ShapeUtil::Equal(subshape(), src.subshape())) {
387     // If the layouts are equal it's faster just to memcpy.
388     memcpy(buffer(), src.buffer(), src.size_bytes());
389   } else {
390     TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
391     std::vector<int64> origin(subshape().rank(), 0);
392     switch (subshape().element_type()) {
393 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                    \
394   case (XLA_T):                                                           \
395     CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
396                                   subshape(), src.subshape());            \
397     break;
398       COPY_ELEMENTS(U8, uint8);
399       COPY_ELEMENTS(U16, uint16);
400       COPY_ELEMENTS(U32, uint32);
401       COPY_ELEMENTS(U64, uint64);
402       COPY_ELEMENTS(S8, int8);
403       COPY_ELEMENTS(S16, int16);
404       COPY_ELEMENTS(S32, int32);
405       COPY_ELEMENTS(S64, int64);
406       COPY_ELEMENTS(F16, half);
407       COPY_ELEMENTS(BF16, bfloat16);
408       COPY_ELEMENTS(F32, float);
409       COPY_ELEMENTS(F64, double);
410       COPY_ELEMENTS(C64, complex64);
411       COPY_ELEMENTS(C128, complex128);
412       COPY_ELEMENTS(PRED, bool);
413 #undef COPY_ELEMENTS
414       default:
415         return Unimplemented(
416             "Copying a Literal object with element type %s is not implemented.",
417             PrimitiveType_Name(subshape().element_type()));
418     }
419   }
420   return Status::OK();
421 }
422 
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index)423 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
424                                     const ShapeIndex& dest_shape_index,
425                                     const ShapeIndex& src_shape_index) {
426   const Shape& dest_subshape =
427       ShapeUtil::GetSubshape(shape(), dest_shape_index);
428   const Shape& src_subshape =
429       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
430   if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
431     return InvalidArgument(
432         "Destination subshape incompatible with source subshape: %s vs %s",
433         ShapeUtil::HumanString(dest_subshape),
434         ShapeUtil::HumanString(src_subshape));
435   }
436   return root_piece_->ForEachMutableSubpieceWithStatus(
437       [&](const ShapeIndex& index, Piece* piece) {
438         if (!piece->subshape().IsArray()) {
439           return Status::OK();
440         }
441 
442         // Determine if this index is in the part of this literal that we want
443         // to copy over from src_literal.
444         bool in_subtree_to_copy = true;
445         for (int i = 0; i < dest_shape_index.size(); ++i) {
446           if (index[i] != dest_shape_index[i]) {
447             in_subtree_to_copy = false;
448             break;
449           }
450         }
451         if (!in_subtree_to_copy) {
452           return Status::OK();
453         }
454         // Construct the index of the corresponding piece in the source literal.
455         ShapeIndex src_piece_index = src_shape_index;
456         for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
457           src_piece_index.push_back(index[i]);
458         }
459         TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
460         return Status::OK();
461       });
462 }
463 
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)464 Status Literal::MoveFrom(Literal&& src_literal,
465                          const ShapeIndex& dest_shape_index) {
466   const Shape& dest_subshape =
467       ShapeUtil::GetSubshape(shape(), dest_shape_index);
468   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
469     return InvalidArgument(
470         "Destination subshape not equal to source shape: %s vs %s",
471         ShapeUtil::HumanString(dest_subshape),
472         ShapeUtil::HumanString(src_literal.shape()));
473   }
474 
475   src_literal.root_piece_->ForEachSubpiece(
476       [&](const ShapeIndex& src_index, const Piece& src_piece) {
477         if (!src_piece.subshape().IsArray()) {
478           return;
479         }
480 
481         ShapeIndex dest_index = dest_shape_index;
482         for (int64 i : src_index) {
483           dest_index.push_back(i);
484         }
485         Piece& dest_piece = piece(dest_index);
486         tensorflow::port::AlignedFree(dest_piece.buffer());
487         dest_piece.set_buffer(src_piece.buffer());
488       });
489 
490   src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
491   delete src_literal.root_piece_;
492   src_literal.root_piece_ = new LiteralBase::Piece();
493   src_literal.root_piece_->set_subshape(src_literal.shape_.get());
494 
495   return Status::OK();
496 }
497 
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)498 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
499                                          absl::Span<const int64> src_base,
500                                          absl::Span<const int64> dest_base,
501                                          absl::Span<const int64> copy_size) {
502   TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
503   TF_RET_CHECK(src_literal.shape().IsArray())
504       << ShapeUtil::HumanString(src_literal.shape());
505   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
506 
507   switch (shape().element_type()) {
508     case U8:
509       return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
510                                           copy_size);
511     case U16:
512       return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
513                                            copy_size);
514     case U32:
515       return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
516                                            copy_size);
517     case U64:
518       return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
519                                            copy_size);
520     case S8:
521       return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
522                                          copy_size);
523     case S16:
524       return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
525                                           copy_size);
526     case S32:
527       return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
528                                           copy_size);
529     case S64:
530       return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
531                                           copy_size);
532     case F16:
533       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
534                                          copy_size);
535     case BF16:
536       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
537                                              copy_size);
538     case F32:
539       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
540                                           copy_size);
541     case F64:
542       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
543                                            copy_size);
544     case C64:
545       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
546                                               copy_size);
547     case C128:
548       return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
549                                                copy_size);
550     case PRED:
551       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
552                                          copy_size);
553     default:
554       break;
555   }
556   return Unimplemented(
557       "Copying a slice from a Literal object with element type %d is not "
558       "implemented.",
559       shape().element_type());
560 }
561 
PopulateR1(const tensorflow::core::Bitmap & values)562 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
563   CHECK(shape().IsArray());
564   CHECK_EQ(shape().rank(), 1);
565   CHECK_EQ(element_count(), values.bits());
566   CHECK_EQ(shape().element_type(), PRED);
567   for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
568     Set({i}, values.get(i));
569   }
570 }
571 
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const572 Literal LiteralBase::Relayout(const Layout& new_layout,
573                               const ShapeIndex& shape_index) const {
574   // Create new shape with 'new_layout' set at the given shape index.
575   Shape new_shape = shape();
576   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
577   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
578   *subshape->mutable_layout() = new_layout;
579   Literal result(new_shape);
580   TF_CHECK_OK(result.CopyFrom(*this));
581   return result;
582 }
583 
Relayout(const Shape & shape_with_layout) const584 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
585   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
586       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
587       << " not compatible with literal shape "
588       << ShapeUtil::HumanString(shape());
589   Literal result = CreateFromShape(shape_with_layout);
590   ShapeUtil::ForEachSubshape(
591       result.shape(),
592       [this, &result](const Shape& subshape, const ShapeIndex& index) {
593         if (subshape.IsArray()) {
594           TF_CHECK_OK(result.CopyFrom(*this,
595                                       /*dest_shape_index=*/index,
596                                       /*src_shape_index=*/index));
597         }
598       });
599   return result;
600 }
601 
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const602 StatusOr<Literal> LiteralBase::Broadcast(
603     const Shape& result_shape, absl::Span<const int64> dimensions) const {
604   if (!shape().IsArray()) {
605     return InvalidArgument("Broadcast only supports arrays.");
606   }
607 
608   for (int64 i = 0; i < dimensions.size(); i++) {
609     TF_RET_CHECK(shape().dimensions(i) ==
610                  result_shape.dimensions(dimensions[i]));
611   }
612 
613   Literal result(result_shape);
614 
615   // scratch_source_index is temporary storage space for the computed index into
616   // the input literal.  We put it here to avoid allocating an std::vector in
617   // every iteration of ShapeUtil::ForEachIndex.
618   std::vector<int64> scratch_source_index(shape().dimensions_size());
619 
620   char* dest_data = static_cast<char*>(result.untyped_data());
621   const char* source_data = static_cast<const char*>(untyped_data());
622   const int64 primitive_size =
623       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
624 
625   ShapeUtil::ForEachIndex(
626       result_shape, [&](absl::Span<const int64> output_index) {
627         for (int64 i = 0; i < dimensions.size(); ++i) {
628           scratch_source_index[i] = output_index[dimensions[i]];
629         }
630         int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
631             result_shape, output_index);
632         int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
633             shape(), scratch_source_index);
634         memcpy(dest_data + primitive_size * dest_index,
635                source_data + primitive_size * source_index, primitive_size);
636         return true;
637       });
638 
639   return std::move(result);
640 }
641 
Reshape(absl::Span<const int64> dimensions) const642 StatusOr<Literal> LiteralBase::Reshape(
643     absl::Span<const int64> dimensions) const {
644   if (!shape().IsArray()) {
645     return InvalidArgument("Reshape does not support tuples.");
646   }
647   Literal output;
648   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
649     output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
650   } else {
651     output = Clone();
652   }
653   // Because the layout is monotonic, we can simply reuse the same sequence of
654   // values without changing their order.
655   *output.mutable_shape_do_not_use() =
656       ShapeUtil::MakeShape(shape().element_type(), dimensions);
657 
658   int64 elements_before = ShapeUtil::ElementsIn(shape());
659   int64 elements_after = ShapeUtil::ElementsIn(output.shape());
660   if (elements_before != elements_after) {
661     return InvalidArgument(
662         "Shapes before and after Literal::Reshape have different numbers "
663         "of elements: %s vs %s.",
664         ShapeUtil::HumanString(shape()),
665         ShapeUtil::HumanString(output.shape()));
666   }
667   return std::move(output);
668 }
669 
Transpose(absl::Span<const int64> permutation) const670 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
671   CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
672   CHECK(IsPermutation(permutation, shape().rank()))
673       << "Given permutation is not a permutation of dimension numbers";
674   // To transpose the array, we just permute the dimensions and layout, and
675   // do a straight memory copy of the raw data set.
676   // This is considerably faster than iterating over every array element using
677   // the EachCell<>() and Set<>() APIs.
678   std::vector<int64> inverse_permutation = InversePermutation(permutation);
679   Shape permuted_shape =
680       ShapeUtil::PermuteDimensions(inverse_permutation, shape());
681   // Replace the layout with one affine to this shape, such that a
682   // transpose operation can be performed by leaving the flat values
683   // representation intact.
684   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
685   // The shape with affine layout resulting from that operation will be
686   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
687   // most minor.
688   //
689   // Essentially, given MinMaj(Di) the position of the Di dimension within the
690   // minor to major vector, and given T(Di) the index that the original Di
691   // dimension has within the transposed array, a layout is affine if
692   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
693   // vector of the affine layout.
694   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
695   Layout* layout = permuted_shape.mutable_layout();
696   layout->clear_minor_to_major();
697   for (auto index : LayoutUtil::MinorToMajor(shape())) {
698     layout->add_minor_to_major(inverse_permutation[index]);
699   }
700   Literal new_literal(permuted_shape);
701   DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
702             ShapeUtil::ByteSizeOf(shape()));
703   std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
704   return new_literal;
705 }
706 
707 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const708 Literal LiteralBase::SliceInternal(
709     const Shape& result_shape, absl::Span<const int64> start_indices) const {
710   Literal result_literal(result_shape);
711   DimensionVector new_indices(result_shape.rank());
712   CHECK(result_literal
713             .Populate<NativeT>([&](absl::Span<const int64> indices) {
714               for (int64 i = 0; i < result_shape.rank(); ++i) {
715                 new_indices[i] = indices[i] + start_indices[i];
716               }
717               return Get<NativeT>(new_indices);
718             })
719             .ok());
720   return result_literal;
721 }
722 
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const723 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
724                            absl::Span<const int64> limit_indices) const {
725   CHECK(shape().IsArray()) << "tuple is not supported for slice";
726 
727   DimensionVector result_dimensions;
728   for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
729     CHECK_GE(start_indices[dnum], 0);
730     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
731         << "dnum = " << dnum;
732     int64 dimension = limit_indices[dnum] - start_indices[dnum];
733     CHECK_GE(dimension, 0) << "dnum = " << dnum;
734     result_dimensions.push_back(dimension);
735   }
736   const auto result_shape =
737       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
738                                      LayoutUtil::MinorToMajor(shape()));
739   switch (result_shape.element_type()) {
740     case PRED:
741       return SliceInternal<bool>(result_shape, start_indices);
742     case U8:
743       return SliceInternal<uint8>(result_shape, start_indices);
744     case U16:
745       return SliceInternal<uint16>(result_shape, start_indices);
746     case U32:
747       return SliceInternal<uint32>(result_shape, start_indices);
748     case U64:
749       return SliceInternal<uint64>(result_shape, start_indices);
750     case S8:
751       return SliceInternal<int8>(result_shape, start_indices);
752     case S16:
753       return SliceInternal<int16>(result_shape, start_indices);
754     case S32:
755       return SliceInternal<int32>(result_shape, start_indices);
756     case S64:
757       return SliceInternal<int64>(result_shape, start_indices);
758     case F16:
759       return SliceInternal<half>(result_shape, start_indices);
760     case BF16:
761       return SliceInternal<bfloat16>(result_shape, start_indices);
762     case F32:
763       return SliceInternal<float>(result_shape, start_indices);
764     case F64:
765       return SliceInternal<double>(result_shape, start_indices);
766     case C64:
767       return SliceInternal<complex64>(result_shape, start_indices);
768     case C128:
769       return SliceInternal<complex128>(result_shape, start_indices);
770     default:
771       LOG(FATAL) << "not yet implemented: "
772                  << PrimitiveType_Name(result_shape.element_type());
773   }
774 }
775 
Clone() const776 Literal LiteralBase::Clone() const {
777   Literal result(shape());
778   TF_CHECK_OK(result.CopyFrom(*this));
779   return result;
780 }
781 
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const782 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
783                                 const ShapeIndex& shape_index) const {
784   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
785   CHECK(LayoutUtil::IsDenseArray(subshape));
786   switch (subshape.element_type()) {
787     case PRED:
788       return Get<bool>(multi_index, shape_index) ? "true" : "false";
789     case S8:
790       return StrCat(Get<int8>(multi_index, shape_index));
791     case S16:
792       return StrCat(Get<int16>(multi_index, shape_index));
793     case S32:
794       return StrCat(Get<int32>(multi_index, shape_index));
795     case S64:
796       return StrCat(Get<int64>(multi_index, shape_index));
797     case U8:
798       return StrCat(Get<uint8>(multi_index, shape_index));
799     case U16:
800       return StrCat(Get<uint16>(multi_index, shape_index));
801     case U32:
802       return StrCat(Get<uint32>(multi_index, shape_index));
803     case U64:
804       return StrCat(Get<uint64>(multi_index, shape_index));
805     case F16:
806       return RoundTripFpToString(Get<half>(multi_index, shape_index));
807     case F32:
808       return RoundTripFpToString(Get<float>(multi_index, shape_index));
809     case BF16:
810       return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
811     case F64:
812       return RoundTripFpToString(Get<double>(multi_index, shape_index));
813     case C64: {
814       complex64 c = Get<complex64>(multi_index, shape_index);
815       return StrCat("(", RoundTripFpToString(c.real()), ", ",
816                     RoundTripFpToString(c.imag()), ")");
817     }
818     case C128: {
819       complex128 c = Get<complex128>(multi_index, shape_index);
820       return StrCat("(", RoundTripFpToString(c.real()), ", ",
821                     RoundTripFpToString(c.imag()), ")");
822     }
823     default:
824       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
825   }
826 }
827 
GetIntegralAsS64(absl::Span<const int64> multi_index) const828 absl::optional<int64> LiteralBase::GetIntegralAsS64(
829     absl::Span<const int64> multi_index) const {
830   CHECK(LayoutUtil::IsDenseArray(shape()));
831   switch (shape().element_type()) {
832     case PRED:
833       return Get<bool>(multi_index);
834     case U8:
835       return Get<uint8>(multi_index);
836     case S32:
837       return Get<int32>(multi_index);
838     case S64:
839       return Get<int64>(multi_index);
840     case U32:
841       return Get<uint32>(multi_index);
842     case U64:
843       return Get<uint64>(multi_index);
844     default:
845       return absl::nullopt;
846   }
847 }
848 
GetAsDouble(absl::Span<const int64> multi_index) const849 absl::optional<double> LiteralBase::GetAsDouble(
850     absl::Span<const int64> multi_index) const {
851   CHECK(LayoutUtil::IsDenseArray(shape()));
852   switch (shape().element_type()) {
853     case F16:
854       return static_cast<double>(Get<half>(multi_index));
855     case F32:
856       return static_cast<double>(Get<float>(multi_index));
857     case F64:
858       return Get<double>(multi_index);
859     case BF16:
860       return static_cast<double>(Get<bfloat16>(multi_index));
861     default:
862       return absl::nullopt;
863   }
864 }
865 
GetAsComplex128(absl::Span<const int64> multi_index) const866 absl::optional<complex128> LiteralBase::GetAsComplex128(
867     absl::Span<const int64> multi_index) const {
868   switch (shape().element_type()) {
869     case BF16:
870       return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
871     case F16:
872       return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
873     case F32:
874       return {{Get<float>(multi_index), 0}};
875     case F64:
876       return {{Get<double>(multi_index), 0}};
877     case C64:
878       return {Get<complex64>(multi_index)};
879     case C128:
880       return {Get<complex128>(multi_index)};
881     case S8:
882       return {Get<int8>(multi_index)};
883     default:
884       return absl::nullopt;
885   }
886 }
887 
Hash() const888 size_t LiteralBase::Hash() const {
889   using tensorflow::Hash64;
890   using tensorflow::Hash64Combine;
891 
892   size_t hash_value = ShapeUtil::Hash(shape());
893 
894   ShapeUtil::ForEachSubshape(
895       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
896         if (!subshape.IsArray()) {
897           return;
898         }
899 
900         CHECK(LayoutUtil::IsDense(subshape.layout()));
901         hash_value = Hash64Combine(
902             hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
903                                size_bytes(index)));
904       });
905 
906   return hash_value;
907 }
908 
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)909 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
910                                             int64 value) {
911   CHECK(LayoutUtil::IsDenseArray(shape()));
912   switch (shape().element_type()) {
913     case PRED:
914       Set<bool>(multi_index, value);
915       break;
916     case U8:
917       Set<uint8>(multi_index, value);
918       break;
919     case S32:
920       Set<int32>(multi_index, value);
921       break;
922     case S64:
923       Set<int64>(multi_index, value);
924       break;
925     case U32:
926       Set<uint32>(multi_index, value);
927       break;
928     case U64:
929       Set<uint64>(multi_index, value);
930       break;
931     default:
932       return FailedPrecondition("Array element type is not integral: %s",
933                                 PrimitiveType_Name(shape().element_type()));
934   }
935   return Status::OK();
936 }
937 
SetFromDouble(absl::Span<const int64> multi_index,double value)938 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
939                                          double value) {
940   CHECK(LayoutUtil::IsDenseArray(shape()));
941   switch (shape().element_type()) {
942     case F16:
943       Set<half>(multi_index, Eigen::half(value));
944       break;
945     case F32:
946       Set<float>(multi_index, value);
947       break;
948     case F64:
949       Set<double>(multi_index, value);
950       break;
951     case BF16:
952       Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
953       break;
954     default:
955       return FailedPrecondition("Array element type is not floating: %s",
956                                 PrimitiveType_Name(shape().element_type()));
957   }
958   return Status::OK();
959 }
960 
961 namespace {
962 
ShapeToString(bool print_layout,const Shape & shape)963 string ShapeToString(bool print_layout, const Shape& shape) {
964   return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
965                       : ShapeUtil::HumanString(shape);
966 }
967 
968 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
969                     bool print_shape, bool print_layout,
970                     std::vector<string>* pieces);
971 
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)972 void TupleToStringHelper(const LiteralBase& literal,
973                          const ShapeIndex& shape_index, bool print_shape,
974                          bool print_layout, std::vector<string>* pieces) {
975   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
976   pieces->push_back("(\n");
977   std::vector<string> tuple_pieces;
978   for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
979     ShapeIndex element_index = shape_index;
980     element_index.push_back(i);
981     std::vector<string> element_pieces;
982     ToStringHelper(literal, element_index, print_shape, print_layout,
983                    &element_pieces);
984     tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
985   }
986   pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
987   pieces->push_back("\n)");
988 }
989 
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)990 void DenseArrayToStringHelper(const LiteralBase& literal,
991                               const ShapeIndex& shape_index, bool print_shape,
992                               bool print_layout, std::vector<string>* pieces) {
993   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
994   int64 rank = subshape.rank();
995 
996   std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
997       to_string_recursive = [&](absl::Span<const int64> dimensions,
998                                 std::vector<int64>* accum_indices) {
999         // dimensions.size() decreases by 1 at each recursive call,
1000         // and accum_indices->size() increases by 1.
1001         // Their sum is equal to the rank of the tensor.
1002         CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1003 
1004         auto brace_to_string = [&](string brace) -> string {
1005           // Handle 1D tensor
1006           if (rank == 1) {
1007             return brace;
1008           }
1009           // Handle the innermost tensor of a 2D+ tensor.
1010           if (dimensions.size() == 1 && brace == "{") {
1011             return StrCat("  ", brace, dimensions[0] <= 1 ? "" : " ");
1012           }
1013           if (dimensions.size() == 1 && brace == "}") {
1014             return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1015           }
1016           // Handle the non-innermost tensors of a 2D+ tensor.
1017           if (brace == "{") {
1018             if (rank > 3 && !accum_indices->empty() &&
1019                 accum_indices->size() < rank) {
1020               int index = accum_indices->size() - 1;
1021               int value = accum_indices->back();
1022               return StrCat(brace, " /*i", index, "=", value, "*/\n");
1023             }
1024             return StrCat(brace, "\n");
1025           }
1026           return StrCat("\n", brace);
1027         };
1028 
1029         if (dimensions.empty()) {
1030           // Display predicates as 0s and 1s so that the string is more dense.
1031           string elem;
1032           if (subshape.element_type() == PRED && rank > 0) {
1033             elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1034           } else {
1035             elem = literal.GetAsString(*accum_indices, shape_index);
1036           }
1037           pieces->push_back(elem);
1038         } else {
1039           pieces->push_back(brace_to_string("{"));
1040           for (int i = 0; i < dimensions[0]; ++i) {
1041             std::vector<int64> cloned_indices(*accum_indices);
1042             cloned_indices.push_back(i);
1043             to_string_recursive(dimensions.subspan(1), &cloned_indices);
1044             if (i < dimensions[0] - 1) {
1045               pieces->push_back(",");
1046               pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1047             }
1048           }
1049           pieces->push_back(brace_to_string("}"));
1050         }
1051       };
1052 
1053   if (print_shape) {
1054     pieces->push_back(ShapeToString(print_layout, subshape));
1055     pieces->push_back(" ");
1056   }
1057   std::vector<int64> indices = {};
1058   std::vector<int64> dimensions(subshape.dimensions().begin(),
1059                                 subshape.dimensions().end());
1060   to_string_recursive(dimensions, &indices);
1061 }
1062 
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1063 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1064                     bool print_shape, bool print_layout,
1065                     std::vector<string>* pieces) {
1066   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1067   CHECK(LayoutUtil::HasLayout(literal.shape()));
1068   CHECK(LayoutUtil::HasLayout(subshape));
1069   if (subshape.IsTuple()) {
1070     TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1071                         pieces);
1072   } else if (subshape.IsToken()) {
1073     pieces->push_back("token");
1074   } else {
1075     CHECK(LayoutUtil::IsDenseArray(subshape));
1076     DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1077                              pieces);
1078   }
1079 }
1080 
1081 }  // namespace
1082 
ToString() const1083 string LiteralBase::ToString() const {
1084   std::vector<string> pieces;
1085   CHECK(LayoutUtil::HasLayout(this->shape()));
1086   ToStringHelper(*this, {}, /*print_shape=*/true,
1087                  /*print_layout=*/false, &pieces);
1088   return absl::StrJoin(pieces, "");
1089 }
1090 
ToStringWithoutShape() const1091 string LiteralBase::ToStringWithoutShape() const {
1092   std::vector<string> pieces;
1093   CHECK(LayoutUtil::HasLayout(this->shape()));
1094   ToStringHelper(*this, {}, /*print_shape=*/false,
1095                  /*print_layout=*/false, &pieces);
1096   return absl::StrJoin(pieces, "");
1097 }
1098 
ToStringWithLayout() const1099 string LiteralBase::ToStringWithLayout() const {
1100   std::vector<string> pieces;
1101   CHECK(LayoutUtil::HasLayout(this->shape()));
1102   ToStringHelper(*this, {}, /*print_shape=*/true,
1103                  /*print_layout=*/true, &pieces);
1104   return absl::StrJoin(pieces, "");
1105 }
1106 
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1107 void LiteralBase::EachCellAsString(
1108     const std::function<void(absl::Span<const int64> indices,
1109                              const string& value)>& per_cell) const {
1110   if (ShapeUtil::IsZeroElementArray(shape())) {
1111     return;
1112   }
1113   std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1114       shape(), /*linear_index=*/0);
1115   do {
1116     per_cell(indices, GetAsString(indices));
1117   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1118 }
1119 
1120 namespace {
1121 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1122 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1123                                                const ConverterType& converter) {
1124   CHECK(src_literal.shape().IsArray());
1125   Literal result_literal(ShapeUtil::ChangeElementType(
1126       src_literal.shape(),
1127       primitive_util::NativeToPrimitiveType<NativeDestT>()));
1128   auto src_data = src_literal.data<NativeSrcT>();
1129   auto dest_data = result_literal.template data<NativeDestT>();
1130   int64 num_elements = src_literal.element_count();
1131 
1132   for (int64 i = 0; i < num_elements; ++i) {
1133     dest_data[i] = converter(src_data[i]);
1134   }
1135   return result_literal;
1136 }
1137 
1138 template <typename NativeSrcT, typename NativeDestT>
1139 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1140                             (std::is_same<NativeDestT, complex64>::value ||
1141                              std::is_same<NativeDestT, complex128>::value),
1142                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1143 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1144   auto converter = [](NativeSrcT src) {
1145     return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1146   };
1147   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1148       src_literal, converter);
1149 }
1150 
1151 template <typename NativeSrcT, typename NativeDestT>
1152 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1153                             (!std::is_same<NativeDestT, complex64>::value &&
1154                              !std::is_same<NativeDestT, complex128>::value),
1155                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1156 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1157   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1158   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1159       src_literal, converter);
1160 }
1161 
1162 template <typename NativeSrcT, typename NativeDestT>
1163 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1164                          !std::is_same<NativeDestT, Eigen::half>::value),
1165                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1166 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1167   auto converter = [](NativeSrcT src) {
1168     return absl::bit_cast<NativeDestT>(GetRawValue(src));
1169   };
1170   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1171       src_literal, converter);
1172 }
1173 
1174 template <typename NativeSrcT, typename NativeDestT>
1175 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1176                          std::is_same<NativeDestT, Eigen::half>::value),
1177                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1178 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1179   // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1180   // cast to unsigned short and then use raw_uint16_to_half.
1181   auto converter = [](NativeSrcT src) {
1182     return Eigen::half_impl::raw_uint16_to_half(
1183         absl::bit_cast<uint16>(GetRawValue(src)));
1184   };
1185   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1186       src_literal, converter);
1187 }
1188 
1189 // This template specialization is here to make the compiler happy. bit_cast has
1190 // a static check that the types are the same size. This specialization should
1191 // never be used because the source and destination types are checked for
1192 // identical sizes higher up.
1193 template <typename NativeSrcT, typename NativeDestT>
1194 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1195                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1196 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1197   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1198 }
1199 
1200 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1201 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1202   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1203   if (bitcast) {
1204     return BitcastBetweenNativeTypes<
1205         typename primitive_util::PrimitiveTypeToNative<
1206             primitive_src_type>::type,
1207         typename primitive_util::PrimitiveTypeToNative<
1208             primitive_dest_type>::type>(src_literal);
1209   } else {
1210     return ConvertBetweenNativeTypes<
1211         typename primitive_util::PrimitiveTypeToNative<
1212             primitive_src_type>::type,
1213         typename primitive_util::PrimitiveTypeToNative<
1214             primitive_dest_type>::type>(src_literal);
1215   }
1216 }
1217 
1218 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1219 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1220                                            PrimitiveType primitive_dest_type,
1221                                            bool bitcast) {
1222   switch (primitive_dest_type) {
1223 #define CONVERT_IF_TYPES_MATCH(type)                                    \
1224   case (type):                                                          \
1225     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1226                                                            bitcast);
1227     CONVERT_IF_TYPES_MATCH(PRED)
1228     CONVERT_IF_TYPES_MATCH(S8)
1229     CONVERT_IF_TYPES_MATCH(S16)
1230     CONVERT_IF_TYPES_MATCH(S32)
1231     CONVERT_IF_TYPES_MATCH(S64)
1232     CONVERT_IF_TYPES_MATCH(U8)
1233     CONVERT_IF_TYPES_MATCH(U16)
1234     CONVERT_IF_TYPES_MATCH(U32)
1235     CONVERT_IF_TYPES_MATCH(U64)
1236     CONVERT_IF_TYPES_MATCH(F16)
1237     CONVERT_IF_TYPES_MATCH(F32)
1238     CONVERT_IF_TYPES_MATCH(F64)
1239     CONVERT_IF_TYPES_MATCH(BF16)
1240 #undef CONVERT_IF_TYPES_MATCH
1241     case C64:
1242       if (bitcast) {
1243         break;
1244       }
1245       return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1246     case C128:
1247       if (bitcast) {
1248         break;
1249       }
1250       return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1251     // Other types are not yet supported.
1252     default:
1253       break;
1254   }
1255   return Unimplemented("Converting from type %s to type %s is not implemented.",
1256                        PrimitiveType_Name(src_literal.shape().element_type()),
1257                        PrimitiveType_Name(primitive_dest_type));
1258 }
1259 
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1260 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1261                                 PrimitiveType primitive_dest_type,
1262                                 bool bitcast) {
1263   TF_RET_CHECK(literal.shape().IsArray());
1264   if (literal.shape().element_type() == primitive_dest_type) {
1265     return literal.Clone();
1266   }
1267   switch (literal.shape().element_type()) {
1268 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
1269   case (type):                                                            \
1270     return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1271                                             bitcast);
1272     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1273     CONVERT_IF_DEST_TYPE_MATCHES(S8)
1274     CONVERT_IF_DEST_TYPE_MATCHES(S16)
1275     CONVERT_IF_DEST_TYPE_MATCHES(S32)
1276     CONVERT_IF_DEST_TYPE_MATCHES(S64)
1277     CONVERT_IF_DEST_TYPE_MATCHES(U8)
1278     CONVERT_IF_DEST_TYPE_MATCHES(U16)
1279     CONVERT_IF_DEST_TYPE_MATCHES(U32)
1280     CONVERT_IF_DEST_TYPE_MATCHES(U64)
1281     CONVERT_IF_DEST_TYPE_MATCHES(F16)
1282     CONVERT_IF_DEST_TYPE_MATCHES(F32)
1283     CONVERT_IF_DEST_TYPE_MATCHES(F64)
1284     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1285 #undef CONVERT_IF_DEST_TYPE_MATCHES
1286       // Other types are not yet supported.
1287     default:
1288       return Unimplemented("%s from type %s to type %s is not implemented.",
1289                            (bitcast ? "Bitcast converting" : "Converting"),
1290                            PrimitiveType_Name(literal.shape().element_type()),
1291                            PrimitiveType_Name(primitive_dest_type));
1292   }
1293 }
1294 
1295 }  // namespace
1296 
Convert(PrimitiveType primitive_dest_type) const1297 StatusOr<Literal> LiteralBase::Convert(
1298     PrimitiveType primitive_dest_type) const {
1299   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1300 }
1301 
BitcastConvert(PrimitiveType primitive_dest_type) const1302 StatusOr<Literal> LiteralBase::BitcastConvert(
1303     PrimitiveType primitive_dest_type) const {
1304   if (primitive_util::BitWidth(shape().element_type()) !=
1305       primitive_util::BitWidth(primitive_dest_type)) {
1306     return InvalidArgument(
1307         "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1308         "%d",
1309         PrimitiveType_Name(shape().element_type()),
1310         PrimitiveType_Name(primitive_dest_type),
1311         primitive_util::BitWidth(shape().element_type()),
1312         primitive_util::BitWidth(primitive_dest_type));
1313   }
1314   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1315 }
1316 
ConvertToShape(const Shape & dest_shape) const1317 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1318   if (!dest_shape.IsTuple()) {
1319     return Convert(dest_shape.element_type());
1320   }
1321   std::vector<Literal> elements;
1322   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1323     auto element = LiteralSlice(*this, {i});
1324     TF_ASSIGN_OR_RETURN(
1325         auto new_element,
1326         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1327     elements.push_back(std::move(new_element));
1328   }
1329   return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1330 }
1331 
MoveIntoTuple(absl::Span<Literal> elements)1332 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1333     absl::Span<Literal> elements) {
1334   std::vector<Shape> element_shapes;
1335   for (const Literal& element : elements) {
1336     element_shapes.push_back(element.shape());
1337   }
1338   Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1339                   /*allocate_arrays=*/false);
1340   for (int i = 0; i < elements.size(); ++i) {
1341     TF_CHECK_OK(
1342         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1343   }
1344   return literal;
1345 }
1346 
1347 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1348 bool LiteralBase::Piece::EqualElementsInternal(
1349     const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1350   if (multi_index->size() == subshape().rank()) {
1351     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1352   }
1353   for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
1354     multi_index->push_back(i);
1355     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1356       return false;
1357     }
1358     multi_index->pop_back();
1359   }
1360   return true;
1361 }
1362 
EqualElements(const LiteralBase::Piece & other) const1363 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1364   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1365 
1366   if (ShapeUtil::Equal(subshape(), other.subshape()) &&
1367       LayoutUtil::IsDenseArray(subshape())) {
1368     CHECK_EQ(size_bytes(), other.size_bytes());
1369     return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1370   }
1371 
1372   std::vector<int64> multi_index;
1373   switch (subshape().element_type()) {
1374     case PRED:
1375       return EqualElementsInternal<bool>(other, &multi_index);
1376     case U8:
1377       return EqualElementsInternal<uint8>(other, &multi_index);
1378     case S16:
1379       return EqualElementsInternal<int16>(other, &multi_index);
1380     case S32:
1381       return EqualElementsInternal<int32>(other, &multi_index);
1382     case S64:
1383       return EqualElementsInternal<int64>(other, &multi_index);
1384     case U16:
1385       return EqualElementsInternal<uint16>(other, &multi_index);
1386     case U32:
1387       return EqualElementsInternal<uint32>(other, &multi_index);
1388     case U64:
1389       return EqualElementsInternal<uint64>(other, &multi_index);
1390     case F32:
1391       return EqualElementsInternal<float>(other, &multi_index);
1392     case F64:
1393       return EqualElementsInternal<double>(other, &multi_index);
1394     case F16:
1395       return EqualElementsInternal<half>(other, &multi_index);
1396     case BF16:
1397       return EqualElementsInternal<bfloat16>(other, &multi_index);
1398     case C64:
1399       return EqualElementsInternal<complex64>(other, &multi_index);
1400     case C128:
1401       return EqualElementsInternal<complex128>(other, &multi_index);
1402     default:
1403       LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1404                  << PrimitiveType_Name(subshape().element_type());
1405   }
1406 }
1407 
operator ==(const LiteralBase & other) const1408 bool LiteralBase::operator==(const LiteralBase& other) const {
1409   if (!ShapeUtil::Compatible(shape(), other.shape())) {
1410     return false;
1411   }
1412 
1413   return root_piece().ForEachSubpieceWithBool(
1414       [&](const ShapeIndex& index, const Piece& piece) {
1415         if (!piece.subshape().IsArray()) {
1416           return true;
1417         }
1418 
1419         const Piece& other_piece = other.piece(index);
1420         if (!piece.EqualElements(other_piece)) {
1421           return false;
1422         }
1423         return true;
1424       });
1425 }
1426 
1427 namespace {
1428 
1429 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1430 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1431                                   NativeT value) {
1432   for (int64 i = 0; i < data.size(); ++i) {
1433     if (data[i] != value) {
1434       return false;
1435     }
1436   }
1437   return true;
1438 }
1439 
1440 }  // namespace
1441 
IsAll(int8 value) const1442 bool LiteralBase::IsAll(int8 value) const {
1443   return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1444                                                   const Piece& piece) {
1445     if (!piece.subshape().IsArray()) {
1446       return true;
1447     }
1448 
1449     auto piece_is_all = [&]() {
1450       switch (shape().element_type()) {
1451         case U8:
1452           if (value >= 0) {
1453             return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1454           }
1455           return false;
1456         case U16:
1457           if (value >= 0) {
1458             return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1459           }
1460           return false;
1461         case U32:
1462           if (value >= 0) {
1463             return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1464           }
1465           return false;
1466         case U64:
1467           if (value >= 0) {
1468             return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1469           }
1470           return false;
1471         case S8:
1472           return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1473         case S16:
1474           return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1475         case S32:
1476           return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1477         case S64:
1478           return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1479         case F32:
1480           return AllElementsEqualValue<float>(piece.data<float>(), value);
1481         case F64:
1482           return AllElementsEqualValue<double>(piece.data<double>(), value);
1483         case F16:
1484           return AllElementsEqualValue<half>(piece.data<half>(),
1485                                              static_cast<half>(value));
1486         case BF16:
1487           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1488                                                  static_cast<bfloat16>(value));
1489         case PRED:
1490           if (value == 0) {
1491             return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1492           }
1493           if (value == 1) {
1494             return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1495           }
1496           return false;
1497         default:
1498           return false;
1499       }
1500       return false;
1501     };
1502 
1503     if (!piece_is_all()) {
1504       return false;
1505     }
1506     return true;
1507   });
1508 }
1509 
IsAllFloat(float value) const1510 bool LiteralBase::IsAllFloat(float value) const {
1511   return root_piece().ForEachSubpieceWithBool(
1512       [&](const ShapeIndex& index, const Piece& piece) {
1513         if (!piece.subshape().IsArray()) {
1514           return true;
1515         }
1516 
1517         switch (shape().element_type()) {
1518           case F32:
1519             return AllElementsEqualValue<float>(piece.data<float>(), value);
1520           case F64:
1521             return AllElementsEqualValue<double>(piece.data<double>(), value);
1522           case F16:
1523             return AllElementsEqualValue<half>(piece.data<half>(),
1524                                                static_cast<half>(value));
1525           case BF16:
1526             return AllElementsEqualValue<bfloat16>(
1527                 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1528           default:
1529             return false;
1530         }
1531       });
1532 }
1533 
IsAllComplex(complex64 value) const1534 bool LiteralBase::IsAllComplex(complex64 value) const {
1535   switch (shape().element_type()) {
1536     case C64:
1537       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1538                                               value);
1539     case C128:
1540       return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1541                                                value);
1542     default:
1543       return false;
1544   }
1545 }
1546 
IsAllFirst() const1547 bool LiteralBase::IsAllFirst() const {
1548   return root_piece().ForEachSubpieceWithBool(
1549       [&](const ShapeIndex& index, const Piece& piece) {
1550         if (!piece.subshape().IsArray()) {
1551           return true;
1552         }
1553 
1554         // Empty shapes are not all the first element since there is no first
1555         // element.
1556         if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1557           return false;
1558         }
1559         auto piece_is_all = [&]() {
1560           switch (piece.subshape().element_type()) {
1561             case PRED: {
1562               auto data = piece.data<bool>();
1563               return AllElementsEqualValue<bool>(data, data[0]);
1564             }
1565             // 8 bit types
1566             case S8: {
1567               auto data = piece.data<int8>();
1568               return AllElementsEqualValue<int8>(data, data[0]);
1569             }
1570             case U8: {
1571               auto data = piece.data<uint8>();
1572               return AllElementsEqualValue<uint8>(data, data[0]);
1573             }
1574             // 16 bit types
1575             case BF16: {
1576               auto data = piece.data<bfloat16>();
1577               return AllElementsEqualValue<bfloat16>(data, data[0]);
1578             }
1579             case F16: {
1580               auto data = piece.data<half>();
1581               return AllElementsEqualValue<half>(data, data[0]);
1582             }
1583             case S16: {
1584               auto data = piece.data<int16>();
1585               return AllElementsEqualValue<int16>(data, data[0]);
1586             }
1587             case U16: {
1588               auto data = piece.data<uint16>();
1589               return AllElementsEqualValue<uint16>(data, data[0]);
1590             }
1591             // 32 bit types
1592             case F32: {
1593               auto data = piece.data<float>();
1594               return AllElementsEqualValue<float>(data, data[0]);
1595             }
1596             case U32: {
1597               auto data = piece.data<uint32>();
1598               return AllElementsEqualValue<uint32>(data, data[0]);
1599             }
1600             case S32: {
1601               auto data = piece.data<int32>();
1602               return AllElementsEqualValue<int32>(data, data[0]);
1603             }
1604             // 64 bit types
1605             case C64: {
1606               auto data = piece.data<complex64>();
1607               return AllElementsEqualValue<complex64>(data, data[0]);
1608             }
1609             case F64: {
1610               auto data = piece.data<double>();
1611               return AllElementsEqualValue<double>(data, data[0]);
1612             }
1613             case S64: {
1614               auto data = piece.data<int64>();
1615               return AllElementsEqualValue<int64>(data, data[0]);
1616             }
1617             case U64: {
1618               auto data = piece.data<uint64>();
1619               return AllElementsEqualValue<uint64>(data, data[0]);
1620             }
1621 
1622             case C128: {
1623               auto data = piece.data<complex128>();
1624               return AllElementsEqualValue<complex128>(data, data[0]);
1625             }
1626             default:
1627               return false;
1628           }
1629         };
1630 
1631         if (!piece_is_all()) {
1632           return false;
1633         }
1634         return true;
1635       });
1636 }
1637 
IsR1Iota() const1638 bool LiteralBase::IsR1Iota() const {
1639   if (!shape().IsArray()) {
1640     return false;
1641   }
1642 
1643   if (shape().rank() != 1) {
1644     return false;
1645   }
1646 
1647   auto is_iota_at_idx = [&](const int64 idx) {
1648     switch (shape().element_type()) {
1649       case U8:
1650         return Get<uint8>({idx}) == idx;
1651       case U16:
1652         return Get<uint16>({idx}) == idx;
1653       case U32:
1654         return Get<uint32>({idx}) == idx;
1655       case U64:
1656         return Get<uint64>({idx}) == idx;
1657       case S8:
1658         return Get<int8>({idx}) == idx;
1659       case S16:
1660         return Get<int16>({idx}) == idx;
1661       case S32:
1662         return Get<int32>({idx}) == idx;
1663       case S64:
1664         return Get<int64>({idx}) == idx;
1665       case F32:
1666         return Get<float>({idx}) == idx;
1667       case F64:
1668         return Get<double>({idx}) == idx;
1669       case F16:
1670         return Get<half>({idx}) == static_cast<half>(idx);
1671       case BF16:
1672         return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1673       case C64:
1674         return Get<complex64>({idx}) == complex64(idx, 0.0f);
1675       case C128:
1676         return Get<complex128>({idx}) == complex128(idx, 0.0f);
1677       // pred, token, opaque, tuple, etc. are all not iota.
1678       default:
1679         return false;
1680     }
1681   };
1682 
1683   const int64 elements = ShapeUtil::ElementsIn(shape());
1684   for (int64 idx = 0; idx < elements; ++idx) {
1685     if (!is_iota_at_idx(idx)) {
1686       return false;
1687     }
1688   }
1689 
1690   return true;
1691 }
1692 
IsZero(absl::Span<const int64> indices) const1693 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1694   CHECK(shape().IsArray());
1695   switch (shape().element_type()) {
1696     case U8:
1697       return Get<uint8>(indices) == 0;
1698     case U16:
1699       return Get<uint16>(indices) == 0;
1700     case U32:
1701       return Get<uint32>(indices) == 0;
1702     case U64:
1703       return Get<uint64>(indices) == 0;
1704     case S8:
1705       return Get<int8>(indices) == 0;
1706     case S16:
1707       return Get<int16>(indices) == 0;
1708     case S32:
1709       return Get<int32>(indices) == 0;
1710     case S64:
1711       return Get<int64>(indices) == 0;
1712     case F32:
1713       return Get<float>(indices) == 0.0f;
1714     case F64:
1715       return Get<double>(indices) == 0.0;
1716     case C64:
1717       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1718     case C128:
1719       return Get<complex128>(indices) == complex128(0.0f, 0.0f);
1720     case F16:
1721       return Get<half>(indices) == static_cast<half>(0.0f);
1722     case BF16:
1723       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
1724     case PRED:
1725       return Get<bool>(indices) == false;
1726     default:
1727       LOG(FATAL) << "Input literal must be an array.";
1728   }
1729 }
1730 
1731 namespace {
1732 
1733 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)1734 void CopyToRepeatedField(RepeatedFieldT* dest,
1735                          const absl::Span<const NativeT> src) {
1736   *dest = RepeatedFieldT(src.begin(), src.end());
1737 }
1738 
1739 }  // namespace
1740 
WriteToProto(LiteralProto * proto) const1741 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
1742   *proto->mutable_shape() = subshape().ToProto();
1743   switch (subshape().element_type()) {
1744     case PRED:
1745       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
1746       break;
1747     case S8:
1748       proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
1749                      element_count());
1750       break;
1751     case U8:
1752       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
1753                      element_count());
1754       break;
1755     case U32:
1756       CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
1757       break;
1758     case U64:
1759       CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
1760       break;
1761     case S32:
1762       CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
1763       break;
1764     case S64:
1765       CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
1766       break;
1767     case U16:
1768       *proto->mutable_u16s() = string(
1769           reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
1770       if (!kLittleEndian) {
1771         ConvertEndianShort(proto->mutable_u16s());
1772       }
1773       break;
1774     case S16:
1775       *proto->mutable_s16s() = string(
1776           reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
1777       if (!kLittleEndian) {
1778         ConvertEndianShort(proto->mutable_s16s());
1779       }
1780       break;
1781     case F16:
1782       *proto->mutable_f16s() = string(
1783           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
1784       if (!kLittleEndian) {
1785         ConvertEndianShort(proto->mutable_f16s());
1786       }
1787       break;
1788     case BF16:
1789       *proto->mutable_bf16s() = string(
1790           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
1791       if (!kLittleEndian) {
1792         ConvertEndianShort(proto->mutable_bf16s());
1793       }
1794       break;
1795     case F32:
1796       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
1797       break;
1798     case F64:
1799       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
1800       break;
1801     case C64:
1802       for (complex64 value : data<complex64>()) {
1803         proto->add_c64s(value.real());
1804         proto->add_c64s(value.imag());
1805       }
1806       break;
1807     case C128:
1808       for (complex128 value : data<complex128>()) {
1809         proto->add_c128s(value.real());
1810         proto->add_c128s(value.imag());
1811       }
1812       break;
1813     case TUPLE:
1814     case TOKEN:
1815       // Nothing to do but assign the shape which is done above.
1816       return;
1817     default:
1818       // TODO(b/111551621): Support serializing more PrimitiveTypes.
1819       LOG(FATAL) << "Unhandled primitive type "
1820                  << PrimitiveType_Name(subshape().element_type());
1821   }
1822 }
1823 
untyped_data() const1824 const void* LiteralBase::Piece::untyped_data() const {
1825   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1826   return buffer();
1827 }
1828 
untyped_data()1829 void* LiteralBase::Piece::untyped_data() {
1830   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1831   return buffer();
1832 }
1833 
1834 namespace {
1835 
1836 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)1837 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
1838                              const RepeatedFieldT& src) {
1839   if (dest.size() != src.size()) {
1840     return InvalidArgument(
1841         "Expected %lu elements in LiteralProto repeated field, has %d",
1842         dest.size(), src.size());
1843   }
1844   std::copy(src.begin(), src.end(), dest.begin());
1845   return Status::OK();
1846 }
1847 
1848 }  // namespace
1849 
CopyFromProto(const LiteralProto & proto)1850 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
1851   // These conditions should have been checked in
1852   // MutableLiteralBase::CreateFromProto.
1853   TF_RET_CHECK(proto.has_shape());
1854   Shape shape(proto.shape());
1855   TF_RET_CHECK(LayoutUtil::HasLayout(shape));
1856   TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
1857 
1858   switch (subshape().element_type()) {
1859     case PRED:
1860       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
1861       break;
1862     case S8: {
1863       auto s8_data = data<int8>();
1864       TF_RET_CHECK(proto.s8s().size() == s8_data.size());
1865       std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
1866     } break;
1867     case U8: {
1868       auto u8_data = data<uint8>();
1869       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
1870       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
1871     } break;
1872     case S32:
1873       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
1874       break;
1875     case S64:
1876       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
1877       break;
1878     case U32:
1879       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
1880       break;
1881     case U64:
1882       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
1883       break;
1884     case S16: {
1885       const string& s(proto.s16s());
1886       TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
1887       memcpy(untyped_data(), s.data(), s.size());
1888       if (!kLittleEndian) {
1889         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1890       }
1891     } break;
1892     case U16: {
1893       const string& s(proto.u16s());
1894       TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
1895       memcpy(untyped_data(), s.data(), s.size());
1896       if (!kLittleEndian) {
1897         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1898       }
1899     } break;
1900     case F16: {
1901       const string& s(proto.f16s());
1902       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
1903       memcpy(untyped_data(), s.data(), s.size());
1904       if (!kLittleEndian) {
1905         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1906       }
1907     } break;
1908 
1909     case BF16: {
1910       const string& s(proto.bf16s());
1911       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
1912       memcpy(untyped_data(), s.data(), s.size());
1913       if (!kLittleEndian) {
1914         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1915       }
1916     } break;
1917     case F32:
1918       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
1919       break;
1920     case F64:
1921       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
1922       break;
1923     case C64: {
1924       auto complex_data = data<complex64>();
1925       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
1926       for (int64 i = 0; i < complex_data.size(); ++i) {
1927         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
1928       }
1929       break;
1930     }
1931     case C128: {
1932       auto complex_data = data<complex128>();
1933       TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
1934       for (int64 i = 0; i < complex_data.size(); ++i) {
1935         complex_data[i] =
1936             complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
1937       }
1938       break;
1939     }
1940     case TUPLE:
1941       return InvalidArgument("Should not be called on tuple shapes: %s",
1942                              ShapeUtil::HumanString(subshape()));
1943     default:
1944       return InvalidArgument("Is called on unsupported shape: %s",
1945                              ShapeUtil::HumanString(subshape()));
1946   }
1947   return Status::OK();
1948 }
1949 
ToProto() const1950 LiteralProto LiteralBase::ToProto() const {
1951   LiteralProto proto;
1952   root_piece().ForEachSubpiece(
1953       [&](const ShapeIndex& index, const Piece& piece) {
1954         LiteralProto* proto_piece = &proto;
1955         for (int64 i : index) {
1956           while (proto_piece->tuple_literals_size() <= i) {
1957             proto_piece->add_tuple_literals();
1958           }
1959           proto_piece = proto_piece->mutable_tuple_literals(i);
1960         }
1961         piece.WriteToProto(proto_piece);
1962       });
1963 
1964   return proto;
1965 }
1966 
untyped_data(const ShapeIndex & shape_index) const1967 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
1968   return piece(shape_index).untyped_data();
1969 }
1970 
untyped_data(const ShapeIndex & shape_index)1971 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
1972   return piece(shape_index).untyped_data();
1973 }
1974 
size_bytes(const ShapeIndex & shape_index) const1975 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
1976   return piece(shape_index).size_bytes();
1977 }
1978 
GetR1U8AsString() const1979 string LiteralBase::GetR1U8AsString() const {
1980   CHECK(shape().IsArray());
1981   CHECK_EQ(shape().rank(), 1);
1982   CHECK_EQ(shape().element_type(), U8);
1983   return string(absl::bit_cast<const char*>(data<uint8>().data()),
1984                 ShapeUtil::ElementsIn(shape()));
1985 }
1986 
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)1987 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
1988                                                Piece* src_piece,
1989                                                Piece* dest_piece) {
1990   DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
1991       << "src_piece has shape: "
1992       << ShapeUtil::HumanString(src_piece->subshape())
1993       << "dest_piece has shape: "
1994       << ShapeUtil::HumanString(dest_piece->subshape());
1995   if (shape.IsTuple()) {
1996     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1997       const Shape& subshape = shape.tuple_shapes(i);
1998 
1999       auto child_piece = Piece();
2000       child_piece.set_subshape(&subshape);
2001 
2002       CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2003 
2004       dest_piece->emplace_back(std::move(child_piece));
2005     }
2006   } else if (shape.IsArray()) {
2007     dest_piece->set_buffer(src_piece->buffer());
2008   } else {
2009     // If the shape is neither an array nor tuple, then it must be
2010     // zero-sized. Otherwise, some memory needs to be allocated for it.
2011     CHECK_EQ(dest_piece->size_bytes(), 0);
2012   }
2013 }
2014 
~MutableLiteralBase()2015 MutableLiteralBase::~MutableLiteralBase() {}
2016 
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2017 MutableBorrowingLiteral::MutableBorrowingLiteral(
2018     const MutableBorrowingLiteral& literal)
2019     : MutableLiteralBase() {
2020   shape_ = absl::make_unique<Shape>(literal.shape());
2021   CHECK(LayoutUtil::HasLayout(*shape_));
2022 
2023   root_piece_ = new Piece();
2024   root_piece_->set_subshape(shape_.get());
2025 
2026   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2027 }
2028 
operator =(const MutableBorrowingLiteral & literal)2029 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2030     const MutableBorrowingLiteral& literal) {
2031   shape_ = absl::make_unique<Shape>(literal.shape());
2032   CHECK(LayoutUtil::HasLayout(*shape_));
2033 
2034   root_piece_ = new Piece();
2035   root_piece_->set_subshape(shape_.get());
2036 
2037   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2038 
2039   return *this;
2040 }
2041 
MutableBorrowingLiteral(MutableLiteralBase * literal)2042 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2043     : MutableLiteralBase() {
2044   shape_ = absl::make_unique<Shape>(literal->shape());
2045   CHECK(LayoutUtil::HasLayout(*shape_));
2046 
2047   root_piece_ = new Piece();
2048   root_piece_->set_subshape(shape_.get());
2049 
2050   CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2051 }
2052 
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2053 MutableBorrowingLiteral::MutableBorrowingLiteral(
2054     MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2055     : MutableLiteralBase() {
2056   shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2057   CHECK(LayoutUtil::HasLayout(*shape_));
2058 
2059   root_piece_ = new Piece();
2060   root_piece_->set_subshape(shape_.get());
2061 
2062   CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2063 }
2064 
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2065 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2066                                                  const Shape& shape)
2067     : MutableLiteralBase() {
2068   shape_ = absl::make_unique<Shape>(shape);
2069   CHECK(LayoutUtil::HasLayout(*shape_));
2070   CHECK(!shape_->IsTuple());
2071 
2072   root_piece_ = new Piece();
2073   root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2074   root_piece_->set_subshape(shape_.get());
2075 }
2076 
~MutableBorrowingLiteral()2077 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2078   if (root_piece_ != nullptr) {
2079     delete root_piece_;
2080   }
2081 }
2082 
LiteralSlice(const LiteralBase & literal)2083 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2084     : LiteralBase(), root_piece_(&literal.root_piece()) {}
2085 
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2086 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2087                            const ShapeIndex& view_root)
2088     : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2089 
BuildPieceSubtree(const Shape & shape,Piece * piece)2090 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2091   CHECK(shape.IsTuple());
2092   for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2093     const Shape& subshape = shape.tuple_shapes(i);
2094 
2095     auto child_piece = Piece();
2096     child_piece.set_subshape(&subshape);
2097 
2098     if (subshape.IsTuple()) {
2099       BuildPieceSubtree(subshape, &child_piece);
2100     }
2101 
2102     piece->emplace_back(std::move(child_piece));
2103   }
2104 }
2105 
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2106 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2107     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2108   CHECK(shape_->IsArray());
2109   CHECK(LayoutUtil::HasLayout(*shape_));
2110 
2111   root_piece_ = Piece();
2112   root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2113   root_piece_.set_subshape(shape_.get());
2114 }
2115 
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2116 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2117                                    const Shape& shape)
2118     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2119   CHECK(shape_->IsTuple());
2120   CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2121   CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2122   root_piece_ = Piece();
2123   root_piece_.set_subshape(shape_.get());
2124   BuildPieceSubtree(*shape_, &root_piece_);
2125 
2126   for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2127     const auto& src_shape = shape_->tuple_shapes(i);
2128     CHECK(src_shape.IsArray());
2129     root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2130   }
2131 }
2132 
2133 }  // namespace xla
2134