• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <type_traits>
24 #include <vector>
25 
26 #include "absl/base/casts.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/types/optional.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/permutation_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/mem.h"
46 #include "tensorflow/core/platform/types.h"
47 
48 namespace xla {
49 namespace {
50 
51 using absl::StrCat;
52 
53 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
54 // Literals can be used as DMA targets, which can require alignment. We
55 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
56 // alignment.
57 constexpr int kMinimumAlignment = 64;
58 
59 // Converts between little and big endian.
60 //
61 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)62 void ConvertEndianShort(string* bytes) {
63   CHECK_EQ(bytes->size() % 2, 0);
64   for (int64_t i = 0, end = bytes->size(); i < end; i += 2) {
65     std::swap((*bytes)[i], (*bytes)[i + 1]);
66   }
67 }
68 
ConvertEndianShort(char * bytes,int64_t size)69 void ConvertEndianShort(char* bytes, int64_t size) {
70   CHECK_EQ(size % 2, 0);
71   for (int64_t i = 0; i < size; i += 2) {
72     std::swap(bytes[i], bytes[i + 1]);
73   }
74 }
75 
CompactOneline(const string & input)76 string CompactOneline(const string& input) {
77   string result;
78   std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
79   bool first = true;
80   // Concatenate elements in "v" with spaces separating them, but ignoring
81   // empty entries.
82   for (const auto& s : v) {
83     if (s.empty()) {
84       continue;
85     }
86     absl::StrAppend(&result, (first ? "" : " "), s);
87     first = false;
88   }
89   return result;
90 }
91 
92 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
93 // able to transparently access the raw 16-bit value contained within.
94 template <typename T>
GetRawValue(T val)95 T GetRawValue(T val) {
96   return val;
97 }
GetRawValue(Eigen::half val)98 uint16 GetRawValue(Eigen::half val) {
99   return Eigen::numext::bit_cast<uint16>(val);
100 }
101 
LiteralProtoHasValues(const LiteralProto & proto)102 bool LiteralProtoHasValues(const LiteralProto& proto) {
103   return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
104          proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
105          proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
106          proto.c64s_size() || proto.c128s_size() ||
107          proto.tuple_literals_size() || !proto.f16s().empty() ||
108          !proto.bf16s().empty() || !proto.u16s().empty() ||
109          !proto.s16s().empty();
110 }
111 
112 }  // namespace
113 
~LiteralBase()114 LiteralBase::~LiteralBase() {}
115 
operator <<(std::ostream & out,const Literal & literal)116 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
117   out << literal.ToString();
118   return out;
119 }
120 
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)121 MutableLiteralBase::StrideConfig::StrideConfig(
122     const Shape& source_shape, const Shape& dest_shape,
123     absl::Span<const int64> dimensions)
124     : dimensions(dimensions),
125       base(dimensions.size(), 0),
126       step(dimensions.size(), 1) {
127   if (!dimensions.empty()) {
128     // Selects the shape with the largest minor dimension as the one upon
129     // which to run the tight stride loop.
130     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
131         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
132       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
133       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
134     } else {
135       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
136       source_stride =
137           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
138     }
139     minor_loop_size = dimensions[minor_dimension];
140     step[minor_dimension] = minor_loop_size;
141   }
142 }
143 
Literal(const Shape & shape)144 Literal::Literal(const Shape& shape)
145     : Literal(shape, /*allocate_arrays=*/true) {}
146 
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)147 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
148   if (shape.IsTuple()) {
149     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
150       const Shape& subshape = shape.tuple_shapes(i);
151 
152       auto child_piece = Piece();
153       child_piece.set_subshape(&subshape);
154 
155       SetPiece(subshape, &child_piece, allocate_arrays);
156 
157       piece->emplace_back(std::move(child_piece));
158     }
159   } else if (shape.IsArray()) {
160     if (allocate_arrays) {
161       piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
162           piece->size_bytes(), kMinimumAlignment)));
163       if (shape.is_dynamic()) {
164         CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
165         piece->set_dynamic_size_buffer(
166             static_cast<int32*>(tensorflow::port::AlignedMalloc(
167                 piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
168       }
169     }
170   } else {
171     // If the shape is neither an array nor tuple, then it must be
172     // zero-sized. Otherwise, some memory needs to be allocated for it.
173     CHECK_EQ(piece->size_bytes(), 0);
174   }
175 }
176 
Literal(const Shape & shape,bool allocate_arrays)177 Literal::Literal(const Shape& shape, bool allocate_arrays)
178     : MutableLiteralBase() {
179   shape_ = absl::make_unique<Shape>(shape);
180   CHECK(LayoutUtil::HasLayout(*shape_));
181   root_piece_ = new Piece();
182   root_piece_->set_subshape(shape_.get());
183   CHECK(&root_piece_->subshape() == shape_.get());
184 
185   SetPiece(*shape_, root_piece_, allocate_arrays);
186 }
187 
~Literal()188 Literal::~Literal() {
189   if (root_piece_ != nullptr) {
190     DeallocateBuffers();
191     delete root_piece_;
192   }
193 }
194 
DeallocateBuffers()195 void Literal::DeallocateBuffers() {
196   root_piece_->ForEachMutableSubpiece(
197       [&](const ShapeIndex& index, Piece* piece) {
198         if (piece->buffer() != nullptr) {
199           tensorflow::port::AlignedFree(piece->buffer());
200         }
201         if (piece->dynamic_size_buffer() != nullptr) {
202           tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
203         }
204       });
205 }
206 
Literal(Literal && other)207 Literal::Literal(Literal&& other) : MutableLiteralBase() {
208   *this = std::move(other);
209 }
210 
operator =(Literal && other)211 Literal& Literal::operator=(Literal&& other) {
212   DCHECK(&other.root_piece_->subshape() == other.shape_.get());
213   using std::swap;
214   swap(shape_, other.shape_);
215   swap(root_piece_, other.root_piece_);
216   DCHECK(&root_piece_->subshape() == shape_.get());
217 
218   return *this;
219 }
220 
CreateFromShape(const Shape & shape)221 Literal LiteralBase::CreateFromShape(const Shape& shape) {
222   Literal literal(shape);
223   literal.root_piece_->ForEachMutableSubpiece(
224       [&](const ShapeIndex& index, Piece* piece) {
225         if (piece->subshape().IsArray()) {
226           memset(piece->untyped_data(), 0, piece->size_bytes());
227         }
228       });
229   return literal;
230 }
231 
GetDynamicSize(int64_t dim_index) const232 int32 LiteralBase::GetDynamicSize(int64_t dim_index) const {
233   return GetDynamicSize(dim_index, {});
234 }
235 
GetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index) const236 int32 LiteralBase::GetDynamicSize(int64_t dim_index,
237                                   const ShapeIndex& shape_index) const {
238   return piece(shape_index).GetDynamicSize(dim_index);
239 }
240 
GetFirstInteger() const241 absl::optional<int64> LiteralBase::GetFirstInteger() const {
242   switch (shape().element_type()) {
243     case U8:
244       return GetFirstElement<uint8>();
245     case U16:
246       return GetFirstElement<uint16>();
247     case U32:
248       return GetFirstElement<uint32>();
249     case U64: {
250       int64_t v = GetFirstElement<uint64>();
251       if (v < 0) {
252         return absl::nullopt;
253       }
254       return v;
255     }
256     case S8:
257       return GetFirstElement<int8>();
258     case S16:
259       return GetFirstElement<int16>();
260     case S32:
261       return GetFirstElement<int32>();
262     case S64:
263       return GetFirstElement<int64>();
264     default:
265       return absl::nullopt;
266   }
267 }
268 
269 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)270 Status MutableLiteralBase::CopySliceFromInternal(
271     const LiteralBase& src_literal, absl::Span<const int64> src_base,
272     absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
273   const int64_t src_base_size = src_base.size();
274   const int64_t dest_base_size = dest_base.size();
275   TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
276   TF_RET_CHECK(shape().rank() == dest_base_size);
277 
278   auto linear_index = [](const Shape& shape,
279                          absl::Span<const int64> multi_index) {
280     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
281   };
282 
283   if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
284     // If any of the two shapes are scalars, we can just call the StridedCopy()
285     // directly, and we know we will be copying only one value.
286     TF_RET_CHECK(copy_size.empty());
287     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
288                 src_literal.data<NativeT>(),
289                 linear_index(src_literal.shape(), src_base), 0, 1);
290   } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
291              !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
292     // Perform copy if neither src nor dest has dimensions with zero element,
293     // otherwise it's a no-op.
294     TF_RET_CHECK(src_base.size() == dest_base.size());
295     TF_RET_CHECK(src_base.size() == copy_size.size());
296 
297     // Scan the source from minor, stepping in copy size blocks, then within
298     // the index enumeration functor, do a strided copy advancing source index
299     // by one (walking through the minor dimension), and destination index by
300     // proper stride size at the matching dimension.
301     DimensionVector src_indexes(src_base.size(), 0);
302     DimensionVector dest_indexes(dest_base.size(), 0);
303     MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
304                                                    copy_size);
305 
306     auto copy_proc = [&](absl::Span<const int64> indexes) {
307       // Map from multi-dimensional index, to source index.
308       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
309                      src_indexes.begin(), std::plus<int64>());
310       // Map from multi-dimensional index, to destination index.
311       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
312                      dest_indexes.begin(), std::plus<int64>());
313 
314       int64_t src_index = linear_index(src_literal.shape(), src_indexes);
315       int64_t dest_index = linear_index(shape(), dest_indexes);
316 
317       // `this->` is needed to workaround MSVC bug: #16882
318       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
319                   src_literal.data<NativeT>(), src_index,
320                   stride_config.source_stride, stride_config.minor_loop_size);
321       return true;
322     };
323 
324     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
325                             stride_config.dimensions, stride_config.step,
326                             copy_proc);
327   }
328   return Status::OK();
329 }
330 
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)331 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
332                                            absl::Span<const int64> src_index,
333                                            absl::Span<const int64> dest_index) {
334   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
335   const int64_t src_linear_index =
336       IndexUtil::MultidimensionalIndexToLinearIndex(src_literal.shape(),
337                                                     src_index);
338   const int64_t dest_linear_index =
339       IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
340   const int64_t primitive_size =
341       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
342 
343   char* dest_address =
344       static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
345   const char* source_address =
346       static_cast<const char*>(src_literal.untyped_data()) +
347       src_linear_index * primitive_size;
348   if (dest_address != source_address) {
349     memcpy(dest_address, source_address, primitive_size);
350   }
351   return Status::OK();
352 }
353 
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)354 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
355     const LiteralProto& proto, bool prohibit_empty_literal) {
356   if (!proto.has_shape()) {
357     return InvalidArgument("LiteralProto has no shape");
358   }
359   Shape shape(proto.shape());
360   if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
361     return InvalidArgument(
362         "Literal shape cannot include OPAQUE_TYPE sub-shape");
363   }
364   if (!LayoutUtil::HasLayout(shape)) {
365     return InvalidArgument("LiteralProto has no layout");
366   }
367 
368   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
369 
370   Literal literal(shape);
371 
372   TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
373       [&](const ShapeIndex& index, Piece* piece) {
374         const LiteralProto* proto_element = &proto;
375         for (int64_t i : index) {
376           CHECK(i < proto_element->tuple_literals_size());
377           proto_element = &proto_element->tuple_literals(i);
378         }
379 
380         if (piece->subshape().IsTuple()) {
381           if (proto_element->tuple_literals_size() !=
382               ShapeUtil::TupleElementCount(piece->subshape())) {
383             return InvalidArgument(
384                 "Expected %d tuple elements in LiteralProto, has %d",
385                 ShapeUtil::TupleElementCount(piece->subshape()),
386                 proto_element->tuple_literals_size());
387           }
388           return Status::OK();
389         }
390         if (piece->subshape().element_type() == TOKEN) {
391           return Status::OK();
392         }
393 
394         CHECK(piece->subshape().IsArray());
395 
396         // When prohibit_empty_literal is false (allowing literal with no
397         // values), only copy from proto if the literal proto has values. This
398         // mode is used for a learned cost model.
399         if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
400           TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
401         }
402 
403         return Status::OK();
404       }));
405 
406   return std::move(literal);
407 }
408 
SubLiteral(ShapeIndexView shape_index)409 Literal Literal::SubLiteral(ShapeIndexView shape_index) {
410   if (!shape_index.empty()) {
411     auto decomposed = this->DecomposeTuple();
412     return decomposed.at(shape_index.front())
413         .SubLiteral(shape_index.ConsumeFront());
414   } else {
415     return std::move(*this);
416   }
417 }
418 
DecomposeTuple()419 std::vector<Literal> Literal::DecomposeTuple() {
420   CHECK(shape().IsTuple());
421   std::vector<Literal> elements;
422   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
423     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
424                                /*allocate_arrays=*/false));
425     Literal& element = elements.back();
426     element.root_piece_->ForEachMutableSubpiece(
427         [&](const ShapeIndex& index, Piece* dest_piece) {
428           ShapeIndex src_index = {i};
429           for (int64_t j : index) {
430             src_index.push_back(j);
431           }
432           Piece& src_piece = piece(src_index);
433 
434           // Move the respective buffer over to the element Literal.
435           dest_piece->set_buffer(src_piece.buffer());
436           dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
437           src_piece.set_buffer(nullptr);
438           src_piece.set_dynamic_size_buffer(nullptr);
439         });
440   }
441   // Set this literal to be nil-shaped.
442   *this = Literal();
443   return elements;
444 }
445 
446 namespace {
447 
448 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
449 // the array slices are indicated by dest_shape and src_shape respectively.
450 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)451 void CopyElementsBetween(absl::Span<NativeT> dest,
452                          absl::Span<const NativeT> src, const Shape& dest_shape,
453                          const Shape& src_shape) {
454   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
455   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
456     return;
457   }
458   std::vector<int64> index(dest_shape.rank());
459   do {
460     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
461         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
462   } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
463 }
464 }  // namespace
465 
GetDynamicSize(int64_t dim_index) const466 int32 LiteralBase::Piece::GetDynamicSize(int64_t dim_index) const {
467   CHECK(LayoutUtil::IsDenseArray(subshape()));
468   if (!subshape_->is_dynamic_dimension(dim_index)) {
469     // This is a static dimension, return size.
470     return subshape_->dimensions(dim_index);
471   }
472   CHECK_NE(dynamic_size_buffer(), nullptr);
473   return dynamic_size_buffer_[dim_index];
474 }
475 
SetDynamicSize(int64_t dim_index,int32_t size)476 void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
477   CHECK(LayoutUtil::IsDenseArray(subshape()));
478   CHECK(subshape_->is_dynamic_dimension(dim_index));
479   if (dynamic_size_buffer() == nullptr) {
480     // Lazily initialize the dynamic size buffer.
481     set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
482         dynamic_size_buffer_bytes(), kMinimumAlignment)));
483     /*for (int64 i = 0; i < subshape().rank(); ++i) {
484       // Initialized to -1 to help debug.
485       dynamic_size_buffer_[i] = -1;
486     }*/
487   }
488   dynamic_size_buffer_[dim_index] = size;
489 }
490 
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)491 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
492                                     bool only_dynamic_bound) {
493   CHECK(subshape_ != nullptr);
494   CHECK(src.subshape_ != nullptr);
495   if (ShapeUtil::Equal(subshape(), src.subshape())) {
496     // If the layouts are equal it's faster just to memcpy.
497     memcpy(buffer(), src.buffer(), src.size_bytes());
498   } else {
499     std::vector<int64> origin(subshape().rank(), 0);
500     switch (subshape().element_type()) {
501 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                      \
502   case (XLA_T):                                                             \
503     if (only_dynamic_bound) {                                               \
504       CopyElementsWithDynamicBound<NATIVE_T>(src);                          \
505     } else {                                                                \
506       CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
507                                     subshape(), src.subshape());            \
508     }                                                                       \
509     break;
510       COPY_ELEMENTS(U8, uint8);
511       COPY_ELEMENTS(U16, uint16);
512       COPY_ELEMENTS(U32, uint32);
513       COPY_ELEMENTS(U64, uint64);
514       COPY_ELEMENTS(S8, int8);
515       COPY_ELEMENTS(S16, int16);
516       COPY_ELEMENTS(S32, int32);
517       COPY_ELEMENTS(S64, int64);
518       COPY_ELEMENTS(F16, half);
519       COPY_ELEMENTS(BF16, bfloat16);
520       COPY_ELEMENTS(F32, float);
521       COPY_ELEMENTS(F64, double);
522       COPY_ELEMENTS(C64, complex64);
523       COPY_ELEMENTS(C128, complex128);
524       COPY_ELEMENTS(PRED, bool);
525 #undef COPY_ELEMENTS
526       default:
527         return Unimplemented(
528             "Copying a Literal object with element type %s is not implemented.",
529             PrimitiveType_Name(subshape().element_type()));
530     }
531   }
532   DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
533   if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
534     CHECK_NE(dynamic_size_buffer_, nullptr);
535     CHECK_NE(src.dynamic_size_buffer_, nullptr);
536     memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
537            src.dynamic_size_buffer_bytes());
538   }
539   return Status::OK();
540 }
541 
SetDynamicSize(int64_t dim_index,int32_t size)542 void MutableLiteralBase::SetDynamicSize(int64_t dim_index, int32_t size) {
543   return SetDynamicSize(dim_index, {}, size);
544 }
545 
SetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index,int32_t size)546 void MutableLiteralBase::SetDynamicSize(int64_t dim_index,
547                                         const ShapeIndex& shape_index,
548                                         int32_t size) {
549   Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
550   CHECK_GE(subshape_->dimensions(dim_index), size);
551   if (subshape_->dimensions(dim_index) == size) {
552     subshape_->set_dynamic_dimension(dim_index, false);
553     return;
554   }
555   subshape_->set_dynamic_dimension(dim_index, true);
556   piece(shape_index).SetDynamicSize(dim_index, size);
557 }
558 
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)559 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
560                                     const ShapeIndex& dest_shape_index,
561                                     const ShapeIndex& src_shape_index,
562                                     bool only_dynamic_bound) {
563   const Shape& dest_subshape =
564       ShapeUtil::GetSubshape(shape(), dest_shape_index);
565   const Shape& src_subshape =
566       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
567   if (only_dynamic_bound) {
568     auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
569     auto compact_shape =
570         dest_subshape.is_static() ? dest_subshape : src_subshape;
571     CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
572         << compact_shape.ToString() << " vs " << bound_shape.ToString();
573   } else {
574     if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
575       return InvalidArgument(
576           "Destination subshape incompatible with source subshape: %s vs %s",
577           ShapeUtil::HumanString(dest_subshape),
578           ShapeUtil::HumanString(src_subshape));
579     }
580   }
581   return root_piece_->ForEachMutableSubpieceWithStatus(
582       [&](const ShapeIndex& index, Piece* piece) {
583         if (!piece->subshape().IsArray()) {
584           return Status::OK();
585         }
586 
587         // Determine if this index is in the part of this literal that we want
588         // to copy over from src_literal.
589         bool in_subtree_to_copy = true;
590         for (int i = 0; i < dest_shape_index.size(); ++i) {
591           if (index[i] != dest_shape_index[i]) {
592             in_subtree_to_copy = false;
593             break;
594           }
595         }
596         if (!in_subtree_to_copy) {
597           return Status::OK();
598         }
599         // Construct the index of the corresponding piece in the source literal.
600         ShapeIndex src_piece_index = src_shape_index;
601         for (int64_t i = dest_shape_index.size(), end = index.size(); i < end;
602              ++i) {
603           src_piece_index.push_back(index[i]);
604         }
605         TF_RETURN_IF_ERROR(
606             piece->CopyFrom(src_literal.piece(src_piece_index),
607                             /*only_dynamic_bound=*/only_dynamic_bound));
608         return Status::OK();
609       });
610 }
611 
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)612 Status Literal::MoveFrom(Literal&& src_literal,
613                          const ShapeIndex& dest_shape_index) {
614   const Shape& dest_subshape =
615       ShapeUtil::GetSubshape(shape(), dest_shape_index);
616   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
617     return InvalidArgument(
618         "Destination subshape not equal to source shape: %s vs %s",
619         ShapeUtil::HumanString(dest_subshape),
620         ShapeUtil::HumanString(src_literal.shape()));
621   }
622 
623   src_literal.root_piece_->ForEachSubpiece(
624       [&](const ShapeIndex& src_index, const Piece& src_piece) {
625         if (!src_piece.subshape().IsArray()) {
626           return;
627         }
628 
629         ShapeIndex dest_index = dest_shape_index;
630         for (int64_t i : src_index) {
631           dest_index.push_back(i);
632         }
633         Piece& dest_piece = piece(dest_index);
634         tensorflow::port::AlignedFree(dest_piece.buffer());
635         tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
636         dest_piece.set_buffer(src_piece.buffer());
637         dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
638       });
639 
640   src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
641   delete src_literal.root_piece_;
642   src_literal.root_piece_ = new LiteralBase::Piece();
643   src_literal.root_piece_->set_subshape(src_literal.shape_.get());
644 
645   return Status::OK();
646 }
647 
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)648 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
649                                          absl::Span<const int64> src_base,
650                                          absl::Span<const int64> dest_base,
651                                          absl::Span<const int64> copy_size) {
652   TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
653   TF_RET_CHECK(src_literal.shape().IsArray())
654       << ShapeUtil::HumanString(src_literal.shape());
655   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
656 
657   switch (shape().element_type()) {
658     case U8:
659       return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
660                                           copy_size);
661     case U16:
662       return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
663                                            copy_size);
664     case U32:
665       return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
666                                            copy_size);
667     case U64:
668       return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
669                                            copy_size);
670     case S8:
671       return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
672                                          copy_size);
673     case S16:
674       return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
675                                           copy_size);
676     case S32:
677       return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
678                                           copy_size);
679     case S64:
680       return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
681                                           copy_size);
682     case F16:
683       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
684                                          copy_size);
685     case BF16:
686       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
687                                              copy_size);
688     case F32:
689       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
690                                           copy_size);
691     case F64:
692       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
693                                            copy_size);
694     case C64:
695       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
696                                               copy_size);
697     case C128:
698       return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
699                                                copy_size);
700     case PRED:
701       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
702                                          copy_size);
703     default:
704       break;
705   }
706   return Unimplemented(
707       "Copying a slice from a Literal object with element type %d is not "
708       "implemented.",
709       shape().element_type());
710 }
711 
PopulateR1(const tensorflow::core::Bitmap & values)712 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
713   CHECK(shape().IsArray());
714   CHECK_EQ(shape().rank(), 1);
715   CHECK_EQ(element_count(), values.bits());
716   CHECK_EQ(shape().element_type(), PRED);
717   for (int64_t i = 0; i < static_cast<int64>(values.bits()); ++i) {
718     Set({i}, values.get(i));
719   }
720 }
721 
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const722 Literal LiteralBase::Relayout(const Layout& new_layout,
723                               const ShapeIndex& shape_index) const {
724   // Create new shape with 'new_layout' set at the given shape index.
725   Shape new_shape = shape();
726   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
727   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
728   *subshape->mutable_layout() = new_layout;
729   Literal result(new_shape);
730   TF_CHECK_OK(result.CopyFrom(*this));
731   return result;
732 }
733 
Relayout(const Shape & shape_with_layout) const734 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
735   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
736       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
737       << " not compatible with literal shape "
738       << ShapeUtil::HumanString(shape());
739   Literal result = CreateFromShape(shape_with_layout);
740   ShapeUtil::ForEachSubshape(
741       result.shape(),
742       [this, &result](const Shape& subshape, const ShapeIndex& index) {
743         if (subshape.IsArray()) {
744           TF_CHECK_OK(result.CopyFrom(*this,
745                                       /*dest_shape_index=*/index,
746                                       /*src_shape_index=*/index));
747         }
748       });
749   return result;
750 }
751 
ToBoundedDynamic(const Shape & bounded_shape) const752 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
753   CHECK(bounded_shape.is_dynamic());
754   Literal result(bounded_shape);
755   ShapeUtil::ForEachSubshape(
756       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
757         if (!subshape.IsArray()) {
758           return;
759         }
760         for (int64_t i = 0; i < subshape.rank(); ++i) {
761           result.SetDynamicSize(i, subshape.dimensions(i));
762         }
763       });
764   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
765 
766   return result;
767 }
768 
ToStatic() const769 Literal LiteralBase::ToStatic() const {
770   // Create new shape with 'new_layout' set at the given shape index.
771   Shape new_shape = shape();
772   ShapeUtil::ForEachMutableSubshape(
773       &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
774         if (!subshape->IsArray()) {
775           return;
776         }
777         for (int64_t i = 0; i < subshape->rank(); ++i) {
778           subshape->set_dynamic_dimension(i, false);
779           subshape->set_dimensions(i, GetDynamicSize(i, index));
780         }
781       });
782   Literal result(new_shape);
783   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
784   return result;
785 }
786 
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const787 StatusOr<Literal> LiteralBase::Broadcast(
788     const Shape& result_shape, absl::Span<const int64> dimensions) const {
789   if (!shape().IsArray()) {
790     return InvalidArgument("Broadcast only supports arrays.");
791   }
792 
793   for (int64_t i = 0, end = dimensions.size(); i < end; i++) {
794     TF_RET_CHECK(shape().dimensions(i) ==
795                  result_shape.dimensions(dimensions[i]));
796   }
797 
798   Literal result(result_shape);
799 
800   // scratch_source_index is temporary storage space for the computed index into
801   // the input literal.  We put it here to avoid allocating an std::vector in
802   // every iteration of ShapeUtil::ForEachIndex.
803   std::vector<int64> scratch_source_index(shape().dimensions_size());
804 
805   char* dest_data = static_cast<char*>(result.untyped_data());
806   const char* source_data = static_cast<const char*>(untyped_data());
807   const int64_t primitive_size =
808       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
809 
810   for (int64_t i = 0; i < dimensions.size(); ++i) {
811     int64_t dynamic_size = GetDynamicSize(i);
812     result.SetDynamicSize(dimensions[i], dynamic_size);
813   }
814 
815   ShapeUtil::ForEachIndex(
816       result_shape, [&](absl::Span<const int64> output_index) {
817         for (int64_t i = 0, end = dimensions.size(); i < end; ++i) {
818           scratch_source_index[i] = output_index[dimensions[i]];
819         }
820         int64_t dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
821             result_shape, output_index);
822         int64_t source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
823             shape(), scratch_source_index);
824         memcpy(dest_data + primitive_size * dest_index,
825                source_data + primitive_size * source_index, primitive_size);
826         return true;
827       });
828 
829   return std::move(result);
830 }
831 
Reshape(absl::Span<const int64> dimensions) const832 StatusOr<Literal> LiteralBase::Reshape(
833     absl::Span<const int64> dimensions) const {
834   if (!shape().IsArray()) {
835     return InvalidArgument("Reshape does not support tuples.");
836   }
837   if (shape().is_dynamic()) {
838     return Unimplemented("Dynamic reshape is not implemented.");
839   }
840   Literal output;
841   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
842     output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
843   } else {
844     output = Clone();
845   }
846   // Because the layout is monotonic, we can simply reuse the same sequence of
847   // values without changing their order.
848   *output.mutable_shape_do_not_use() =
849       ShapeUtil::MakeShape(shape().element_type(), dimensions);
850 
851   int64_t elements_before = ShapeUtil::ElementsIn(shape());
852   int64_t elements_after = ShapeUtil::ElementsIn(output.shape());
853   if (elements_before != elements_after) {
854     return InvalidArgument(
855         "Shapes before and after Literal::Reshape have different numbers "
856         "of elements: %s vs %s.",
857         ShapeUtil::HumanString(shape()),
858         ShapeUtil::HumanString(output.shape()));
859   }
860   return std::move(output);
861 }
862 
Transpose(absl::Span<const int64> permutation) const863 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
864   CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
865   CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
866       << "Given permutation is not a permutation of dimension numbers";
867   // To transpose the array, we just permute the dimensions and layout, and
868   // do a straight memory copy of the raw data set.
869   // This is considerably faster than iterating over every array element using
870   // the EachCell<>() and Set<>() APIs.
871   Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
872   // Replace the layout with one affine to this shape, such that a
873   // transpose operation can be performed by leaving the flat values
874   // representation intact.
875   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
876   // The shape with affine layout resulting from that operation will be
877   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
878   // most minor.
879   //
880   // Essentially, given MinMaj(Di) the position of the Di dimension within the
881   // minor to major vector, and given T(Di) the index that the original Di
882   // dimension has within the transposed array, a layout is affine if
883   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
884   // vector of the affine layout.
885   std::vector<int64> inverse_permutation = InversePermutation(permutation);
886   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
887   Layout* layout = permuted_shape.mutable_layout();
888   layout->clear_minor_to_major();
889   for (auto index : LayoutUtil::MinorToMajor(shape())) {
890     layout->add_minor_to_major(inverse_permutation[index]);
891   }
892   Literal new_literal(permuted_shape);
893   for (int64_t i = 0; i < shape().rank(); i++) {
894     new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
895   }
896   DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
897             ShapeUtil::ByteSizeOf(shape()));
898   std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
899   return new_literal;
900 }
901 
902 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const903 Literal LiteralBase::SliceInternal(
904     const Shape& result_shape, absl::Span<const int64> start_indices) const {
905   Literal result_literal(result_shape);
906   DimensionVector new_indices(result_shape.rank());
907   CHECK(result_literal
908             .Populate<NativeT>([&](absl::Span<const int64> indices) {
909               for (int64_t i = 0; i < result_shape.rank(); ++i) {
910                 new_indices[i] = indices[i] + start_indices[i];
911               }
912               return Get<NativeT>(new_indices);
913             })
914             .ok());
915   for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
916     if (shape().is_dynamic_dimension(dnum)) {
917       int64_t dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
918       CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
919       dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
920       result_literal.SetDynamicSize(dnum, dynamic_size);
921     }
922   }
923   return result_literal;
924 }
925 
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const926 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
927                            absl::Span<const int64> limit_indices) const {
928   CHECK(shape().IsArray()) << "tuple is not supported for slice";
929 
930   DimensionVector result_dimensions;
931   for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
932     CHECK_GE(start_indices[dnum], 0);
933     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
934         << "dnum = " << dnum;
935     int64_t dimension = limit_indices[dnum] - start_indices[dnum];
936     CHECK_GE(dimension, 0) << "dnum = " << dnum;
937     result_dimensions.push_back(dimension);
938   }
939   auto result_shape =
940       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
941                                      LayoutUtil::MinorToMajor(shape()));
942   ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
943   switch (result_shape.element_type()) {
944     case PRED:
945       return SliceInternal<bool>(result_shape, start_indices);
946     case U8:
947       return SliceInternal<uint8>(result_shape, start_indices);
948     case U16:
949       return SliceInternal<uint16>(result_shape, start_indices);
950     case U32:
951       return SliceInternal<uint32>(result_shape, start_indices);
952     case U64:
953       return SliceInternal<uint64>(result_shape, start_indices);
954     case S8:
955       return SliceInternal<int8>(result_shape, start_indices);
956     case S16:
957       return SliceInternal<int16>(result_shape, start_indices);
958     case S32:
959       return SliceInternal<int32>(result_shape, start_indices);
960     case S64:
961       return SliceInternal<int64>(result_shape, start_indices);
962     case F16:
963       return SliceInternal<half>(result_shape, start_indices);
964     case BF16:
965       return SliceInternal<bfloat16>(result_shape, start_indices);
966     case F32:
967       return SliceInternal<float>(result_shape, start_indices);
968     case F64:
969       return SliceInternal<double>(result_shape, start_indices);
970     case C64:
971       return SliceInternal<complex64>(result_shape, start_indices);
972     case C128:
973       return SliceInternal<complex128>(result_shape, start_indices);
974     default:
975       LOG(FATAL) << "not yet implemented: "
976                  << PrimitiveType_Name(result_shape.element_type());
977   }
978 }
979 
Clone() const980 Literal LiteralBase::Clone() const {
981   Literal result(shape());
982   TF_CHECK_OK(result.CopyFrom(*this));
983   return result;
984 }
985 
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const986 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
987                                 const ShapeIndex& shape_index) const {
988   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
989   CHECK(LayoutUtil::IsDenseArray(subshape));
990   switch (subshape.element_type()) {
991     case PRED:
992       return Get<bool>(multi_index, shape_index) ? "true" : "false";
993     case S8:
994       return StrCat(Get<int8>(multi_index, shape_index));
995     case S16:
996       return StrCat(Get<int16>(multi_index, shape_index));
997     case S32:
998       return StrCat(Get<int32>(multi_index, shape_index));
999     case S64:
1000       return StrCat(Get<int64>(multi_index, shape_index));
1001     case U8:
1002       return StrCat(Get<uint8>(multi_index, shape_index));
1003     case U16:
1004       return StrCat(Get<uint16>(multi_index, shape_index));
1005     case U32:
1006       return StrCat(Get<uint32>(multi_index, shape_index));
1007     case U64:
1008       return StrCat(Get<uint64>(multi_index, shape_index));
1009     case F16:
1010       return RoundTripFpToString(Get<half>(multi_index, shape_index));
1011     case F32:
1012       return RoundTripFpToString(Get<float>(multi_index, shape_index));
1013     case BF16:
1014       return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1015     case F64:
1016       return RoundTripFpToString(Get<double>(multi_index, shape_index));
1017     case C64: {
1018       complex64 c = Get<complex64>(multi_index, shape_index);
1019       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1020                     RoundTripFpToString(c.imag()), ")");
1021     }
1022     case C128: {
1023       complex128 c = Get<complex128>(multi_index, shape_index);
1024       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1025                     RoundTripFpToString(c.imag()), ")");
1026     }
1027     default:
1028       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1029   }
1030 }
1031 
GetIntegralAsS64(absl::Span<const int64> multi_index) const1032 absl::optional<int64> LiteralBase::GetIntegralAsS64(
1033     absl::Span<const int64> multi_index) const {
1034   CHECK(LayoutUtil::IsDenseArray(shape()));
1035   switch (shape().element_type()) {
1036     case PRED:
1037       return Get<bool>(multi_index);
1038     case S8:
1039       return Get<int8>(multi_index);
1040     case U8:
1041       return Get<uint8>(multi_index);
1042     case S16:
1043       return Get<int16>(multi_index);
1044     case U16:
1045       return Get<uint16>(multi_index);
1046     case S32:
1047       return Get<int32>(multi_index);
1048     case U32:
1049       return Get<uint32>(multi_index);
1050     case S64:
1051       return Get<int64>(multi_index);
1052     case U64:
1053       return Get<uint64>(multi_index);
1054     default:
1055       return absl::nullopt;
1056   }
1057 }
1058 
GetAsDouble(absl::Span<const int64> multi_index) const1059 absl::optional<double> LiteralBase::GetAsDouble(
1060     absl::Span<const int64> multi_index) const {
1061   CHECK(LayoutUtil::IsDenseArray(shape()));
1062   switch (shape().element_type()) {
1063     case F16:
1064       return static_cast<double>(Get<half>(multi_index));
1065     case F32:
1066       return static_cast<double>(Get<float>(multi_index));
1067     case F64:
1068       return Get<double>(multi_index);
1069     case BF16:
1070       return static_cast<double>(Get<bfloat16>(multi_index));
1071     default:
1072       return absl::nullopt;
1073   }
1074 }
1075 
GetAsComplex128(absl::Span<const int64> multi_index) const1076 absl::optional<complex128> LiteralBase::GetAsComplex128(
1077     absl::Span<const int64> multi_index) const {
1078   switch (shape().element_type()) {
1079     case BF16:
1080       return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1081     case F16:
1082       return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1083     case F32:
1084       return {{Get<float>(multi_index), 0}};
1085     case F64:
1086       return {{Get<double>(multi_index), 0}};
1087     case C64:
1088       return {Get<complex64>(multi_index)};
1089     case C128:
1090       return {Get<complex128>(multi_index)};
1091     case S8:
1092       return {Get<int8>(multi_index)};
1093     default:
1094       return absl::nullopt;
1095   }
1096 }
1097 
Hash() const1098 size_t LiteralBase::Hash() const {
1099   using tensorflow::Hash64;
1100   using tensorflow::Hash64Combine;
1101 
1102   size_t hash_value = ShapeUtil::Hash(shape());
1103 
1104   ShapeUtil::ForEachSubshape(
1105       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1106         if (!subshape.IsArray()) {
1107           return;
1108         }
1109 
1110         CHECK(LayoutUtil::IsDense(subshape.layout()));
1111         hash_value = Hash64Combine(
1112             hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
1113                                size_bytes(index)));
1114       });
1115 
1116   return hash_value;
1117 }
1118 
SetIntegralAsS64(absl::Span<const int64> multi_index,int64_t value)1119 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
1120                                             int64_t value) {
1121   CHECK(LayoutUtil::IsDenseArray(shape()));
1122   switch (shape().element_type()) {
1123     case PRED:
1124       Set<bool>(multi_index, value);
1125       break;
1126     case U8:
1127       Set<uint8>(multi_index, value);
1128       break;
1129     case S32:
1130       Set<int32>(multi_index, value);
1131       break;
1132     case S64:
1133       Set<int64>(multi_index, value);
1134       break;
1135     case U32:
1136       Set<uint32>(multi_index, value);
1137       break;
1138     case U64:
1139       Set<uint64>(multi_index, value);
1140       break;
1141     default:
1142       return FailedPrecondition("Array element type is not integral: %s",
1143                                 PrimitiveType_Name(shape().element_type()));
1144   }
1145   return Status::OK();
1146 }
1147 
SetFromDouble(absl::Span<const int64> multi_index,double value)1148 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
1149                                          double value) {
1150   CHECK(LayoutUtil::IsDenseArray(shape()));
1151   switch (shape().element_type()) {
1152     case F16:
1153       Set<half>(multi_index, Eigen::half(value));
1154       break;
1155     case F32:
1156       Set<float>(multi_index, value);
1157       break;
1158     case F64:
1159       Set<double>(multi_index, value);
1160       break;
1161     case BF16:
1162       Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1163       break;
1164     default:
1165       return FailedPrecondition("Array element type is not floating: %s",
1166                                 PrimitiveType_Name(shape().element_type()));
1167   }
1168   return Status::OK();
1169 }
1170 
1171 namespace {
1172 
ShapeToString(bool print_layout,const Shape & shape)1173 string ShapeToString(bool print_layout, const Shape& shape) {
1174   return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1175                       : ShapeUtil::HumanString(shape);
1176 }
1177 
1178 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1179                     bool print_shape, bool print_layout,
1180                     std::vector<string>* pieces);
1181 
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1182 void TupleToStringHelper(const LiteralBase& literal,
1183                          const ShapeIndex& shape_index, bool print_shape,
1184                          bool print_layout, std::vector<string>* pieces) {
1185   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1186   pieces->push_back("(\n");
1187   std::vector<string> tuple_pieces;
1188   for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1189     ShapeIndex element_index = shape_index;
1190     element_index.push_back(i);
1191     std::vector<string> element_pieces;
1192     ToStringHelper(literal, element_index, print_shape, print_layout,
1193                    &element_pieces);
1194     tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1195   }
1196   pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1197   pieces->push_back("\n)");
1198 }
1199 
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1200 void DenseArrayToStringHelper(const LiteralBase& literal,
1201                               const ShapeIndex& shape_index, bool print_shape,
1202                               bool print_layout, std::vector<string>* pieces) {
1203   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1204   int64_t rank = subshape.rank();
1205 
1206   std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1207       to_string_recursive = [&](absl::Span<const int64> dimensions,
1208                                 std::vector<int64>* accum_indices) {
1209         // dimensions.size() decreases by 1 at each recursive call,
1210         // and accum_indices->size() increases by 1.
1211         // Their sum is equal to the rank of the tensor.
1212         CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1213 
1214         auto brace_to_string = [&](string brace) -> string {
1215           // Handle 1D tensor
1216           if (rank == 1) {
1217             return brace;
1218           }
1219           // Handle the innermost tensor of a 2D+ tensor.
1220           if (dimensions.size() == 1 && brace == "{") {
1221             return StrCat("  ", brace, dimensions[0] <= 1 ? "" : " ");
1222           }
1223           if (dimensions.size() == 1 && brace == "}") {
1224             return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1225           }
1226           // Handle the non-innermost tensors of a 2D+ tensor.
1227           if (brace == "{") {
1228             const int64_t accum_indices_size = accum_indices->size();
1229             if (rank > 3 && !accum_indices->empty() &&
1230                 accum_indices_size < rank) {
1231               int index = accum_indices->size() - 1;
1232               int value = accum_indices->back();
1233               return StrCat(brace, " /*i", index, "=", value, "*/\n");
1234             }
1235             return StrCat(brace, "\n");
1236           }
1237           return StrCat("\n", brace);
1238         };
1239 
1240         if (dimensions.empty()) {
1241           // Display predicates as 0s and 1s so that the string is more dense.
1242           string elem;
1243           if (subshape.element_type() == PRED && rank > 0) {
1244             elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1245           } else {
1246             elem = literal.GetAsString(*accum_indices, shape_index);
1247           }
1248           pieces->push_back(elem);
1249         } else {
1250           pieces->push_back(brace_to_string("{"));
1251           for (int i = 0; i < dimensions[0]; ++i) {
1252             std::vector<int64> cloned_indices(*accum_indices);
1253             cloned_indices.push_back(i);
1254             to_string_recursive(dimensions.subspan(1), &cloned_indices);
1255             if (i < dimensions[0] - 1) {
1256               pieces->push_back(",");
1257               pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1258             }
1259           }
1260           pieces->push_back(brace_to_string("}"));
1261         }
1262       };
1263 
1264   if (print_shape) {
1265     pieces->push_back(ShapeToString(print_layout, subshape));
1266     if (subshape.is_dynamic()) {
1267       pieces->push_back("(");
1268       for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1269         pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1270         if (i < subshape.dimensions_size() - 1) {
1271           pieces->push_back(",");
1272         }
1273       }
1274       pieces->push_back(")");
1275     }
1276     pieces->push_back(" ");
1277   }
1278   std::vector<int64> indices = {};
1279   std::vector<int64> dimensions;
1280   dimensions.reserve(subshape.rank());
1281   for (int64_t i = 0; i < subshape.rank(); ++i) {
1282     dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1283   }
1284   to_string_recursive(dimensions, &indices);
1285 }
1286 
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1287 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1288                     bool print_shape, bool print_layout,
1289                     std::vector<string>* pieces) {
1290   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1291   CHECK(LayoutUtil::HasLayout(literal.shape()));
1292   CHECK(LayoutUtil::HasLayout(subshape));
1293   if (subshape.IsTuple()) {
1294     TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1295                         pieces);
1296   } else if (subshape.IsToken()) {
1297     pieces->push_back("token");
1298   } else {
1299     CHECK(LayoutUtil::IsDenseArray(subshape));
1300     DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1301                              pieces);
1302   }
1303 }
1304 
1305 }  // namespace
1306 
ToString() const1307 string LiteralBase::ToString() const {
1308   std::vector<string> pieces;
1309   CHECK(LayoutUtil::HasLayout(this->shape()));
1310   ToStringHelper(*this, {}, /*print_shape=*/true,
1311                  /*print_layout=*/false, &pieces);
1312   return absl::StrJoin(pieces, "");
1313 }
1314 
ToStringOneline() const1315 string LiteralBase::ToStringOneline() const {
1316   return CompactOneline(ToString());
1317 }
1318 
ToStringWithoutShape() const1319 string LiteralBase::ToStringWithoutShape() const {
1320   std::vector<string> pieces;
1321   CHECK(LayoutUtil::HasLayout(this->shape()));
1322   ToStringHelper(*this, {}, /*print_shape=*/false,
1323                  /*print_layout=*/false, &pieces);
1324   return absl::StrJoin(pieces, "");
1325 }
1326 
ToStringWithoutShapeOneline() const1327 string LiteralBase::ToStringWithoutShapeOneline() const {
1328   return CompactOneline(ToStringWithoutShape());
1329 }
1330 
ToStringWithLayout() const1331 string LiteralBase::ToStringWithLayout() const {
1332   std::vector<string> pieces;
1333   CHECK(LayoutUtil::HasLayout(this->shape()));
1334   ToStringHelper(*this, {}, /*print_shape=*/true,
1335                  /*print_layout=*/true, &pieces);
1336   return absl::StrJoin(pieces, "");
1337 }
1338 
ToStringWithLayoutOneline() const1339 string LiteralBase::ToStringWithLayoutOneline() const {
1340   return CompactOneline(ToStringWithLayout());
1341 }
1342 
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1343 void LiteralBase::EachCellAsString(
1344     const std::function<void(absl::Span<const int64> indices,
1345                              const string& value)>& per_cell) const {
1346   if (ShapeUtil::IsZeroElementArray(shape())) {
1347     return;
1348   }
1349   std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1350       shape(), /*linear_index=*/0);
1351   do {
1352     per_cell(indices, GetAsString(indices));
1353   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1354 }
1355 
1356 namespace {
1357 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1358 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1359                                                const ConverterType& converter) {
1360   CHECK(src_literal.shape().IsArray());
1361   Literal result_literal(ShapeUtil::ChangeElementType(
1362       src_literal.shape(),
1363       primitive_util::NativeToPrimitiveType<NativeDestT>()));
1364   auto src_data = src_literal.data<NativeSrcT>();
1365   auto dest_data = result_literal.template data<NativeDestT>();
1366   int64_t num_elements = src_literal.element_count();
1367 
1368   for (int64_t i = 0; i < num_elements; ++i) {
1369     dest_data[i] = converter(src_data[i]);
1370   }
1371   return result_literal;
1372 }
1373 
1374 template <typename NativeSrcT, typename NativeDestT>
1375 typename std::enable_if<std::is_same<NativeSrcT, Eigen::half>::value &&
1376                             (std::is_same<NativeDestT, complex64>::value ||
1377                              std::is_same<NativeDestT, complex128>::value),
1378                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1379 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1380   auto converter = [](NativeSrcT src) {
1381     return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1382   };
1383   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1384       src_literal, converter);
1385 }
1386 
1387 template <typename NativeSrcT, typename NativeDestT>
1388 typename std::enable_if<std::is_floating_point<NativeSrcT>::value &&
1389                             std::is_integral<NativeDestT>::value,
1390                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1391 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1392   auto converter = [](NativeSrcT src) {
1393     // C++ [conv.bool]p1:
1394     //   A prvalue of arithmetic [...] type can be converted to a prvalue of
1395     //   type bool. A zero value [...] is converted to false; any other value is
1396     //   converted to true.
1397     // C++ [conv.fpint]p1:
1398     //   [...] The behavior is undefined if the truncated value cannot be
1399     //   represented in the destination type.
1400     //
1401     // Using static_cast to convert a float to an integral type other than bool
1402     // may be undefined if the value's magnitude is too large or it is a NaN.
1403     // Let's choose saturating arithmetic as it captures the spirit of infinity
1404     // and arbitrarily map NaN to zero.
1405     if (!std::is_same<NativeDestT, bool>::value) {
1406       if (src != src) {
1407         return NativeDestT{0};
1408       }
1409       if (src >= std::numeric_limits<NativeDestT>::max()) {
1410         return std::numeric_limits<NativeDestT>::max();
1411       }
1412       if (src <= std::numeric_limits<NativeDestT>::lowest()) {
1413         return std::numeric_limits<NativeDestT>::lowest();
1414       }
1415     }
1416     return static_cast<NativeDestT>(src);
1417   };
1418   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1419       src_literal, converter);
1420 }
1421 
1422 template <typename NativeSrcT, typename NativeDestT>
1423 typename std::enable_if<!(std::is_floating_point<NativeSrcT>::value &&
1424                           std::is_integral<NativeDestT>::value) &&
1425                             !(std::is_same<NativeSrcT, Eigen::half>::value &&
1426                               (std::is_same<NativeDestT, complex64>::value ||
1427                                std::is_same<NativeDestT, complex128>::value)),
1428                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1429 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1430   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1431   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1432       src_literal, converter);
1433 }
1434 
1435 template <typename NativeSrcT, typename NativeDestT>
1436 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1437                          !std::is_same<NativeDestT, Eigen::half>::value),
1438                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1439 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1440   auto converter = [](NativeSrcT src) {
1441     return absl::bit_cast<NativeDestT>(GetRawValue(src));
1442   };
1443   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1444       src_literal, converter);
1445 }
1446 
1447 template <typename NativeSrcT, typename NativeDestT>
1448 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1449                          std::is_same<NativeDestT, Eigen::half>::value),
1450                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1451 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1452   // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1453   // cast to unsigned short first.
1454   auto converter = [](NativeSrcT src) {
1455     return Eigen::numext::bit_cast<Eigen::half>(
1456         absl::bit_cast<uint16>(GetRawValue(src)));
1457   };
1458   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1459       src_literal, converter);
1460 }
1461 
1462 // This template specialization is here to make the compiler happy. bit_cast has
1463 // a static check that the types are the same size. This specialization should
1464 // never be used because the source and destination types are checked for
1465 // identical sizes higher up.
1466 template <typename NativeSrcT, typename NativeDestT>
1467 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1468                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1469 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1470   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1471 }
1472 
1473 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1474 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1475   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1476   if (bitcast) {
1477     return BitcastBetweenNativeTypes<
1478         typename primitive_util::PrimitiveTypeToNative<
1479             primitive_src_type>::type,
1480         typename primitive_util::PrimitiveTypeToNative<
1481             primitive_dest_type>::type>(src_literal);
1482   } else {
1483     return ConvertBetweenNativeTypes<
1484         typename primitive_util::PrimitiveTypeToNative<
1485             primitive_src_type>::type,
1486         typename primitive_util::PrimitiveTypeToNative<
1487             primitive_dest_type>::type>(src_literal);
1488   }
1489 }
1490 
1491 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1492 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1493                                            PrimitiveType primitive_dest_type,
1494                                            bool bitcast) {
1495   switch (primitive_dest_type) {
1496 #define CONVERT_IF_TYPES_MATCH(type)                                    \
1497   case (type):                                                          \
1498     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1499                                                            bitcast);
1500     CONVERT_IF_TYPES_MATCH(PRED)
1501     CONVERT_IF_TYPES_MATCH(S8)
1502     CONVERT_IF_TYPES_MATCH(S16)
1503     CONVERT_IF_TYPES_MATCH(S32)
1504     CONVERT_IF_TYPES_MATCH(S64)
1505     CONVERT_IF_TYPES_MATCH(U8)
1506     CONVERT_IF_TYPES_MATCH(U16)
1507     CONVERT_IF_TYPES_MATCH(U32)
1508     CONVERT_IF_TYPES_MATCH(U64)
1509     CONVERT_IF_TYPES_MATCH(F16)
1510     CONVERT_IF_TYPES_MATCH(F32)
1511     CONVERT_IF_TYPES_MATCH(F64)
1512     CONVERT_IF_TYPES_MATCH(BF16)
1513 #undef CONVERT_IF_TYPES_MATCH
1514     case C64:
1515       if (bitcast) {
1516         break;
1517       }
1518       return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1519     case C128:
1520       if (bitcast) {
1521         break;
1522       }
1523       return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1524     // Other types are not yet supported.
1525     default:
1526       break;
1527   }
1528   return Unimplemented("Converting from type %s to type %s is not implemented.",
1529                        PrimitiveType_Name(src_literal.shape().element_type()),
1530                        PrimitiveType_Name(primitive_dest_type));
1531 }
1532 
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1533 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1534                                 PrimitiveType primitive_dest_type,
1535                                 bool bitcast) {
1536   TF_RET_CHECK(literal.shape().IsArray());
1537   if (literal.shape().element_type() == primitive_dest_type) {
1538     return literal.Clone();
1539   }
1540   switch (literal.shape().element_type()) {
1541 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
1542   case (type):                                                            \
1543     return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1544                                             bitcast);
1545     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1546     CONVERT_IF_DEST_TYPE_MATCHES(S8)
1547     CONVERT_IF_DEST_TYPE_MATCHES(S16)
1548     CONVERT_IF_DEST_TYPE_MATCHES(S32)
1549     CONVERT_IF_DEST_TYPE_MATCHES(S64)
1550     CONVERT_IF_DEST_TYPE_MATCHES(U8)
1551     CONVERT_IF_DEST_TYPE_MATCHES(U16)
1552     CONVERT_IF_DEST_TYPE_MATCHES(U32)
1553     CONVERT_IF_DEST_TYPE_MATCHES(U64)
1554     CONVERT_IF_DEST_TYPE_MATCHES(F16)
1555     CONVERT_IF_DEST_TYPE_MATCHES(F32)
1556     CONVERT_IF_DEST_TYPE_MATCHES(F64)
1557     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1558 #undef CONVERT_IF_DEST_TYPE_MATCHES
1559       // Other types are not yet supported.
1560     default:
1561       return Unimplemented("%s from type %s to type %s is not implemented.",
1562                            (bitcast ? "Bitcast converting" : "Converting"),
1563                            PrimitiveType_Name(literal.shape().element_type()),
1564                            PrimitiveType_Name(primitive_dest_type));
1565   }
1566 }
1567 
1568 }  // namespace
1569 
Convert(PrimitiveType primitive_dest_type) const1570 StatusOr<Literal> LiteralBase::Convert(
1571     PrimitiveType primitive_dest_type) const {
1572   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1573 }
1574 
BitcastConvert(PrimitiveType primitive_dest_type) const1575 StatusOr<Literal> LiteralBase::BitcastConvert(
1576     PrimitiveType primitive_dest_type) const {
1577   if (primitive_util::BitWidth(shape().element_type()) !=
1578       primitive_util::BitWidth(primitive_dest_type)) {
1579     return InvalidArgument(
1580         "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1581         "%d",
1582         PrimitiveType_Name(shape().element_type()),
1583         PrimitiveType_Name(primitive_dest_type),
1584         primitive_util::BitWidth(shape().element_type()),
1585         primitive_util::BitWidth(primitive_dest_type));
1586   }
1587   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1588 }
1589 
ConvertToShape(const Shape & dest_shape) const1590 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1591   if (!dest_shape.IsTuple()) {
1592     return Convert(dest_shape.element_type());
1593   }
1594   std::vector<Literal> elements;
1595   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1596     auto element = LiteralSlice(*this, {i});
1597     TF_ASSIGN_OR_RETURN(
1598         auto new_element,
1599         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1600     elements.push_back(std::move(new_element));
1601   }
1602   return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1603 }
1604 
MoveIntoTuple(absl::Span<Literal> elements)1605 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1606     absl::Span<Literal> elements) {
1607   std::vector<Shape> element_shapes;
1608   for (const Literal& element : elements) {
1609     element_shapes.push_back(element.shape());
1610   }
1611   Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1612                   /*allocate_arrays=*/false);
1613   for (int i = 0, end = elements.size(); i < end; ++i) {
1614     TF_CHECK_OK(
1615         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1616   }
1617   return literal;
1618 }
1619 
1620 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1621 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1622     const LiteralBase::Piece& src) {
1623   auto dest_shape = subshape();
1624   auto src_shape = src.subshape();
1625 
1626   // At least one shape has to be static as bound.
1627   CHECK(dest_shape.is_static() || src_shape.is_static());
1628   auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1629   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1630     return;
1631   }
1632   std::vector<int64> index(dest_shape.rank());
1633   do {
1634     bool out_of_bound = false;
1635     for (int64_t i = 0; i < index.size(); ++i) {
1636       // Do not copy elements beyond dynamic bound.
1637       if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1638         out_of_bound = true;
1639       }
1640     }
1641     if (out_of_bound) {
1642       continue;
1643     }
1644     data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1645                                                                   index)] =
1646         src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1647             src_shape, index)];
1648   } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1649 }
1650 
1651 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1652 bool LiteralBase::Piece::EqualElementsInternal(
1653     const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1654   if (multi_index->size() == subshape().rank()) {
1655     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1656   }
1657   for (int64_t i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1658     multi_index->push_back(i);
1659     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1660       return false;
1661     }
1662     multi_index->pop_back();
1663   }
1664   return true;
1665 }
1666 
EqualDynamicSize(const LiteralBase::Piece & other) const1667 bool LiteralBase::Piece::EqualDynamicSize(
1668     const LiteralBase::Piece& other) const {
1669   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1670   if (subshape().is_static()) {
1671     return true;
1672   }
1673 
1674   for (int64_t i = 0; i < subshape().rank(); ++i) {
1675     if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1676       return false;
1677     }
1678   }
1679   return true;
1680 }
1681 
EqualElements(const LiteralBase::Piece & other) const1682 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1683   if (subshape().is_static() &&
1684       ShapeUtil::Equal(subshape(), other.subshape()) &&
1685       LayoutUtil::IsDenseArray(subshape())) {
1686     CHECK_EQ(size_bytes(), other.size_bytes());
1687     return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1688   }
1689 
1690   std::vector<int64> multi_index;
1691   switch (subshape().element_type()) {
1692     case PRED:
1693       return EqualElementsInternal<bool>(other, &multi_index);
1694     case S8:
1695       return EqualElementsInternal<int8>(other, &multi_index);
1696     case S16:
1697       return EqualElementsInternal<int16>(other, &multi_index);
1698     case S32:
1699       return EqualElementsInternal<int32>(other, &multi_index);
1700     case S64:
1701       return EqualElementsInternal<int64>(other, &multi_index);
1702     case U8:
1703       return EqualElementsInternal<uint8>(other, &multi_index);
1704     case U16:
1705       return EqualElementsInternal<uint16>(other, &multi_index);
1706     case U32:
1707       return EqualElementsInternal<uint32>(other, &multi_index);
1708     case U64:
1709       return EqualElementsInternal<uint64>(other, &multi_index);
1710     case F32:
1711       return EqualElementsInternal<float>(other, &multi_index);
1712     case F64:
1713       return EqualElementsInternal<double>(other, &multi_index);
1714     case F16:
1715       return EqualElementsInternal<half>(other, &multi_index);
1716     case BF16:
1717       return EqualElementsInternal<bfloat16>(other, &multi_index);
1718     case C64:
1719       return EqualElementsInternal<complex64>(other, &multi_index);
1720     case C128:
1721       return EqualElementsInternal<complex128>(other, &multi_index);
1722     default:
1723       LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1724                  << PrimitiveType_Name(subshape().element_type());
1725   }
1726 }
1727 
operator ==(const LiteralBase & other) const1728 bool LiteralBase::operator==(const LiteralBase& other) const {
1729   // Checking the structure of tuple literals. Checks for dense arrays are
1730   // performed below.
1731   if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1732     return false;
1733   }
1734 
1735   return root_piece().ForEachSubpieceWithBool(
1736       [&](const ShapeIndex& index, const Piece& piece) {
1737         const Piece& other_piece = other.piece(index);
1738         const Shape& subshape = piece.subshape();
1739         const Shape& other_subshape = other_piece.subshape();
1740         if (subshape.element_type() != other_subshape.element_type()) {
1741           return false;
1742         }
1743         if (!piece.subshape().IsArray()) {
1744           return true;
1745         }
1746         if (subshape.rank() != other_subshape.rank()) {
1747           return false;
1748         }
1749 
1750         for (int64_t i = 0; i < subshape.rank(); ++i) {
1751           if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1752             return false;
1753           }
1754         }
1755 
1756         if (!piece.EqualElements(other_piece)) {
1757           return false;
1758         }
1759         return true;
1760       });
1761 }
1762 
1763 namespace {
1764 
1765 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1766 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1767                                   NativeT value) {
1768   for (int64_t i = 0; i < data.size(); ++i) {
1769     if (data[i] != value) {
1770       return false;
1771     }
1772   }
1773   return true;
1774 }
1775 
1776 }  // namespace
1777 
IsAll(int8_t value) const1778 bool LiteralBase::IsAll(int8_t value) const {
1779   return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1780                                                   const Piece& piece) {
1781     if (!piece.subshape().IsArray()) {
1782       return true;
1783     }
1784 
1785     auto piece_is_all = [&]() {
1786       switch (shape().element_type()) {
1787         case U8:
1788           if (value >= 0) {
1789             return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1790           }
1791           return false;
1792         case U16:
1793           if (value >= 0) {
1794             return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1795           }
1796           return false;
1797         case U32:
1798           if (value >= 0) {
1799             return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1800           }
1801           return false;
1802         case U64:
1803           if (value >= 0) {
1804             return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1805           }
1806           return false;
1807         case S8:
1808           return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1809         case S16:
1810           return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1811         case S32:
1812           return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1813         case S64:
1814           return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1815         case F32:
1816           return AllElementsEqualValue<float>(piece.data<float>(), value);
1817         case F64:
1818           return AllElementsEqualValue<double>(piece.data<double>(), value);
1819         case F16:
1820           return AllElementsEqualValue<half>(piece.data<half>(),
1821                                              static_cast<half>(value));
1822         case BF16:
1823           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1824                                                  static_cast<bfloat16>(value));
1825         case PRED:
1826           if (value == 0) {
1827             return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1828           }
1829           if (value == 1) {
1830             return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1831           }
1832           return false;
1833         default:
1834           return false;
1835       }
1836       return false;
1837     };
1838 
1839     if (!piece_is_all()) {
1840       return false;
1841     }
1842     return true;
1843   });
1844 }
1845 
IsAllFloat(float value) const1846 bool LiteralBase::IsAllFloat(float value) const {
1847   return root_piece().ForEachSubpieceWithBool(
1848       [&](const ShapeIndex& index, const Piece& piece) {
1849         if (!piece.subshape().IsArray()) {
1850           return true;
1851         }
1852 
1853         switch (shape().element_type()) {
1854           case F32:
1855             return AllElementsEqualValue<float>(piece.data<float>(), value);
1856           case F64:
1857             return AllElementsEqualValue<double>(piece.data<double>(), value);
1858           case F16:
1859             return AllElementsEqualValue<half>(piece.data<half>(),
1860                                                static_cast<half>(value));
1861           case BF16:
1862             return AllElementsEqualValue<bfloat16>(
1863                 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1864           default:
1865             return false;
1866         }
1867       });
1868 }
1869 
IsAllComplex(complex64 value) const1870 bool LiteralBase::IsAllComplex(complex64 value) const {
1871   switch (shape().element_type()) {
1872     case C64:
1873       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1874                                               value);
1875     case C128:
1876       return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1877                                                value);
1878     default:
1879       return false;
1880   }
1881 }
1882 
IsAllFirst() const1883 bool LiteralBase::IsAllFirst() const {
1884   return root_piece().ForEachSubpieceWithBool(
1885       [&](const ShapeIndex& index, const Piece& piece) {
1886         if (!piece.subshape().IsArray()) {
1887           return true;
1888         }
1889 
1890         // Empty shapes are not all the first element since there is no first
1891         // element.
1892         if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1893           return false;
1894         }
1895         auto piece_is_all = [&]() {
1896           switch (piece.subshape().element_type()) {
1897             case PRED: {
1898               auto data = piece.data<bool>();
1899               return AllElementsEqualValue<bool>(data, data[0]);
1900             }
1901             // 8 bit types
1902             case S8: {
1903               auto data = piece.data<int8>();
1904               return AllElementsEqualValue<int8>(data, data[0]);
1905             }
1906             case U8: {
1907               auto data = piece.data<uint8>();
1908               return AllElementsEqualValue<uint8>(data, data[0]);
1909             }
1910             // 16 bit types
1911             case BF16: {
1912               auto data = piece.data<bfloat16>();
1913               return AllElementsEqualValue<bfloat16>(data, data[0]);
1914             }
1915             case F16: {
1916               auto data = piece.data<half>();
1917               return AllElementsEqualValue<half>(data, data[0]);
1918             }
1919             case S16: {
1920               auto data = piece.data<int16>();
1921               return AllElementsEqualValue<int16>(data, data[0]);
1922             }
1923             case U16: {
1924               auto data = piece.data<uint16>();
1925               return AllElementsEqualValue<uint16>(data, data[0]);
1926             }
1927             // 32 bit types
1928             case F32: {
1929               auto data = piece.data<float>();
1930               return AllElementsEqualValue<float>(data, data[0]);
1931             }
1932             case U32: {
1933               auto data = piece.data<uint32>();
1934               return AllElementsEqualValue<uint32>(data, data[0]);
1935             }
1936             case S32: {
1937               auto data = piece.data<int32>();
1938               return AllElementsEqualValue<int32>(data, data[0]);
1939             }
1940             // 64 bit types
1941             case C64: {
1942               auto data = piece.data<complex64>();
1943               return AllElementsEqualValue<complex64>(data, data[0]);
1944             }
1945             case F64: {
1946               auto data = piece.data<double>();
1947               return AllElementsEqualValue<double>(data, data[0]);
1948             }
1949             case S64: {
1950               auto data = piece.data<int64>();
1951               return AllElementsEqualValue<int64>(data, data[0]);
1952             }
1953             case U64: {
1954               auto data = piece.data<uint64>();
1955               return AllElementsEqualValue<uint64>(data, data[0]);
1956             }
1957 
1958             case C128: {
1959               auto data = piece.data<complex128>();
1960               return AllElementsEqualValue<complex128>(data, data[0]);
1961             }
1962             default:
1963               return false;
1964           }
1965         };
1966 
1967         if (!piece_is_all()) {
1968           return false;
1969         }
1970         return true;
1971       });
1972 }
1973 
IsR1Iota() const1974 bool LiteralBase::IsR1Iota() const {
1975   if (!shape().IsArray()) {
1976     return false;
1977   }
1978 
1979   if (shape().rank() != 1) {
1980     return false;
1981   }
1982 
1983   auto is_iota_at_idx = [&](const int64_t idx) {
1984     switch (shape().element_type()) {
1985       case U8:
1986         return static_cast<int64>(Get<uint8>({idx})) == idx;
1987       case U16:
1988         return static_cast<int64>(Get<uint16>({idx})) == idx;
1989       case U32:
1990         return static_cast<int64>(Get<uint32>({idx})) == idx;
1991       case U64:
1992         return static_cast<int64>(Get<uint64>({idx})) == idx;
1993       case S8:
1994         return Get<int8>({idx}) == idx;
1995       case S16:
1996         return Get<int16>({idx}) == idx;
1997       case S32:
1998         return Get<int32>({idx}) == idx;
1999       case S64:
2000         return Get<int64>({idx}) == idx;
2001       case F32:
2002         return Get<float>({idx}) == idx;
2003       case F64:
2004         return Get<double>({idx}) == idx;
2005       case F16:
2006         return Get<half>({idx}) == static_cast<half>(idx);
2007       case BF16:
2008         return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
2009       case C64:
2010         return Get<complex64>({idx}) == complex64(idx, 0.0f);
2011       case C128:
2012         return Get<complex128>({idx}) == complex128(idx, 0.0f);
2013       // pred, token, opaque, tuple, etc. are all not iota.
2014       default:
2015         return false;
2016     }
2017   };
2018 
2019   const int64_t elements = ShapeUtil::ElementsIn(shape());
2020   for (int64_t idx = 0; idx < elements; ++idx) {
2021     if (!is_iota_at_idx(idx)) {
2022       return false;
2023     }
2024   }
2025 
2026   return true;
2027 }
2028 
2029 // Returns a stride if the literal is a strided iota, i.e., iota multiplied by a
2030 // stride. Only applicable for integer iotas. Returns absl::nullopt if the
2031 // literal is not a strided iota.
IsR1StridedIota() const2032 absl::optional<int64> LiteralBase::IsR1StridedIota() const {
2033   if (!shape().IsArray() || shape().rank() != 1) {
2034     return absl::nullopt;
2035   }
2036 
2037   const int64_t elements = ShapeUtil::ElementsIn(shape());
2038   const PrimitiveType type = shape().element_type();
2039   if (elements <= 1 || !primitive_util::IsIntegralType(type)) {
2040     return absl::nullopt;
2041   }
2042 
2043   auto get_element_at = [&](const int64_t idx) -> int64 {
2044     switch (type) {
2045       case U8:
2046         return static_cast<int64>(Get<uint8>({idx}));
2047       case U16:
2048         return static_cast<int64>(Get<uint16>({idx}));
2049       case U32:
2050         return static_cast<int64>(Get<uint32>({idx}));
2051       case U64:
2052         return static_cast<int64>(Get<uint64>({idx}));
2053       case S8:
2054         return Get<int8>({idx});
2055       case S16:
2056         return Get<int16>({idx});
2057       case S32:
2058         return Get<int32>({idx});
2059       case S64:
2060         return Get<int64>({idx});
2061       default:
2062         CHECK(0);
2063         return 0;
2064     }
2065   };
2066 
2067   // Infer the stride as the second element (since first element is supposed
2068   // to be zero).
2069   int64_t stride = get_element_at(1);
2070   if (stride == 0) {
2071     return absl::nullopt;
2072   }
2073 
2074   for (int64_t idx = 0; idx < elements; ++idx) {
2075     if (get_element_at(idx) != idx * stride) {
2076       return absl::nullopt;
2077     }
2078   }
2079 
2080   return stride;
2081 }
2082 
IsZero(absl::Span<const int64> indices) const2083 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
2084   CHECK(shape().IsArray());
2085   switch (shape().element_type()) {
2086     case U8:
2087       return Get<uint8>(indices) == 0;
2088     case U16:
2089       return Get<uint16>(indices) == 0;
2090     case U32:
2091       return Get<uint32>(indices) == 0;
2092     case U64:
2093       return Get<uint64>(indices) == 0;
2094     case S8:
2095       return Get<int8>(indices) == 0;
2096     case S16:
2097       return Get<int16>(indices) == 0;
2098     case S32:
2099       return Get<int32>(indices) == 0;
2100     case S64:
2101       return Get<int64>(indices) == 0;
2102     case F32:
2103       return Get<float>(indices) == 0.0f;
2104     case F64:
2105       return Get<double>(indices) == 0.0;
2106     case C64:
2107       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
2108     case C128:
2109       return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2110     case F16:
2111       return Get<half>(indices) == static_cast<half>(0.0f);
2112     case BF16:
2113       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2114     case PRED:
2115       return Get<bool>(indices) == false;
2116     default:
2117       LOG(FATAL) << "Input literal must be an array.";
2118   }
2119 }
2120 
2121 namespace {
2122 
2123 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2124 void CopyToRepeatedField(RepeatedFieldT* dest,
2125                          const absl::Span<const NativeT> src) {
2126   *dest = RepeatedFieldT(src.begin(), src.end());
2127 }
2128 
2129 }  // namespace
2130 
WriteToProto(LiteralProto * proto) const2131 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2132   *proto->mutable_shape() = subshape().ToProto();
2133   switch (subshape().element_type()) {
2134     case PRED:
2135       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2136       break;
2137     case S8:
2138       proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
2139                      element_count());
2140       break;
2141     case U8:
2142       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
2143                      element_count());
2144       break;
2145     case U32:
2146       CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
2147       break;
2148     case U64:
2149       CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
2150       break;
2151     case S32:
2152       CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
2153       break;
2154     case S64:
2155       CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
2156       break;
2157     case U16:
2158       *proto->mutable_u16s() = string(
2159           reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2160       if (!kLittleEndian) {
2161         ConvertEndianShort(proto->mutable_u16s());
2162       }
2163       break;
2164     case S16:
2165       *proto->mutable_s16s() = string(
2166           reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2167       if (!kLittleEndian) {
2168         ConvertEndianShort(proto->mutable_s16s());
2169       }
2170       break;
2171     case F16:
2172       *proto->mutable_f16s() = string(
2173           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2174       if (!kLittleEndian) {
2175         ConvertEndianShort(proto->mutable_f16s());
2176       }
2177       break;
2178     case BF16:
2179       *proto->mutable_bf16s() = string(
2180           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2181       if (!kLittleEndian) {
2182         ConvertEndianShort(proto->mutable_bf16s());
2183       }
2184       break;
2185     case F32:
2186       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2187       break;
2188     case F64:
2189       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2190       break;
2191     case C64:
2192       for (complex64 value : data<complex64>()) {
2193         proto->add_c64s(value.real());
2194         proto->add_c64s(value.imag());
2195       }
2196       break;
2197     case C128:
2198       for (complex128 value : data<complex128>()) {
2199         proto->add_c128s(value.real());
2200         proto->add_c128s(value.imag());
2201       }
2202       break;
2203     case TUPLE:
2204     case TOKEN:
2205       // Nothing to do but assign the shape which is done above.
2206       return;
2207     default:
2208       // TODO(b/111551621): Support serializing more PrimitiveTypes.
2209       LOG(FATAL) << "Unhandled primitive type "
2210                  << PrimitiveType_Name(subshape().element_type());
2211   }
2212 }
2213 
untyped_data() const2214 const void* LiteralBase::Piece::untyped_data() const {
2215   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2216   return buffer();
2217 }
2218 
untyped_data()2219 void* LiteralBase::Piece::untyped_data() {
2220   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2221   return buffer();
2222 }
2223 
2224 namespace {
2225 
2226 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2227 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2228                              const RepeatedFieldT& src) {
2229   if (dest.size() != src.size()) {
2230     return InvalidArgument(
2231         "Expected %lu elements in LiteralProto repeated field, has %d",
2232         dest.size(), src.size());
2233   }
2234   std::copy(src.begin(), src.end(), dest.begin());
2235   return Status::OK();
2236 }
2237 
2238 }  // namespace
2239 
CopyFromProto(const LiteralProto & proto)2240 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2241   // These conditions should have been checked in
2242   // MutableLiteralBase::CreateFromProto.
2243   TF_RET_CHECK(proto.has_shape());
2244   Shape shape(proto.shape());
2245   TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2246   TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2247 
2248   switch (subshape().element_type()) {
2249     case PRED:
2250       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2251       break;
2252     case S8: {
2253       auto s8_data = data<int8>();
2254       TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2255       std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2256     } break;
2257     case U8: {
2258       auto u8_data = data<uint8>();
2259       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2260       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2261     } break;
2262     case S32:
2263       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2264       break;
2265     case S64:
2266       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2267       break;
2268     case U32:
2269       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2270       break;
2271     case U64:
2272       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2273       break;
2274     case S16: {
2275       const string& s(proto.s16s());
2276       TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2277       memcpy(untyped_data(), s.data(), s.size());
2278       if (!kLittleEndian) {
2279         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2280       }
2281     } break;
2282     case U16: {
2283       const string& s(proto.u16s());
2284       TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2285       memcpy(untyped_data(), s.data(), s.size());
2286       if (!kLittleEndian) {
2287         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2288       }
2289     } break;
2290     case F16: {
2291       const string& s(proto.f16s());
2292       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2293       memcpy(untyped_data(), s.data(), s.size());
2294       if (!kLittleEndian) {
2295         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2296       }
2297     } break;
2298 
2299     case BF16: {
2300       const string& s(proto.bf16s());
2301       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2302       memcpy(untyped_data(), s.data(), s.size());
2303       if (!kLittleEndian) {
2304         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2305       }
2306     } break;
2307     case F32:
2308       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2309       break;
2310     case F64:
2311       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2312       break;
2313     case C64: {
2314       auto complex_data = data<complex64>();
2315       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2316       for (int64_t i = 0; i < complex_data.size(); ++i) {
2317         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2318       }
2319       break;
2320     }
2321     case C128: {
2322       auto complex_data = data<complex128>();
2323       const int64_t complex_data_size_doubled = complex_data.size() * 2;
2324       TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2325       for (int64_t i = 0, end = complex_data.size(); i < end; ++i) {
2326         complex_data[i] =
2327             complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2328       }
2329       break;
2330     }
2331     case TUPLE:
2332       return InvalidArgument("Should not be called on tuple shapes: %s",
2333                              ShapeUtil::HumanString(subshape()));
2334     default:
2335       return InvalidArgument("Is called on unsupported shape: %s",
2336                              ShapeUtil::HumanString(subshape()));
2337   }
2338   return Status::OK();
2339 }
2340 
ToProto() const2341 LiteralProto LiteralBase::ToProto() const {
2342   LiteralProto proto;
2343   root_piece().ForEachSubpiece(
2344       [&](const ShapeIndex& index, const Piece& piece) {
2345         LiteralProto* proto_piece = &proto;
2346         for (int64_t i : index) {
2347           while (proto_piece->tuple_literals_size() <= i) {
2348             proto_piece->add_tuple_literals();
2349           }
2350           proto_piece = proto_piece->mutable_tuple_literals(i);
2351         }
2352         piece.WriteToProto(proto_piece);
2353       });
2354 
2355   return proto;
2356 }
2357 
untyped_data(const ShapeIndex & shape_index) const2358 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2359   return piece(shape_index).untyped_data();
2360 }
2361 
untyped_data(const ShapeIndex & shape_index)2362 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2363   return piece(shape_index).untyped_data();
2364 }
2365 
size_bytes(const ShapeIndex & shape_index) const2366 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2367   return piece(shape_index).size_bytes();
2368 }
2369 
GetR1U8AsString() const2370 string LiteralBase::GetR1U8AsString() const {
2371   CHECK(shape().IsArray());
2372   CHECK_EQ(shape().rank(), 1);
2373   CHECK_EQ(shape().element_type(), U8);
2374   return string(absl::bit_cast<const char*>(data<uint8>().data()),
2375                 ShapeUtil::ElementsIn(shape()));
2376 }
2377 
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2378 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2379                                                Piece* src_piece,
2380                                                Piece* dest_piece) {
2381   DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2382       << "src_piece has shape: "
2383       << ShapeUtil::HumanString(src_piece->subshape())
2384       << "dest_piece has shape: "
2385       << ShapeUtil::HumanString(dest_piece->subshape());
2386   if (shape.IsTuple()) {
2387     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2388       const Shape& subshape = shape.tuple_shapes(i);
2389 
2390       auto child_piece = Piece();
2391       child_piece.set_subshape(&subshape);
2392 
2393       CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2394 
2395       dest_piece->emplace_back(std::move(child_piece));
2396     }
2397   } else if (shape.IsArray()) {
2398     dest_piece->set_buffer(src_piece->buffer());
2399     dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
2400   } else {
2401     // If the shape is neither an array nor tuple, then it must be
2402     // zero-sized. Otherwise, some memory needs to be allocated for it.
2403     CHECK_EQ(dest_piece->size_bytes(), 0);
2404   }
2405 }
2406 
~MutableLiteralBase()2407 MutableLiteralBase::~MutableLiteralBase() {}
2408 
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2409 MutableBorrowingLiteral::MutableBorrowingLiteral(
2410     const MutableBorrowingLiteral& literal)
2411     : MutableLiteralBase() {
2412   shape_ = absl::make_unique<Shape>(literal.shape());
2413   CHECK(LayoutUtil::HasLayout(*shape_));
2414 
2415   root_piece_ = new Piece();
2416   root_piece_->set_subshape(shape_.get());
2417 
2418   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2419 }
2420 
operator =(const MutableBorrowingLiteral & literal)2421 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2422     const MutableBorrowingLiteral& literal) {
2423   shape_ = absl::make_unique<Shape>(literal.shape());
2424   CHECK(LayoutUtil::HasLayout(*shape_));
2425 
2426   root_piece_ = new Piece();
2427   root_piece_->set_subshape(shape_.get());
2428 
2429   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2430 
2431   return *this;
2432 }
2433 
MutableBorrowingLiteral(MutableLiteralBase * literal)2434 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2435     : MutableLiteralBase() {
2436   shape_ = absl::make_unique<Shape>(literal->shape());
2437   CHECK(LayoutUtil::HasLayout(*shape_));
2438 
2439   root_piece_ = new Piece();
2440   root_piece_->set_subshape(shape_.get());
2441 
2442   CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2443 }
2444 
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2445 MutableBorrowingLiteral::MutableBorrowingLiteral(
2446     MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2447     : MutableLiteralBase() {
2448   shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2449   CHECK(LayoutUtil::HasLayout(*shape_));
2450 
2451   root_piece_ = new Piece();
2452   root_piece_->set_subshape(shape_.get());
2453 
2454   CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2455 }
2456 
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2457 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2458                                                  const Shape& shape)
2459     : MutableLiteralBase() {
2460   shape_ = absl::make_unique<Shape>(shape);
2461   CHECK(LayoutUtil::HasLayout(*shape_));
2462   CHECK(!shape_->IsTuple());
2463 
2464   root_piece_ = new Piece();
2465   root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2466   root_piece_->set_subshape(shape_.get());
2467 }
2468 
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2469 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2470                                                  const Shape& shape)
2471     : MutableLiteralBase() {
2472   shape_ = absl::make_unique<Shape>(shape);
2473   if (!shape_->IsTuple()) {
2474     CHECK_EQ(src_buf_ptrs.size(), 1);
2475     root_piece_ = new Piece();
2476     root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2477     root_piece_->set_subshape(shape_.get());
2478   } else {
2479     CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2480     CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2481     root_piece_ = new Piece();
2482     root_piece_->set_subshape(shape_.get());
2483 
2484     for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2485       Piece child_piece;
2486       const auto& src_shape = shape_->tuple_shapes(i);
2487       CHECK(src_shape.IsArray());
2488       child_piece.set_subshape(&src_shape);
2489       child_piece.set_buffer(src_buf_ptrs[i]);
2490       root_piece_->emplace_back(std::move(child_piece));
2491     }
2492   }
2493 }
2494 
~MutableBorrowingLiteral()2495 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2496   if (root_piece_ != nullptr) {
2497     delete root_piece_;
2498   }
2499 }
2500 
LiteralSlice(const LiteralBase & literal)2501 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2502     : LiteralBase(), root_piece_(&literal.root_piece()) {}
2503 
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2504 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2505                            const ShapeIndex& view_root)
2506     : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2507 
BuildPieceSubtree(const Shape & shape,Piece * piece)2508 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2509   CHECK(shape.IsTuple());
2510   for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2511     const Shape& subshape = shape.tuple_shapes(i);
2512 
2513     auto child_piece = Piece();
2514     child_piece.set_subshape(&subshape);
2515 
2516     if (subshape.IsTuple()) {
2517       BuildPieceSubtree(subshape, &child_piece);
2518     }
2519 
2520     piece->emplace_back(std::move(child_piece));
2521   }
2522 }
2523 
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2524 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2525     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2526   CHECK(shape_->IsArray());
2527   CHECK(LayoutUtil::HasLayout(*shape_));
2528 
2529   root_piece_ = Piece();
2530   root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2531   root_piece_.set_subshape(shape_.get());
2532 }
2533 
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2534 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2535                                    const Shape& shape)
2536     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2537   CHECK(shape_->IsTuple());
2538   CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2539   CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2540   root_piece_ = Piece();
2541   root_piece_.set_subshape(shape_.get());
2542   BuildPieceSubtree(*shape_, &root_piece_);
2543 
2544   for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2545     const auto& src_shape = shape_->tuple_shapes(i);
2546     CHECK(src_shape.IsArray());
2547     root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2548   }
2549 }
2550 
2551 }  // namespace xla
2552