• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/base/casts.h"
31 #include "absl/hash/hash.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/mem.h"
48 
49 namespace xla {
50 namespace {
51 
52 using absl::StrCat;
53 
54 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
55 // Literals can be used as DMA targets, which can require alignment. We
56 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
57 // alignment.
58 constexpr int kMinimumAlignment = 64;
59 
60 // Converts between little and big endian.
61 //
62 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(std::string * bytes)63 void ConvertEndianShort(std::string* bytes) {
64   CHECK_EQ(bytes->size() % 2, 0);
65   for (int64_t i = 0, end = bytes->size(); i < end; i += 2) {
66     std::swap((*bytes)[i], (*bytes)[i + 1]);
67   }
68 }
69 
ConvertEndianShort(char * bytes,int64_t size)70 void ConvertEndianShort(char* bytes, int64_t size) {
71   CHECK_EQ(size % 2, 0);
72   for (int64_t i = 0; i < size; i += 2) {
73     std::swap(bytes[i], bytes[i + 1]);
74   }
75 }
76 
CompactOneline(const std::string & input)77 std::string CompactOneline(const std::string& input) {
78   std::string result;
79   std::vector<std::string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
80   bool first = true;
81   // Concatenate elements in "v" with spaces separating them, but ignoring
82   // empty entries.
83   for (const auto& s : v) {
84     if (s.empty()) {
85       continue;
86     }
87     absl::StrAppend(&result, (first ? "" : " "), s);
88     first = false;
89   }
90   return result;
91 }
92 
93 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
94 // able to transparently access the raw 16-bit value contained within.
95 template <typename T>
GetRawValue(T val)96 T GetRawValue(T val) {
97   return val;
98 }
GetRawValue(Eigen::half val)99 uint16_t GetRawValue(Eigen::half val) {
100   return Eigen::numext::bit_cast<uint16_t>(val);
101 }
102 
LiteralProtoHasValues(const LiteralProto & proto)103 bool LiteralProtoHasValues(const LiteralProto& proto) {
104   return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
105          proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
106          proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
107          proto.c64s_size() || proto.c128s_size() ||
108          proto.tuple_literals_size() || !proto.f16s().empty() ||
109          !proto.bf16s().empty() || !proto.u16s().empty() ||
110          !proto.s16s().empty();
111 }
112 
113 // Lazy getter for the interned scalar shape in static storage. We reuse this
114 // shape pointer to when constructing scalar Literals, which can happen a lot
115 // when we are evaluating reduce-like ops in HloEvalutator, and copying the
116 // shape over and over again significantly slows down the evaluator.
117 template <PrimitiveType kType>
ScalarShapeImpl()118 const Shape& ScalarShapeImpl() {
119   static_assert(primitive_util::IsArrayType(kType),
120                 "Not a valid type for a scalar.");
121   static const Shape* shape = [] {
122     auto shape = new Shape(kType, {}, {}, {});
123     shape->mutable_layout();
124     return shape;
125   }();
126   return *shape;
127 }
128 
ScalarShape(PrimitiveType type)129 const Shape& ScalarShape(PrimitiveType type) {
130   switch (type) {
131     case U8:
132       return ScalarShapeImpl<U8>();
133     case U16:
134       return ScalarShapeImpl<U16>();
135     case U32:
136       return ScalarShapeImpl<U32>();
137     case U64:
138       return ScalarShapeImpl<U64>();
139     case S8:
140       return ScalarShapeImpl<S8>();
141     case S16:
142       return ScalarShapeImpl<S16>();
143     case S32:
144       return ScalarShapeImpl<S32>();
145     case S64:
146       return ScalarShapeImpl<S64>();
147     case F16:
148       return ScalarShapeImpl<F16>();
149     case BF16:
150       return ScalarShapeImpl<BF16>();
151     case F32:
152       return ScalarShapeImpl<F32>();
153     case F64:
154       return ScalarShapeImpl<F64>();
155     case C64:
156       return ScalarShapeImpl<C64>();
157     case C128:
158       return ScalarShapeImpl<C128>();
159     case PRED:
160       return ScalarShapeImpl<PRED>();
161     case TUPLE:
162       LOG(FATAL) << "Tuple element type cannot be a scalar type.";
163     case OPAQUE_TYPE:
164       LOG(FATAL) << "Opaque element type cannot be a scalar type.";
165     case TOKEN:
166       LOG(FATAL) << "Token element type cannot be a scalar type.";
167     case PRIMITIVE_TYPE_INVALID:
168       LOG(FATAL) << "Invalid primitive type.";
169     default:
170       LOG(FATAL) << "Unhandled primitive type " << type;
171   }
172 }
173 
NilShape()174 const Shape& NilShape() {
175   static const Shape* shape = new Shape(TUPLE, {}, {}, {});
176   return *shape;
177 }
178 
179 // Returns the interned shape pointer in static storage if it's a scalar shape
180 // or nil shape.
TryInternShape(const Shape & shape)181 const Shape* TryInternShape(const Shape& shape) {
182   if (shape.IsTuple() && shape.tuple_shapes_size() == 0) {
183     return &NilShape();
184   }
185   if (shape.IsArray() && shape.dimensions_size() == 0 && shape.is_static() &&
186       shape.layout().tiles_size() == 0 && shape.layout().memory_space() == 0) {
187     return &ScalarShape(shape.element_type());
188   }
189   return nullptr;
190 }
191 
192 }  // namespace
193 
~LiteralBase()194 LiteralBase::~LiteralBase() {}
195 
operator <<(std::ostream & out,const Literal & literal)196 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
197   out << literal.ToString();
198   return out;
199 }
200 
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64_t> dimensions)201 MutableLiteralBase::StrideConfig::StrideConfig(
202     const Shape& source_shape, const Shape& dest_shape,
203     absl::Span<const int64_t> dimensions)
204     : dimensions(dimensions),
205       base(dimensions.size(), 0),
206       step(dimensions.size(), 1) {
207   if (!dimensions.empty()) {
208     // Selects the shape with the largest minor dimension as the one upon
209     // which to run the tight stride loop.
210     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
211         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
212       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
213       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
214     } else {
215       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
216       source_stride =
217           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
218     }
219     minor_loop_size = dimensions[minor_dimension];
220     step[minor_dimension] = minor_loop_size;
221   }
222 }
223 
mutable_shape_do_not_use()224 Shape* MutableLiteralBase::mutable_shape_do_not_use() {
225   const Shape* const_shape = shape_.get();
226   Shape* shape = shape_.get_mutable(/*ensure_owned=*/true);
227   if (shape != const_shape) {
228     std::function<void(const Shape&, Piece*)> set_piece_shapes =
229         [&set_piece_shapes](const Shape& shape, Piece* piece) {
230           piece->set_subshape(&shape);
231           if (shape.IsTuple()) {
232             for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
233               const Shape& subshape = shape.tuple_shapes(i);
234               set_piece_shapes(subshape, &piece->child(i));
235             }
236           }
237         };
238     set_piece_shapes(*shape, &mutable_root_piece());
239   }
240   return shape;
241 }
242 
Literal()243 Literal::Literal() : Literal(NilShape()) {}
244 
Literal(const Shape & shape)245 Literal::Literal(const Shape& shape)
246     : Literal(shape, /*allocate_arrays=*/true) {}
247 
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays,ArrayValueState leaf_array_value_state)248 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
249                        ArrayValueState leaf_array_value_state) {
250   if (shape.IsTuple()) {
251     for (const Shape& subshape : shape.tuple_shapes()) {
252       auto child_piece = Piece();
253       child_piece.set_subshape(&subshape);
254 
255       SetPiece(subshape, &child_piece, allocate_arrays, leaf_array_value_state);
256 
257       piece->emplace_back(std::move(child_piece));
258     }
259   } else if (shape.IsArray()) {
260     piece->set_array_value_state(leaf_array_value_state);
261     if (leaf_array_value_state == LiteralBase::ArrayValueState::kKnown &&
262         allocate_arrays) {
263       piece->AllocateBuffers();
264     }
265   } else {
266     // If the shape is neither an array nor tuple, then it must be
267     // zero-sized. Otherwise, some memory needs to be allocated for it.
268     CHECK_EQ(piece->size_bytes(), 0);
269   }
270 }
271 
Literal(const Shape & shape,bool allocate_arrays,ArrayValueState leaf_array_value_state)272 Literal::Literal(const Shape& shape, bool allocate_arrays,
273                  ArrayValueState leaf_array_value_state)
274     : MutableLiteralBase() {
275   if (const Shape* intered_shape_ptr = TryInternShape(shape)) {
276     shape_ = intered_shape_ptr;
277   } else {
278     shape_ = std::make_unique<Shape>(shape);
279   }
280   CHECK(leaf_array_value_state != ArrayValueState::kKnown ||
281         LayoutUtil::HasLayout(*shape_));
282   root_piece_.set_subshape(shape_.get());
283   CHECK(&root_piece_.subshape() == shape_.get());
284 
285   SetPiece(*shape_, &root_piece_, allocate_arrays, leaf_array_value_state);
286 }
287 
~Literal()288 Literal::~Literal() { DeallocateBuffers(); }
289 
DeallocateBuffers()290 void Literal::DeallocateBuffers() {
291   root_piece_.ForEachMutableSubpiece(
292       [&](const ShapeIndex& index, Piece* piece) {
293         piece->DeallocateBuffers();
294       });
295 }
296 
Literal(Literal && other)297 Literal::Literal(Literal&& other) : MutableLiteralBase() {
298   *this = std::move(other);
299 }
300 
operator =(Literal && other)301 Literal& Literal::operator=(Literal&& other) {
302   DCHECK(&other.root_piece_.subshape() == other.shape_.get());
303   using std::swap;
304   swap(shape_, other.shape_);
305   swap(root_piece_, other.root_piece_);
306   DCHECK(&root_piece_.subshape() == shape_.get());
307 
308   return *this;
309 }
310 
CreateFromShape(const Shape & shape)311 Literal LiteralBase::CreateFromShape(const Shape& shape) {
312   Literal literal(shape);
313   literal.root_piece_.ForEachMutableSubpiece(
314       [&](const ShapeIndex& index, Piece* piece) {
315         if (piece->subshape().IsArray()) {
316           memset(piece->untyped_data(), 0, piece->size_bytes());
317         }
318       });
319   return literal;
320 }
321 
CreateFromShapeWithUnknownLeafArrays(const Shape & shape)322 Literal LiteralBase::CreateFromShapeWithUnknownLeafArrays(const Shape& shape) {
323   Literal literal(shape, /*allocate_arrays=*/false, ArrayValueState::kUnknown);
324   return literal;
325 }
326 
CreateFromShapeWithUndeterminedLeafArrays(const Shape & shape)327 Literal LiteralBase::CreateFromShapeWithUndeterminedLeafArrays(
328     const Shape& shape) {
329   Literal literal(shape, /*allocate_arrays=*/false,
330                   ArrayValueState::kUndetermined);
331   return literal;
332 }
333 
GetDynamicSize(int64_t dim_index) const334 int32_t LiteralBase::GetDynamicSize(int64_t dim_index) const {
335   return GetDynamicSize(dim_index, {});
336 }
337 
GetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index) const338 int32_t LiteralBase::GetDynamicSize(int64_t dim_index,
339                                     const ShapeIndex& shape_index) const {
340   return piece(shape_index).GetDynamicSize(dim_index);
341 }
342 
GetFirstInteger() const343 std::optional<int64_t> LiteralBase::GetFirstInteger() const {
344   switch (shape().element_type()) {
345     case U8:
346       return GetFirstElement<uint8_t>();
347     case U16:
348       return GetFirstElement<uint16_t>();
349     case U32:
350       return GetFirstElement<uint32_t>();
351     case U64: {
352       int64_t v = GetFirstElement<uint64_t>();
353       if (v < 0) {
354         return std::nullopt;
355       }
356       return v;
357     }
358     case S8:
359       return GetFirstElement<int8_t>();
360     case S16:
361       return GetFirstElement<int16_t>();
362     case S32:
363       return GetFirstElement<int32_t>();
364     case S64:
365       return GetFirstElement<int64_t>();
366     default:
367       return std::nullopt;
368   }
369 }
370 
371 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64_t> src_base,absl::Span<const int64_t> dest_base,absl::Span<const int64_t> copy_size)372 Status MutableLiteralBase::CopySliceFromInternal(
373     const LiteralBase& src_literal, absl::Span<const int64_t> src_base,
374     absl::Span<const int64_t> dest_base, absl::Span<const int64_t> copy_size) {
375   const int64_t src_base_size = src_base.size();
376   const int64_t dest_base_size = dest_base.size();
377   TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
378   TF_RET_CHECK(shape().rank() == dest_base_size);
379 
380   auto linear_index = [](const Shape& shape,
381                          absl::Span<const int64_t> multi_index) {
382     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
383   };
384 
385   if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
386     // If any of the two shapes are scalars, we can just call the StridedCopy()
387     // directly, and we know we will be copying only one value.
388     TF_RET_CHECK(copy_size.empty());
389     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
390                 src_literal.data<NativeT>(),
391                 linear_index(src_literal.shape(), src_base), 0, 1);
392   } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
393              !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
394     // Perform copy if neither src nor dest has dimensions with zero element,
395     // otherwise it's a no-op.
396     TF_RET_CHECK(src_base.size() == dest_base.size());
397     TF_RET_CHECK(src_base.size() == copy_size.size());
398 
399     // Scan the source from minor, stepping in copy size blocks, then within
400     // the index enumeration functor, do a strided copy advancing source index
401     // by one (walking through the minor dimension), and destination index by
402     // proper stride size at the matching dimension.
403     DimensionVector src_indexes(src_base.size(), 0);
404     DimensionVector dest_indexes(dest_base.size(), 0);
405     MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
406                                                    copy_size);
407 
408     auto copy_proc = [&](absl::Span<const int64_t> indexes) {
409       // Map from multi-dimensional index, to source index.
410       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
411                      src_indexes.begin(), std::plus<int64_t>());
412       // Map from multi-dimensional index, to destination index.
413       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
414                      dest_indexes.begin(), std::plus<int64_t>());
415 
416       int64_t src_index = linear_index(src_literal.shape(), src_indexes);
417       int64_t dest_index = linear_index(shape(), dest_indexes);
418 
419       // `this->` is needed to workaround MSVC bug: #16882
420       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
421                   src_literal.data<NativeT>(), src_index,
422                   stride_config.source_stride, stride_config.minor_loop_size);
423       return true;
424     };
425 
426     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
427                             stride_config.dimensions, stride_config.step,
428                             copy_proc);
429   }
430   return OkStatus();
431 }
432 
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64_t> src_index,absl::Span<const int64_t> dest_index)433 Status MutableLiteralBase::CopyElementFrom(
434     const LiteralSlice& src_literal, absl::Span<const int64_t> src_index,
435     absl::Span<const int64_t> dest_index) {
436   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
437   const int64_t src_linear_index =
438       IndexUtil::MultidimensionalIndexToLinearIndex(src_literal.shape(),
439                                                     src_index);
440   const int64_t dest_linear_index =
441       IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
442   const int64_t primitive_size =
443       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
444 
445   char* dest_address =
446       static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
447   const char* source_address =
448       static_cast<const char*>(src_literal.untyped_data()) +
449       src_linear_index * primitive_size;
450   if (dest_address != source_address) {
451     memcpy(dest_address, source_address, primitive_size);
452   }
453   return OkStatus();
454 }
455 
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)456 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
457     const LiteralProto& proto, bool prohibit_empty_literal) {
458   if (!proto.has_shape()) {
459     return InvalidArgument("LiteralProto has no shape");
460   }
461   Shape shape(proto.shape());
462   if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
463     return InvalidArgument(
464         "Literal shape cannot include OPAQUE_TYPE sub-shape");
465   }
466   if (!LayoutUtil::HasLayout(shape)) {
467     return InvalidArgument("LiteralProto has no layout");
468   }
469 
470   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
471 
472   Literal literal(shape);
473 
474   TF_RETURN_IF_ERROR(literal.root_piece_.ForEachMutableSubpieceWithStatus(
475       [&](const ShapeIndex& index, Piece* piece) {
476         const LiteralProto* proto_element = &proto;
477         for (int64_t i : index) {
478           CHECK(i < proto_element->tuple_literals_size());
479           proto_element = &proto_element->tuple_literals(i);
480         }
481 
482         if (piece->subshape().IsTuple()) {
483           if (proto_element->tuple_literals_size() !=
484               ShapeUtil::TupleElementCount(piece->subshape())) {
485             return InvalidArgument(
486                 "Expected %d tuple elements in LiteralProto, has %d",
487                 ShapeUtil::TupleElementCount(piece->subshape()),
488                 proto_element->tuple_literals_size());
489           }
490           return OkStatus();
491         }
492         if (piece->subshape().element_type() == TOKEN) {
493           return OkStatus();
494         }
495 
496         CHECK(piece->subshape().IsArray());
497 
498         // When prohibit_empty_literal is false (allowing literal with no
499         // values), only copy from proto if the literal proto has values. This
500         // mode is used for a learned cost model.
501         if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
502           TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
503         }
504 
505         return OkStatus();
506       }));
507 
508   return std::move(literal);
509 }
510 
SubLiteral(ShapeIndexView shape_index)511 Literal Literal::SubLiteral(ShapeIndexView shape_index) {
512   if (!shape_index.empty()) {
513     auto decomposed = this->DecomposeTuple();
514     return decomposed.at(shape_index.front())
515         .SubLiteral(shape_index.subspan(1));
516   } else {
517     return std::move(*this);
518   }
519 }
520 
DecomposeTuple()521 std::vector<Literal> Literal::DecomposeTuple() {
522   CHECK(shape().IsTuple());
523   std::vector<Literal> elements;
524   const auto tuple_element_count = ShapeUtil::TupleElementCount(shape());
525   elements.reserve(tuple_element_count);
526   for (int i = 0; i < tuple_element_count; ++i) {
527     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
528                                /*allocate_arrays=*/false));
529     Literal& element = elements.back();
530     element.root_piece_.ForEachMutableSubpiece(
531         [&](const ShapeIndex& index, Piece* dest_piece) {
532           if (dest_piece->subshape().IsTuple()) {
533             return;
534           }
535           ShapeIndex src_index = {i};
536           for (int64_t j : index) {
537             src_index.push_back(j);
538           }
539           Piece& src_piece = piece(src_index);
540 
541           // Move the respective buffer over to the element Literal.
542           dest_piece->MoveDataFrom(src_piece);
543         });
544   }
545   // Set this literal to be nil-shaped.
546   *this = Literal();
547   return elements;
548 }
549 
550 namespace {
551 
552 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
553 // the array slices are indicated by dest_shape and src_shape respectively.
554 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)555 void CopyElementsBetween(absl::Span<NativeT> dest,
556                          absl::Span<const NativeT> src, const Shape& dest_shape,
557                          const Shape& src_shape) {
558   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
559   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
560     return;
561   }
562   std::vector<int64_t> index(dest_shape.rank());
563   do {
564     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
565         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
566   } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
567 }
568 }  // namespace
569 
GetDynamicSize(int64_t dim_index) const570 int32_t LiteralBase::Piece::GetDynamicSize(int64_t dim_index) const {
571   CHECK(LayoutUtil::IsDenseArray(subshape()));
572   if (!subshape_->is_dynamic_dimension(dim_index)) {
573     // This is a static dimension, return size.
574     return subshape_->dimensions(dim_index);
575   }
576   return dynamic_size_buffer()[dim_index];
577 }
578 
SetDynamicSize(int64_t dim_index,int32_t size)579 void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
580   CHECK(LayoutUtil::IsDenseArray(subshape()));
581   CHECK(subshape_->is_dynamic_dimension(dim_index));
582   dynamic_size_buffer()[dim_index] = size;
583 }
584 
AllocateBuffers()585 void LiteralBase::Piece::AllocateBuffers() {
586   const int64_t bytes = total_bytes();
587   if (bytes > kMaxInlinedBytes) {
588     CHECK_EQ(buffer(), nullptr);
589     rep_.emplace<ArrayRep>();
590     set_buffer(static_cast<char*>(
591         tensorflow::port::AlignedMalloc(bytes, kMinimumAlignment)));
592   } else {
593     rep_.emplace<InlinedRep>();
594   }
595 }
596 
DeallocateBuffers()597 void LiteralBase::Piece::DeallocateBuffers() {
598   if (auto* array_rep = GetArrayRep()) {
599     tensorflow::port::AlignedFree(array_rep->data);
600     rep_.emplace<Uninitialized>();
601   }
602 }
603 
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)604 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
605                                     bool only_dynamic_bound) {
606   CHECK(subshape_ != nullptr);
607   CHECK(src.subshape_ != nullptr);
608   if (src.array_value_state_ == ArrayValueState::kUnknown ||
609       src.array_value_state_ == ArrayValueState::kUndetermined) {
610     if (array_value_state_ == ArrayValueState::kKnown) {
611       DeallocateBuffers();
612     }
613     array_value_state_ = src.array_value_state_;
614     return OkStatus();
615   } else {
616     CHECK(src.array_value_state_ == ArrayValueState::kKnown);
617     if (array_value_state_ == ArrayValueState::kUndetermined ||
618         array_value_state_ == ArrayValueState::kUnknown) {
619       AllocateBuffers();
620     }
621     array_value_state_ = src.array_value_state_;
622   }
623 
624   if (ShapeUtil::Equal(subshape(), src.subshape())) {
625     // If the layouts are equal it's faster just to memcpy.
626     memcpy(buffer(), src.buffer(), src.size_bytes());
627   } else {
628     std::vector<int64_t> origin(subshape().rank(), 0);
629     switch (subshape().element_type()) {
630 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                      \
631   case (XLA_T):                                                             \
632     if (only_dynamic_bound) {                                               \
633       CopyElementsWithDynamicBound<NATIVE_T>(src);                          \
634     } else {                                                                \
635       CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
636                                     subshape(), src.subshape());            \
637     }                                                                       \
638     break;
639       COPY_ELEMENTS(U8, uint8_t);
640       COPY_ELEMENTS(U16, uint16_t);
641       COPY_ELEMENTS(U32, uint32_t);
642       COPY_ELEMENTS(U64, uint64_t);
643       COPY_ELEMENTS(S8, int8_t);
644       COPY_ELEMENTS(S16, int16_t);
645       COPY_ELEMENTS(S32, int32_t);
646       COPY_ELEMENTS(S64, int64_t);
647       COPY_ELEMENTS(F16, half);
648       COPY_ELEMENTS(BF16, bfloat16);
649       COPY_ELEMENTS(F32, float);
650       COPY_ELEMENTS(F64, double);
651       COPY_ELEMENTS(C64, complex64);
652       COPY_ELEMENTS(C128, complex128);
653       COPY_ELEMENTS(PRED, bool);
654 #undef COPY_ELEMENTS
655       default:
656         return Unimplemented(
657             "Copying a Literal object with element type %s is not implemented.",
658             PrimitiveType_Name(subshape().element_type()));
659     }
660   }
661   DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
662   if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
663     memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
664            src.dynamic_size_buffer_bytes());
665   }
666   return OkStatus();
667 }
668 
SetDynamicSize(int64_t dim_index,int32_t size)669 void MutableLiteralBase::SetDynamicSize(int64_t dim_index, int32_t size) {
670   return SetDynamicSize(dim_index, {}, size);
671 }
672 
SetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index,int32_t size)673 void MutableLiteralBase::SetDynamicSize(int64_t dim_index,
674                                         const ShapeIndex& shape_index,
675                                         int32_t size) {
676   Shape* subshape =
677       ShapeUtil::GetMutableSubshape(mutable_shape_do_not_use(), shape_index);
678   CHECK_GE(subshape->dimensions(dim_index), size);
679   if (subshape->dimensions(dim_index) == size) {
680     subshape->set_dynamic_dimension(dim_index, false);
681     return;
682   }
683   subshape->set_dynamic_dimension(dim_index, true);
684   CHECK_EQ(&piece(shape_index).subshape(), subshape);
685 
686   piece(shape_index).SetDynamicSize(dim_index, size);
687 }
688 
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)689 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
690                                     const ShapeIndex& dest_shape_index,
691                                     const ShapeIndex& src_shape_index,
692                                     bool only_dynamic_bound) {
693   const Shape& dest_subshape =
694       ShapeUtil::GetSubshape(shape(), dest_shape_index);
695   const Shape& src_subshape =
696       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
697   if (only_dynamic_bound) {
698     auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
699     auto compact_shape =
700         dest_subshape.is_static() ? dest_subshape : src_subshape;
701     CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
702         << compact_shape.ToString() << " vs " << bound_shape.ToString();
703   } else {
704     if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
705       return InvalidArgument(
706           "Destination subshape incompatible with source subshape: %s vs %s",
707           ShapeUtil::HumanString(dest_subshape),
708           ShapeUtil::HumanString(src_subshape));
709     }
710   }
711   return mutable_root_piece().ForEachMutableSubpieceWithStatus(
712       [&](const ShapeIndex& index, Piece* piece) {
713         if (!piece->subshape().IsArray()) {
714           return OkStatus();
715         }
716 
717         // Determine if this index is in the part of this literal that we want
718         // to copy over from src_literal.
719         bool in_subtree_to_copy = true;
720         for (int i = 0; i < dest_shape_index.size(); ++i) {
721           if (index[i] != dest_shape_index[i]) {
722             in_subtree_to_copy = false;
723             break;
724           }
725         }
726         if (!in_subtree_to_copy) {
727           return OkStatus();
728         }
729         // Construct the index of the corresponding piece in the source literal.
730         ShapeIndex src_piece_index = src_shape_index;
731         for (int64_t i = dest_shape_index.size(), end = index.size(); i < end;
732              ++i) {
733           src_piece_index.push_back(index[i]);
734         }
735         TF_RETURN_IF_ERROR(
736             piece->CopyFrom(src_literal.piece(src_piece_index),
737                             /*only_dynamic_bound=*/only_dynamic_bound));
738         return OkStatus();
739       });
740 }
741 
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)742 Status Literal::MoveFrom(Literal&& src_literal,
743                          const ShapeIndex& dest_shape_index) {
744   const Shape& dest_subshape =
745       ShapeUtil::GetSubshape(shape(), dest_shape_index);
746   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
747     return InvalidArgument(
748         "Destination subshape not equal to source shape: %s vs %s",
749         ShapeUtil::HumanString(dest_subshape),
750         ShapeUtil::HumanString(src_literal.shape()));
751   }
752 
753   src_literal.root_piece_.ForEachMutableSubpiece(
754       [&](const ShapeIndex& src_index, Piece* src_piece) {
755         if (!src_piece->subshape().IsArray()) {
756           return;
757         }
758 
759         ShapeIndex dest_index = dest_shape_index;
760         for (int64_t i : src_index) {
761           dest_index.push_back(i);
762         }
763         Piece& dest_piece = piece(dest_index);
764         dest_piece.DeallocateBuffers();
765         dest_piece.MoveDataFrom(*src_piece);
766       });
767 
768   src_literal.shape_ = MaybeOwningShapePtr(&NilShape());
769   src_literal.root_piece_ = Piece();
770   src_literal.root_piece_.set_subshape(src_literal.shape_.get());
771 
772   return OkStatus();
773 }
774 
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64_t> src_base,absl::Span<const int64_t> dest_base,absl::Span<const int64_t> copy_size)775 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
776                                          absl::Span<const int64_t> src_base,
777                                          absl::Span<const int64_t> dest_base,
778                                          absl::Span<const int64_t> copy_size) {
779   TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
780   TF_RET_CHECK(src_literal.shape().IsArray())
781       << ShapeUtil::HumanString(src_literal.shape());
782   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
783 
784   switch (shape().element_type()) {
785     case U8:
786       return CopySliceFromInternal<uint8_t>(src_literal, src_base, dest_base,
787                                             copy_size);
788     case U16:
789       return CopySliceFromInternal<uint16_t>(src_literal, src_base, dest_base,
790                                              copy_size);
791     case U32:
792       return CopySliceFromInternal<uint32_t>(src_literal, src_base, dest_base,
793                                              copy_size);
794     case U64:
795       return CopySliceFromInternal<uint64_t>(src_literal, src_base, dest_base,
796                                              copy_size);
797     case S8:
798       return CopySliceFromInternal<int8_t>(src_literal, src_base, dest_base,
799                                            copy_size);
800     case S16:
801       return CopySliceFromInternal<int16_t>(src_literal, src_base, dest_base,
802                                             copy_size);
803     case S32:
804       return CopySliceFromInternal<int32_t>(src_literal, src_base, dest_base,
805                                             copy_size);
806     case S64:
807       return CopySliceFromInternal<int64_t>(src_literal, src_base, dest_base,
808                                             copy_size);
809     case F16:
810       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
811                                          copy_size);
812     case BF16:
813       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
814                                              copy_size);
815     case F32:
816       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
817                                           copy_size);
818     case F64:
819       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
820                                            copy_size);
821     case C64:
822       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
823                                               copy_size);
824     case C128:
825       return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
826                                                copy_size);
827     case PRED:
828       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
829                                          copy_size);
830     default:
831       break;
832   }
833   return Unimplemented(
834       "Copying a slice from a Literal object with element type %d is not "
835       "implemented.",
836       shape().element_type());
837 }
838 
PopulateR1(const tensorflow::core::Bitmap & values)839 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
840   CHECK(shape().IsArray());
841   CHECK_EQ(shape().rank(), 1);
842   CHECK_EQ(element_count(), values.bits());
843   CHECK_EQ(shape().element_type(), PRED);
844   for (int64_t i = 0; i < static_cast<int64_t>(values.bits()); ++i) {
845     Set({i}, values.get(i));
846   }
847 }
848 
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const849 Literal LiteralBase::Relayout(const Layout& new_layout,
850                               const ShapeIndex& shape_index) const {
851   // Create new shape with 'new_layout' set at the given shape index.
852   Shape new_shape = shape();
853   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
854   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
855   *subshape->mutable_layout() = new_layout;
856   Literal result(new_shape);
857   TF_CHECK_OK(result.CopyFrom(*this));
858   return result;
859 }
860 
Relayout(const Shape & shape_with_layout) const861 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
862   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
863       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
864       << " not compatible with literal shape "
865       << ShapeUtil::HumanString(shape());
866   Literal result = CreateFromShape(shape_with_layout);
867   ShapeUtil::ForEachSubshape(
868       result.shape(),
869       [this, &result](const Shape& subshape, const ShapeIndex& index) {
870         if (subshape.IsArray()) {
871           TF_CHECK_OK(result.CopyFrom(*this,
872                                       /*dest_shape_index=*/index,
873                                       /*src_shape_index=*/index));
874         }
875       });
876   return result;
877 }
878 
ToBoundedDynamic(const Shape & bounded_shape) const879 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
880   CHECK(bounded_shape.is_dynamic());
881   Literal result(bounded_shape);
882   ShapeUtil::ForEachSubshape(
883       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
884         if (!subshape.IsArray()) {
885           return;
886         }
887         for (int64_t i = 0; i < subshape.rank(); ++i) {
888           result.SetDynamicSize(i, subshape.dimensions(i));
889         }
890       });
891   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
892 
893   return result;
894 }
895 
ToStatic() const896 Literal LiteralBase::ToStatic() const {
897   // Create new shape with 'new_layout' set at the given shape index.
898   Shape new_shape = shape();
899   ShapeUtil::ForEachMutableSubshape(
900       &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
901         if (!subshape->IsArray()) {
902           return;
903         }
904         for (int64_t i = 0; i < subshape->rank(); ++i) {
905           subshape->set_dynamic_dimension(i, false);
906           subshape->set_dimensions(i, GetDynamicSize(i, index));
907         }
908       });
909   Literal result(new_shape);
910   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
911   return result;
912 }
913 
Broadcast(const Shape & result_shape,absl::Span<const int64_t> dimensions) const914 StatusOr<Literal> LiteralBase::Broadcast(
915     const Shape& result_shape, absl::Span<const int64_t> dimensions) const {
916   if (!shape().IsArray()) {
917     return InvalidArgument("Broadcast only supports arrays.");
918   }
919 
920   for (int64_t i = 0, end = dimensions.size(); i < end; i++) {
921     TF_RET_CHECK(shape().dimensions(i) ==
922                  result_shape.dimensions(dimensions[i]));
923   }
924 
925   TF_RET_CHECK(result_shape.element_type() == shape().element_type());
926   Literal result(result_shape);
927   // scratch_source_index is temporary storage space for the computed index into
928   // the input literal.  We put it here to avoid allocating an std::vector in
929   // every iteration of ShapeUtil::ForEachIndex.
930   std::vector<int64_t> scratch_source_index(shape().dimensions_size());
931 
932   char* dest_data = static_cast<char*>(result.untyped_data());
933   const char* source_data = static_cast<const char*>(untyped_data());
934   const int64_t primitive_size =
935       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
936   for (int64_t i = 0; i < dimensions.size(); ++i) {
937     int64_t dynamic_size = GetDynamicSize(i);
938     result.SetDynamicSize(dimensions[i], dynamic_size);
939   }
940 
941   ShapeUtil::ForEachIndex(
942       result_shape, [&](absl::Span<const int64_t> output_index) {
943         for (int64_t i = 0, end = dimensions.size(); i < end; ++i) {
944           scratch_source_index[i] = output_index[dimensions[i]];
945         }
946         int64_t dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
947             result_shape, output_index);
948         int64_t source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
949             shape(), scratch_source_index);
950         memcpy(dest_data + primitive_size * dest_index,
951                source_data + primitive_size * source_index, primitive_size);
952         return true;
953       });
954 
955   return std::move(result);
956 }
957 
Reshape(absl::Span<const int64_t> dimensions) const958 StatusOr<Literal> LiteralBase::Reshape(
959     absl::Span<const int64_t> dimensions) const {
960   if (!shape().IsArray()) {
961     return InvalidArgument("Reshape does not support tuples.");
962   }
963   if (shape().is_dynamic()) {
964     return Unimplemented("Dynamic reshape is not implemented.");
965   }
966   Literal output;
967   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
968     output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
969   } else {
970     output = Clone();
971   }
972   // Because the layout is monotonic, we can simply reuse the same sequence of
973   // values without changing their order.
974   *output.mutable_shape_do_not_use() =
975       ShapeUtil::MakeShape(shape().element_type(), dimensions);
976 
977   int64_t elements_before = ShapeUtil::ElementsIn(shape());
978   int64_t elements_after = ShapeUtil::ElementsIn(output.shape());
979   if (elements_before != elements_after) {
980     return InvalidArgument(
981         "Shapes before and after Literal::Reshape have different numbers "
982         "of elements: %s vs %s.",
983         ShapeUtil::HumanString(shape()),
984         ShapeUtil::HumanString(output.shape()));
985   }
986   return std::move(output);
987 }
988 
Transpose(absl::Span<const int64_t> permutation) const989 Literal LiteralBase::Transpose(absl::Span<const int64_t> permutation) const {
990   CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
991   CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
992       << "Given permutation is not a permutation of dimension numbers";
993   // To transpose the array, we just permute the dimensions and layout, and
994   // do a straight memory copy of the raw data set.
995   // This is considerably faster than iterating over every array element using
996   // the EachCell<>() and Set<>() APIs.
997   Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
998   // Replace the layout with one affine to this shape, such that a
999   // transpose operation can be performed by leaving the flat values
1000   // representation intact.
1001   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
1002   // The shape with affine layout resulting from that operation will be
1003   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
1004   // most minor.
1005   //
1006   // Essentially, given MinMaj(Di) the position of the Di dimension within the
1007   // minor to major vector, and given T(Di) the index that the original Di
1008   // dimension has within the transposed array, a layout is affine if
1009   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
1010   // vector of the affine layout.
1011   std::vector<int64_t> inverse_permutation = InversePermutation(permutation);
1012   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
1013   Layout* layout = permuted_shape.mutable_layout();
1014   layout->clear_minor_to_major();
1015   for (auto index : LayoutUtil::MinorToMajor(shape())) {
1016     layout->add_minor_to_major(inverse_permutation[index]);
1017   }
1018   Literal new_literal(permuted_shape);
1019   for (int64_t i = 0; i < shape().rank(); i++) {
1020     new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
1021   }
1022   DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
1023             ShapeUtil::ByteSizeOf(shape()));
1024   std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
1025   return new_literal;
1026 }
1027 
1028 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64_t> start_indices) const1029 Literal LiteralBase::SliceInternal(
1030     const Shape& result_shape, absl::Span<const int64_t> start_indices) const {
1031   Literal result_literal(result_shape);
1032   DimensionVector new_indices(result_shape.rank());
1033   CHECK(result_literal
1034             .Populate<NativeT>([&](absl::Span<const int64_t> indices) {
1035               for (int64_t i = 0; i < result_shape.rank(); ++i) {
1036                 new_indices[i] = indices[i] + start_indices[i];
1037               }
1038               return Get<NativeT>(new_indices);
1039             })
1040             .ok());
1041   for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
1042     if (shape().is_dynamic_dimension(dnum)) {
1043       int64_t dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
1044       CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
1045       dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
1046       result_literal.SetDynamicSize(dnum, dynamic_size);
1047     }
1048   }
1049   return result_literal;
1050 }
1051 
Slice(absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices) const1052 Literal LiteralBase::Slice(absl::Span<const int64_t> start_indices,
1053                            absl::Span<const int64_t> limit_indices) const {
1054   CHECK(shape().IsArray()) << "tuple is not supported for slice";
1055 
1056   DimensionVector result_dimensions;
1057   for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
1058     CHECK_GE(start_indices[dnum], 0);
1059     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
1060         << "dnum = " << dnum;
1061     int64_t dimension = limit_indices[dnum] - start_indices[dnum];
1062     CHECK_GE(dimension, 0) << "dnum = " << dnum;
1063     result_dimensions.push_back(dimension);
1064   }
1065   auto result_shape =
1066       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
1067                                      LayoutUtil::MinorToMajor(shape()));
1068   ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
1069   switch (result_shape.element_type()) {
1070     case PRED:
1071       return SliceInternal<bool>(result_shape, start_indices);
1072     case U8:
1073       return SliceInternal<uint8_t>(result_shape, start_indices);
1074     case U16:
1075       return SliceInternal<uint16_t>(result_shape, start_indices);
1076     case U32:
1077       return SliceInternal<uint32_t>(result_shape, start_indices);
1078     case U64:
1079       return SliceInternal<uint64_t>(result_shape, start_indices);
1080     case S8:
1081       return SliceInternal<int8_t>(result_shape, start_indices);
1082     case S16:
1083       return SliceInternal<int16_t>(result_shape, start_indices);
1084     case S32:
1085       return SliceInternal<int32_t>(result_shape, start_indices);
1086     case S64:
1087       return SliceInternal<int64_t>(result_shape, start_indices);
1088     case F16:
1089       return SliceInternal<half>(result_shape, start_indices);
1090     case BF16:
1091       return SliceInternal<bfloat16>(result_shape, start_indices);
1092     case F32:
1093       return SliceInternal<float>(result_shape, start_indices);
1094     case F64:
1095       return SliceInternal<double>(result_shape, start_indices);
1096     case C64:
1097       return SliceInternal<complex64>(result_shape, start_indices);
1098     case C128:
1099       return SliceInternal<complex128>(result_shape, start_indices);
1100     default:
1101       LOG(FATAL) << "not yet implemented: "
1102                  << PrimitiveType_Name(result_shape.element_type());
1103   }
1104 }
1105 
Clone() const1106 Literal LiteralBase::Clone() const {
1107   Literal result(shape());
1108   TF_CHECK_OK(result.CopyFrom(*this));
1109   return result;
1110 }
1111 
CloneToUnique() const1112 std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
1113   auto result = std::make_unique<Literal>(shape());
1114   TF_CHECK_OK(result->CopyFrom(*this));
1115   return result;
1116 }
1117 
IsDetermined(const ShapeIndex & shape_index) const1118 bool LiteralBase::IsDetermined(const ShapeIndex& shape_index) const {
1119   return piece(shape_index).IsDetermined();
1120 }
1121 
IsKnown(const ShapeIndex & shape_index) const1122 bool LiteralBase::IsKnown(const ShapeIndex& shape_index) const {
1123   return piece(shape_index).IsKnown();
1124 }
1125 
GetAsString(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index) const1126 std::string LiteralBase::GetAsString(absl::Span<const int64_t> multi_index,
1127                                      const ShapeIndex& shape_index) const {
1128   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
1129   CHECK(LayoutUtil::IsDenseArray(subshape));
1130   switch (subshape.element_type()) {
1131     case PRED:
1132       return Get<bool>(multi_index, shape_index) ? "true" : "false";
1133     case S8:
1134       return StrCat(Get<int8_t>(multi_index, shape_index));
1135     case S16:
1136       return StrCat(Get<int16_t>(multi_index, shape_index));
1137     case S32:
1138       return StrCat(Get<int32_t>(multi_index, shape_index));
1139     case S64:
1140       return StrCat(Get<int64_t>(multi_index, shape_index));
1141     case U8:
1142       return StrCat(Get<uint8_t>(multi_index, shape_index));
1143     case U16:
1144       return StrCat(Get<uint16_t>(multi_index, shape_index));
1145     case U32:
1146       return StrCat(Get<uint32_t>(multi_index, shape_index));
1147     case U64:
1148       return StrCat(Get<uint64_t>(multi_index, shape_index));
1149     case F16:
1150       return RoundTripFpToString(Get<half>(multi_index, shape_index));
1151     case F32:
1152       return RoundTripFpToString(Get<float>(multi_index, shape_index));
1153     case BF16:
1154       return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1155     case F64:
1156       return RoundTripFpToString(Get<double>(multi_index, shape_index));
1157     case C64: {
1158       complex64 c = Get<complex64>(multi_index, shape_index);
1159       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1160                     RoundTripFpToString(c.imag()), ")");
1161     }
1162     case C128: {
1163       complex128 c = Get<complex128>(multi_index, shape_index);
1164       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1165                     RoundTripFpToString(c.imag()), ")");
1166     }
1167     default:
1168       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1169   }
1170 }
1171 
GetIntegralAsS64(absl::Span<const int64_t> multi_index) const1172 std::optional<int64_t> LiteralBase::GetIntegralAsS64(
1173     absl::Span<const int64_t> multi_index) const {
1174   CHECK(LayoutUtil::IsDenseArray(shape()));
1175   switch (shape().element_type()) {
1176     case PRED:
1177       return Get<bool>(multi_index);
1178     case S8:
1179       return Get<int8_t>(multi_index);
1180     case U8:
1181       return Get<uint8_t>(multi_index);
1182     case S16:
1183       return Get<int16_t>(multi_index);
1184     case U16:
1185       return Get<uint16_t>(multi_index);
1186     case S32:
1187       return Get<int32_t>(multi_index);
1188     case U32:
1189       return Get<uint32_t>(multi_index);
1190     case S64:
1191       return Get<int64_t>(multi_index);
1192     case U64:
1193       return Get<uint64_t>(multi_index);
1194     default:
1195       return std::nullopt;
1196   }
1197 }
1198 
GetAsDouble(absl::Span<const int64_t> multi_index) const1199 std::optional<double> LiteralBase::GetAsDouble(
1200     absl::Span<const int64_t> multi_index) const {
1201   CHECK(LayoutUtil::IsDenseArray(shape()));
1202   switch (shape().element_type()) {
1203     case F16:
1204       return static_cast<double>(Get<half>(multi_index));
1205     case F32:
1206       return static_cast<double>(Get<float>(multi_index));
1207     case F64:
1208       return Get<double>(multi_index);
1209     case BF16:
1210       return static_cast<double>(Get<bfloat16>(multi_index));
1211     default:
1212       return std::nullopt;
1213   }
1214 }
1215 
GetAsComplex128(absl::Span<const int64_t> multi_index) const1216 std::optional<complex128> LiteralBase::GetAsComplex128(
1217     absl::Span<const int64_t> multi_index) const {
1218   switch (shape().element_type()) {
1219     case BF16:
1220       return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1221     case F16:
1222       return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1223     case F32:
1224       return {{Get<float>(multi_index), 0}};
1225     case F64:
1226       return {{Get<double>(multi_index), 0}};
1227     case C64:
1228       return {Get<complex64>(multi_index)};
1229     case C128:
1230       return {Get<complex128>(multi_index)};
1231     case S8:
1232       return {Get<int8_t>(multi_index)};
1233     default:
1234       return std::nullopt;
1235   }
1236 }
1237 
SetIntegralAsS64(absl::Span<const int64_t> multi_index,int64_t value)1238 Status MutableLiteralBase::SetIntegralAsS64(
1239     absl::Span<const int64_t> multi_index, int64_t value) {
1240   CHECK(LayoutUtil::IsDenseArray(shape()));
1241   switch (shape().element_type()) {
1242     case PRED:
1243       Set<bool>(multi_index, value);
1244       break;
1245     case U8:
1246       Set<uint8_t>(multi_index, value);
1247       break;
1248     case S32:
1249       Set<int32_t>(multi_index, value);
1250       break;
1251     case S64:
1252       Set<int64_t>(multi_index, value);
1253       break;
1254     case U32:
1255       Set<uint32_t>(multi_index, value);
1256       break;
1257     case U64:
1258       Set<uint64_t>(multi_index, value);
1259       break;
1260     default:
1261       return FailedPrecondition("Array element type is not integral: %s",
1262                                 PrimitiveType_Name(shape().element_type()));
1263   }
1264   return OkStatus();
1265 }
1266 
SetFromDouble(absl::Span<const int64_t> multi_index,double value)1267 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64_t> multi_index,
1268                                          double value) {
1269   CHECK(LayoutUtil::IsDenseArray(shape()));
1270   switch (shape().element_type()) {
1271     case F16:
1272       Set<half>(multi_index, Eigen::half(value));
1273       break;
1274     case F32:
1275       Set<float>(multi_index, value);
1276       break;
1277     case F64:
1278       Set<double>(multi_index, value);
1279       break;
1280     case BF16:
1281       Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1282       break;
1283     default:
1284       return FailedPrecondition("Array element type is not floating: %s",
1285                                 PrimitiveType_Name(shape().element_type()));
1286   }
1287   return OkStatus();
1288 }
1289 
1290 namespace {
1291 
ShapeToString(bool print_layout,const Shape & shape)1292 std::string ShapeToString(bool print_layout, const Shape& shape) {
1293   return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1294                       : ShapeUtil::HumanString(shape);
1295 }
1296 
1297 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1298                     bool print_shape, bool print_layout,
1299                     std::vector<std::string>* pieces);
1300 
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1301 void TupleToStringHelper(const LiteralBase& literal,
1302                          const ShapeIndex& shape_index, bool print_shape,
1303                          bool print_layout, std::vector<std::string>* pieces) {
1304   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1305   pieces->push_back("(\n");
1306   std::vector<std::string> tuple_pieces;
1307   const auto tuple_element_count = ShapeUtil::TupleElementCount(subshape);
1308   tuple_pieces.reserve(tuple_element_count);
1309   for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1310     ShapeIndex element_index = shape_index;
1311     element_index.push_back(i);
1312     std::vector<std::string> element_pieces;
1313     ToStringHelper(literal, element_index, print_shape, print_layout,
1314                    &element_pieces);
1315     tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1316   }
1317   pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1318   pieces->push_back("\n)");
1319 }
1320 
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1321 void DenseArrayToStringHelper(const LiteralBase& literal,
1322                               const ShapeIndex& shape_index, bool print_shape,
1323                               bool print_layout,
1324                               std::vector<std::string>* pieces) {
1325   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1326   int64_t rank = subshape.rank();
1327 
1328   std::function<void(absl::Span<const int64_t> dimensions,
1329                      std::vector<int64_t>*)>
1330       to_string_recursive = [&](absl::Span<const int64_t> dimensions,
1331                                 std::vector<int64_t>* accum_indices) {
1332         // dimensions.size() decreases by 1 at each recursive call,
1333         // and accum_indices->size() increases by 1.
1334         // Their sum is equal to the rank of the tensor.
1335         CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1336 
1337         auto brace_to_string = [&](std::string brace) -> std::string {
1338           // Handle 1D tensor
1339           if (rank == 1) {
1340             return brace;
1341           }
1342           // Handle the innermost tensor of a 2D+ tensor.
1343           if (dimensions.size() == 1 && brace == "{") {
1344             return StrCat("  ", brace, dimensions[0] <= 1 ? "" : " ");
1345           }
1346           if (dimensions.size() == 1 && brace == "}") {
1347             return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1348           }
1349           // Handle the non-innermost tensors of a 2D+ tensor.
1350           if (brace == "{") {
1351             const int64_t accum_indices_size = accum_indices->size();
1352             if (rank > 3 && !accum_indices->empty() &&
1353                 accum_indices_size < rank) {
1354               int index = accum_indices->size() - 1;
1355               int value = accum_indices->back();
1356               return StrCat(brace, " /*i", index, "=", value, "*/\n");
1357             }
1358             return StrCat(brace, "\n");
1359           }
1360           return StrCat("\n", brace);
1361         };
1362 
1363         if (dimensions.empty()) {
1364           // Display predicates as 0s and 1s so that the string is more dense.
1365           std::string elem;
1366           if (subshape.element_type() == PRED && rank > 0) {
1367             elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1368           } else {
1369             elem = literal.GetAsString(*accum_indices, shape_index);
1370           }
1371           pieces->push_back(elem);
1372         } else {
1373           pieces->push_back(brace_to_string("{"));
1374           for (int i = 0; i < dimensions[0]; ++i) {
1375             accum_indices->push_back(i);
1376             to_string_recursive(dimensions.subspan(1), accum_indices);
1377             accum_indices->pop_back();
1378             if (i < dimensions[0] - 1) {
1379               pieces->push_back(",");
1380               pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1381             }
1382           }
1383           pieces->push_back(brace_to_string("}"));
1384         }
1385       };
1386 
1387   if (print_shape) {
1388     pieces->push_back(ShapeToString(print_layout, subshape));
1389     if (subshape.is_dynamic()) {
1390       pieces->push_back("(");
1391       for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1392         pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1393         if (i < subshape.dimensions_size() - 1) {
1394           pieces->push_back(",");
1395         }
1396       }
1397       pieces->push_back(")");
1398     }
1399     pieces->push_back(" ");
1400   }
1401   std::vector<int64_t> indices = {};
1402   std::vector<int64_t> dimensions;
1403   dimensions.reserve(subshape.rank());
1404   for (int64_t i = 0; i < subshape.rank(); ++i) {
1405     dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1406   }
1407   to_string_recursive(dimensions, &indices);
1408 }
1409 
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1410 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1411                     bool print_shape, bool print_layout,
1412                     std::vector<std::string>* pieces) {
1413   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1414   CHECK(LayoutUtil::HasLayout(literal.shape()));
1415   CHECK(LayoutUtil::HasLayout(subshape));
1416   if (subshape.IsTuple()) {
1417     TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1418                         pieces);
1419   } else if (subshape.IsToken()) {
1420     pieces->push_back("token");
1421   } else {
1422     CHECK(LayoutUtil::IsDenseArray(subshape));
1423     if (literal.IsKnown(shape_index)) {
1424       DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1425                                pieces);
1426     } else {
1427       pieces->push_back(ShapeToString(print_layout, subshape));
1428       pieces->push_back(" ");
1429       if (literal.IsDetermined(shape_index)) {
1430         pieces->push_back("unknown");
1431       } else {
1432         pieces->push_back("undetermined");
1433       }
1434     }
1435   }
1436 }
1437 
1438 }  // namespace
1439 
ToString() const1440 std::string LiteralBase::ToString() const {
1441   std::vector<std::string> pieces;
1442   CHECK(LayoutUtil::HasLayout(this->shape()));
1443   ToStringHelper(*this, {}, /*print_shape=*/true,
1444                  /*print_layout=*/false, &pieces);
1445   return absl::StrJoin(pieces, "");
1446 }
1447 
ToStringOneline() const1448 std::string LiteralBase::ToStringOneline() const {
1449   return CompactOneline(ToString());
1450 }
1451 
ToStringWithoutShape() const1452 std::string LiteralBase::ToStringWithoutShape() const {
1453   std::vector<std::string> pieces;
1454   CHECK(LayoutUtil::HasLayout(this->shape()));
1455   ToStringHelper(*this, {}, /*print_shape=*/false,
1456                  /*print_layout=*/false, &pieces);
1457   return absl::StrJoin(pieces, "");
1458 }
1459 
ToStringWithoutShapeOneline() const1460 std::string LiteralBase::ToStringWithoutShapeOneline() const {
1461   return CompactOneline(ToStringWithoutShape());
1462 }
1463 
ToStringWithLayout() const1464 std::string LiteralBase::ToStringWithLayout() const {
1465   std::vector<std::string> pieces;
1466   CHECK(LayoutUtil::HasLayout(this->shape()));
1467   ToStringHelper(*this, {}, /*print_shape=*/true,
1468                  /*print_layout=*/true, &pieces);
1469   return absl::StrJoin(pieces, "");
1470 }
1471 
ToStringWithLayoutOneline() const1472 std::string LiteralBase::ToStringWithLayoutOneline() const {
1473   return CompactOneline(ToStringWithLayout());
1474 }
1475 
EachCellAsString(const std::function<void (absl::Span<const int64_t> indices,const std::string & value)> & per_cell) const1476 void LiteralBase::EachCellAsString(
1477     const std::function<void(absl::Span<const int64_t> indices,
1478                              const std::string& value)>& per_cell) const {
1479   if (ShapeUtil::IsZeroElementArray(shape())) {
1480     return;
1481   }
1482   std::vector<int64_t> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1483       shape(), /*linear_index=*/0);
1484   do {
1485     per_cell(indices, GetAsString(indices));
1486   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1487 }
1488 
1489 namespace {
1490 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1491 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1492                                                const ConverterType& converter) {
1493   CHECK(src_literal.shape().IsArray());
1494   Literal result_literal(ShapeUtil::ChangeElementType(
1495       src_literal.shape(),
1496       primitive_util::NativeToPrimitiveType<NativeDestT>()));
1497   auto src_data = src_literal.data<NativeSrcT>();
1498   auto dest_data = result_literal.template data<NativeDestT>();
1499   int64_t num_elements = src_literal.element_count();
1500 
1501   for (int64_t i = 0; i < num_elements; ++i) {
1502     dest_data[i] = converter(src_data[i]);
1503   }
1504   return result_literal;
1505 }
1506 
1507 template <typename NativeSrcT, typename NativeDestT>
1508 typename std::enable_if<std::is_same<NativeSrcT, Eigen::half>::value &&
1509                             (std::is_same<NativeDestT, complex64>::value ||
1510                              std::is_same<NativeDestT, complex128>::value),
1511                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1512 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1513   auto converter = [](NativeSrcT src) {
1514     return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1515   };
1516   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1517       src_literal, converter);
1518 }
1519 
1520 template <typename NativeSrcT, typename NativeDestT>
1521 typename std::enable_if<std::is_floating_point<NativeSrcT>::value &&
1522                             std::is_integral<NativeDestT>::value,
1523                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1524 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1525   auto converter = [](NativeSrcT src) {
1526     // C++ [conv.bool]p1:
1527     //   A prvalue of arithmetic [...] type can be converted to a prvalue of
1528     //   type bool. A zero value [...] is converted to false; any other value is
1529     //   converted to true.
1530     // C++ [conv.fpint]p1:
1531     //   [...] The behavior is undefined if the truncated value cannot be
1532     //   represented in the destination type.
1533     //
1534     // Using static_cast to convert a float to an integral type other than bool
1535     // may be undefined if the value's magnitude is too large or it is a NaN.
1536     // Let's choose saturating arithmetic as it captures the spirit of infinity
1537     // and arbitrarily map NaN to zero.
1538     if (!std::is_same<NativeDestT, bool>::value) {
1539       if (src != src) {
1540         return NativeDestT{0};
1541       }
1542       if (src >= std::numeric_limits<NativeDestT>::max()) {
1543         return std::numeric_limits<NativeDestT>::max();
1544       }
1545       if (src <= std::numeric_limits<NativeDestT>::lowest()) {
1546         return std::numeric_limits<NativeDestT>::lowest();
1547       }
1548     }
1549     return static_cast<NativeDestT>(src);
1550   };
1551   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1552       src_literal, converter);
1553 }
1554 
1555 template <typename NativeSrcT, typename NativeDestT>
1556 typename std::enable_if<!(std::is_floating_point<NativeSrcT>::value &&
1557                           std::is_integral<NativeDestT>::value) &&
1558                             !(std::is_same<NativeSrcT, Eigen::half>::value &&
1559                               (std::is_same<NativeDestT, complex64>::value ||
1560                                std::is_same<NativeDestT, complex128>::value)),
1561                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1562 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1563   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1564   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1565       src_literal, converter);
1566 }
1567 
1568 template <typename NativeSrcT, typename NativeDestT>
1569 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1570                          !std::is_same<NativeDestT, Eigen::half>::value),
1571                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1572 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1573   auto converter = [](NativeSrcT src) {
1574     return absl::bit_cast<NativeDestT>(GetRawValue(src));
1575   };
1576   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1577       src_literal, converter);
1578 }
1579 
1580 template <typename NativeSrcT, typename NativeDestT>
1581 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1582                          std::is_same<NativeDestT, Eigen::half>::value),
1583                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1584 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1585   // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1586   // cast to unsigned short first.
1587   auto converter = [](NativeSrcT src) {
1588     return Eigen::numext::bit_cast<Eigen::half>(
1589         absl::bit_cast<uint16_t>(GetRawValue(src)));
1590   };
1591   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1592       src_literal, converter);
1593 }
1594 
1595 // This template specialization is here to make the compiler happy. bit_cast has
1596 // a static check that the types are the same size. This specialization should
1597 // never be used because the source and destination types are checked for
1598 // identical sizes higher up.
1599 template <typename NativeSrcT, typename NativeDestT>
1600 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1601                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1602 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1603   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1604 }
1605 
1606 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1607 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1608   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1609   if (bitcast) {
1610     return BitcastBetweenNativeTypes<
1611         typename primitive_util::PrimitiveTypeToNative<
1612             primitive_src_type>::type,
1613         typename primitive_util::PrimitiveTypeToNative<
1614             primitive_dest_type>::type>(src_literal);
1615   } else {
1616     return ConvertBetweenNativeTypes<
1617         typename primitive_util::PrimitiveTypeToNative<
1618             primitive_src_type>::type,
1619         typename primitive_util::PrimitiveTypeToNative<
1620             primitive_dest_type>::type>(src_literal);
1621   }
1622 }
1623 
1624 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1625 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1626                                            PrimitiveType primitive_dest_type,
1627                                            bool bitcast) {
1628   switch (primitive_dest_type) {
1629 #define CONVERT_IF_TYPES_MATCH(type)                                    \
1630   case (type):                                                          \
1631     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1632                                                            bitcast);
1633     CONVERT_IF_TYPES_MATCH(PRED)
1634     CONVERT_IF_TYPES_MATCH(S8)
1635     CONVERT_IF_TYPES_MATCH(S16)
1636     CONVERT_IF_TYPES_MATCH(S32)
1637     CONVERT_IF_TYPES_MATCH(S64)
1638     CONVERT_IF_TYPES_MATCH(U8)
1639     CONVERT_IF_TYPES_MATCH(U16)
1640     CONVERT_IF_TYPES_MATCH(U32)
1641     CONVERT_IF_TYPES_MATCH(U64)
1642     CONVERT_IF_TYPES_MATCH(F16)
1643     CONVERT_IF_TYPES_MATCH(F32)
1644     CONVERT_IF_TYPES_MATCH(F64)
1645     CONVERT_IF_TYPES_MATCH(BF16)
1646 #undef CONVERT_IF_TYPES_MATCH
1647     case C64:
1648       if (bitcast) {
1649         break;
1650       }
1651       return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1652     case C128:
1653       if (bitcast) {
1654         break;
1655       }
1656       return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1657     // Other types are not yet supported.
1658     default:
1659       break;
1660   }
1661   return Unimplemented("Converting from type %s to type %s is not implemented.",
1662                        PrimitiveType_Name(src_literal.shape().element_type()),
1663                        PrimitiveType_Name(primitive_dest_type));
1664 }
1665 
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1666 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1667                                 PrimitiveType primitive_dest_type,
1668                                 bool bitcast) {
1669   TF_RET_CHECK(literal.shape().IsArray());
1670   if (literal.shape().element_type() == primitive_dest_type) {
1671     return literal.Clone();
1672   }
1673   switch (literal.shape().element_type()) {
1674 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
1675   case (type):                                                            \
1676     return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1677                                             bitcast);
1678     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1679     CONVERT_IF_DEST_TYPE_MATCHES(S8)
1680     CONVERT_IF_DEST_TYPE_MATCHES(S16)
1681     CONVERT_IF_DEST_TYPE_MATCHES(S32)
1682     CONVERT_IF_DEST_TYPE_MATCHES(S64)
1683     CONVERT_IF_DEST_TYPE_MATCHES(U8)
1684     CONVERT_IF_DEST_TYPE_MATCHES(U16)
1685     CONVERT_IF_DEST_TYPE_MATCHES(U32)
1686     CONVERT_IF_DEST_TYPE_MATCHES(U64)
1687     CONVERT_IF_DEST_TYPE_MATCHES(F16)
1688     CONVERT_IF_DEST_TYPE_MATCHES(F32)
1689     CONVERT_IF_DEST_TYPE_MATCHES(F64)
1690     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1691 #undef CONVERT_IF_DEST_TYPE_MATCHES
1692       // Other types are not yet supported.
1693     default:
1694       return Unimplemented("%s from type %s to type %s is not implemented.",
1695                            (bitcast ? "Bitcast converting" : "Converting"),
1696                            PrimitiveType_Name(literal.shape().element_type()),
1697                            PrimitiveType_Name(primitive_dest_type));
1698   }
1699 }
1700 
1701 }  // namespace
1702 
Convert(PrimitiveType primitive_dest_type) const1703 StatusOr<Literal> LiteralBase::Convert(
1704     PrimitiveType primitive_dest_type) const {
1705   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1706 }
1707 
BitcastConvert(const Shape & dest_shape) const1708 StatusOr<Literal> LiteralBase::BitcastConvert(const Shape& dest_shape) const {
1709   if (ShapeUtil::ByteSizeOf(dest_shape) != ShapeUtil::ByteSizeOf(shape())) {
1710     return InvalidArgument(
1711         "Can not bitcast-convert from shape %s to a shape of different size %s",
1712         shape().ToString(), dest_shape.ToString());
1713   }
1714   if (dest_shape.IsTuple() || shape().IsTuple()) {
1715     return InvalidArgument(
1716         "bitcast-convert is not valid for tuple shapes %s->%s",
1717         shape().ToString(), dest_shape.ToString());
1718   }
1719   if (shape().is_dynamic() || dest_shape.is_dynamic()) {
1720     return InvalidArgument(
1721         "bitcast-convert is not valid for dynamic shape %s->%s",
1722         shape().ToString(), dest_shape.ToString());
1723   }
1724 
1725   Literal out(dest_shape);
1726   std::memcpy(out.root_piece_.buffer(), root_piece().buffer(),
1727               root_piece().size_bytes());
1728   return out;
1729 }
1730 
ConvertToShape(const Shape & dest_shape) const1731 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1732   if (!dest_shape.IsTuple()) {
1733     return Convert(dest_shape.element_type());
1734   }
1735   std::vector<Literal> elements;
1736   const auto tuple_element_count = ShapeUtil::TupleElementCount(shape());
1737   elements.reserve(tuple_element_count);
1738   for (int i = 0; i < tuple_element_count; ++i) {
1739     auto element = LiteralSlice(*this, {i});
1740     TF_ASSIGN_OR_RETURN(
1741         auto new_element,
1742         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1743     elements.push_back(std::move(new_element));
1744   }
1745   return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1746 }
1747 
MoveIntoTuple(absl::Span<Literal> elements)1748 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1749     absl::Span<Literal> elements) {
1750   std::vector<const Shape*> element_shapes;
1751   element_shapes.reserve(elements.size());
1752   for (const Literal& element : elements) {
1753     element_shapes.push_back(&element.shape());
1754   }
1755   Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes),
1756                   /*allocate_arrays=*/false);
1757   for (int i = 0, end = elements.size(); i < end; ++i) {
1758     TF_CHECK_OK(
1759         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1760   }
1761   return literal;
1762 }
1763 
1764 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1765 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1766     const LiteralBase::Piece& src) {
1767   auto dest_shape = subshape();
1768   auto src_shape = src.subshape();
1769 
1770   // At least one shape has to be static as bound.
1771   CHECK(dest_shape.is_static() || src_shape.is_static());
1772   auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1773   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1774     return;
1775   }
1776   std::vector<int64_t> index(dest_shape.rank());
1777   do {
1778     bool out_of_bound = false;
1779     for (int64_t i = 0; i < index.size(); ++i) {
1780       // Do not copy elements beyond dynamic bound.
1781       if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1782         out_of_bound = true;
1783       }
1784     }
1785     if (out_of_bound) {
1786       continue;
1787     }
1788     data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1789                                                                   index)] =
1790         src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1791             src_shape, index)];
1792   } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1793 }
1794 
1795 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64_t> * multi_index) const1796 bool LiteralBase::Piece::EqualElementsInternal(
1797     const LiteralBase::Piece& other, std::vector<int64_t>* multi_index) const {
1798   if (multi_index->size() == subshape().rank()) {
1799     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1800   }
1801   for (int64_t i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1802     multi_index->push_back(i);
1803     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1804       return false;
1805     }
1806     multi_index->pop_back();
1807   }
1808   return true;
1809 }
1810 
EqualDynamicSize(const LiteralBase::Piece & other) const1811 bool LiteralBase::Piece::EqualDynamicSize(
1812     const LiteralBase::Piece& other) const {
1813   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1814   if (subshape().is_static()) {
1815     return true;
1816   }
1817 
1818   for (int64_t i = 0; i < subshape().rank(); ++i) {
1819     if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1820       return false;
1821     }
1822   }
1823   return true;
1824 }
1825 
EqualElements(const LiteralBase::Piece & other) const1826 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1827   if (subshape().is_static() &&
1828       ShapeUtil::Equal(subshape(), other.subshape()) &&
1829       LayoutUtil::IsDenseArray(subshape())) {
1830     CHECK_EQ(size_bytes(), other.size_bytes());
1831     return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1832   }
1833 
1834   std::vector<int64_t> multi_index;
1835   switch (subshape().element_type()) {
1836     case PRED:
1837       return EqualElementsInternal<bool>(other, &multi_index);
1838     case S8:
1839       return EqualElementsInternal<int8_t>(other, &multi_index);
1840     case S16:
1841       return EqualElementsInternal<int16_t>(other, &multi_index);
1842     case S32:
1843       return EqualElementsInternal<int32_t>(other, &multi_index);
1844     case S64:
1845       return EqualElementsInternal<int64_t>(other, &multi_index);
1846     case U8:
1847       return EqualElementsInternal<uint8_t>(other, &multi_index);
1848     case U16:
1849       return EqualElementsInternal<uint16_t>(other, &multi_index);
1850     case U32:
1851       return EqualElementsInternal<uint32_t>(other, &multi_index);
1852     case U64:
1853       return EqualElementsInternal<uint64_t>(other, &multi_index);
1854     case F32:
1855       return EqualElementsInternal<float>(other, &multi_index);
1856     case F64:
1857       return EqualElementsInternal<double>(other, &multi_index);
1858     case F16:
1859       return EqualElementsInternal<half>(other, &multi_index);
1860     case BF16:
1861       return EqualElementsInternal<bfloat16>(other, &multi_index);
1862     case C64:
1863       return EqualElementsInternal<complex64>(other, &multi_index);
1864     case C128:
1865       return EqualElementsInternal<complex128>(other, &multi_index);
1866     default:
1867       LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1868                  << PrimitiveType_Name(subshape().element_type());
1869   }
1870 }
1871 
operator ==(const LiteralBase & other) const1872 bool LiteralBase::operator==(const LiteralBase& other) const {
1873   // Checking the structure of tuple literals. Checks for dense arrays are
1874   // performed below.
1875   if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1876     return false;
1877   }
1878 
1879   return root_piece().ForEachSubpieceWithBool(
1880       [&](const ShapeIndex& index, const Piece& piece) {
1881         const Piece& other_piece = other.piece(index);
1882         const Shape& subshape = piece.subshape();
1883         const Shape& other_subshape = other_piece.subshape();
1884         if (subshape.element_type() != other_subshape.element_type()) {
1885           return false;
1886         }
1887         if (!piece.subshape().IsArray()) {
1888           return true;
1889         }
1890         if (subshape.rank() != other_subshape.rank()) {
1891           return false;
1892         }
1893 
1894         for (int64_t i = 0; i < subshape.rank(); ++i) {
1895           if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1896             return false;
1897           }
1898         }
1899 
1900         if (!piece.EqualElements(other_piece)) {
1901           return false;
1902         }
1903         return true;
1904       });
1905 }
1906 
1907 template <typename NativeT>
EqualIncludingNan(NativeT a,NativeT b)1908 static bool EqualIncludingNan(NativeT a, NativeT b) {
1909   // msvc can't compile std::isnan(a) where `a` is uint8_t.  This is a bug
1910   // according to https://en.cppreference.com/w/cpp/numeric/math/isnan, but it's
1911   // easy to work around.
1912   return a == b || (std::isnan(static_cast<double>(a)) &&
1913                     std::isnan(static_cast<double>(b)));
1914 }
1915 
1916 template <typename T>
EqualIncludingNan(std::complex<T> a,std::complex<T> b)1917 static bool EqualIncludingNan(std::complex<T> a, std::complex<T> b) {
1918   return EqualIncludingNan(a.real(), b.real()) &&
1919          EqualIncludingNan(a.imag(), b.imag());
1920 }
1921 
1922 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1923 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1924                                   NativeT value) {
1925   for (int64_t i = 0; i < data.size(); ++i) {
1926     if (!EqualIncludingNan(data[i], value)) {
1927       return false;
1928     }
1929   }
1930   return true;
1931 }
1932 
IsAll(const Literal & scalar) const1933 bool Literal::Piece::IsAll(const Literal& scalar) const {
1934   CHECK(ShapeUtil::IsScalar(scalar.shape())) << scalar.shape().ToString();
1935   if (!subshape().IsArray()) {
1936     return false;
1937   }
1938 
1939   CHECK_EQ(subshape().element_type(), scalar.shape().element_type());
1940   switch (subshape().element_type()) {
1941     case U8:
1942       return AllElementsEqualValue(data<uint8_t>(),
1943                                    scalar.GetFirstElement<uint8_t>());
1944     case U16:
1945       return AllElementsEqualValue(data<uint16_t>(),
1946                                    scalar.GetFirstElement<uint16_t>());
1947     case U32:
1948       return AllElementsEqualValue(data<uint32_t>(),
1949                                    scalar.GetFirstElement<uint32_t>());
1950     case U64:
1951       return AllElementsEqualValue(data<uint64_t>(),
1952                                    scalar.GetFirstElement<uint64_t>());
1953     case S8:
1954       return AllElementsEqualValue(data<int8_t>(),
1955                                    scalar.GetFirstElement<int8_t>());
1956     case S16:
1957       return AllElementsEqualValue(data<int16_t>(),
1958                                    scalar.GetFirstElement<int16_t>());
1959     case S32:
1960       return AllElementsEqualValue(data<int32_t>(),
1961                                    scalar.GetFirstElement<int32_t>());
1962     case S64:
1963       return AllElementsEqualValue(data<int64_t>(),
1964                                    scalar.GetFirstElement<int64_t>());
1965     case PRED:
1966       return AllElementsEqualValue(data<bool>(),
1967                                    scalar.GetFirstElement<bool>());
1968     case F16:
1969       return AllElementsEqualValue(data<half>(),
1970                                    scalar.GetFirstElement<half>());
1971     case BF16:
1972       return AllElementsEqualValue(data<bfloat16>(),
1973                                    scalar.GetFirstElement<bfloat16>());
1974     case F32:
1975       return AllElementsEqualValue(data<float>(),
1976                                    scalar.GetFirstElement<float>());
1977     case F64:
1978       return AllElementsEqualValue(data<double>(),
1979                                    scalar.GetFirstElement<double>());
1980     case C64:
1981       return AllElementsEqualValue(data<complex64>(),
1982                                    scalar.GetFirstElement<complex64>());
1983     case C128:
1984       return AllElementsEqualValue(data<complex128>(),
1985                                    scalar.GetFirstElement<complex128>());
1986     default:
1987       return false;
1988   }
1989 }
1990 
IsAll(const Literal & scalar) const1991 bool LiteralBase::IsAll(const Literal& scalar) const {
1992   return root_piece().IsAll(scalar);
1993 }
1994 
IsAll(int8_t value) const1995 bool LiteralBase::IsAll(int8_t value) const {
1996   if (!shape().IsArray()) {
1997     return false;
1998   }
1999   PrimitiveType ty = shape().element_type();
2000   if (primitive_util::IsFloatingPointType(ty)) {
2001     return IsAllFloat(value);
2002   }
2003   if (primitive_util::IsUnsignedIntegralType(ty) && value < 0) {
2004     return false;
2005   }
2006   Literal scalar(ShapeUtil::MakeScalarShape(ty));
2007   switch (ty) {
2008     case U8:
2009       scalar.Set<uint8_t>({}, value);
2010       break;
2011     case U16:
2012       scalar.Set<uint16_t>({}, value);
2013       break;
2014     case U32:
2015       scalar.Set<uint32_t>({}, value);
2016       break;
2017     case U64:
2018       scalar.Set<uint64_t>({}, value);
2019       break;
2020     case S8:
2021       scalar.Set<int8_t>({}, value);
2022       break;
2023     case S16:
2024       scalar.Set<int16_t>({}, value);
2025       break;
2026     case S32:
2027       scalar.Set<int32_t>({}, value);
2028       break;
2029     case S64:
2030       scalar.Set<int64_t>({}, value);
2031       break;
2032     case PRED:
2033       if (value == 0) {
2034         scalar.Set<bool>({}, false);
2035       } else if (value == 1) {
2036         scalar.Set<bool>({}, true);
2037       } else {
2038         return false;
2039       }
2040       break;
2041     default:
2042       return false;
2043   }
2044   return root_piece().IsAll(scalar);
2045 }
2046 
IsAllFloat(float value) const2047 bool LiteralBase::IsAllFloat(float value) const {
2048   if (!shape().IsArray()) {
2049     return false;
2050   }
2051   PrimitiveType ty = shape().element_type();
2052   Literal scalar(ShapeUtil::MakeScalarShape(ty));
2053   switch (ty) {
2054     case F16:
2055       scalar.Set<half>({}, static_cast<half>(value));
2056       break;
2057     case BF16:
2058       scalar.Set<bfloat16>({}, static_cast<bfloat16>(value));
2059       break;
2060     case F32:
2061       scalar.Set<float>({}, value);
2062       break;
2063     case F64:
2064       scalar.Set<double>({}, value);
2065       break;
2066     default:
2067       return false;
2068   }
2069   return root_piece().IsAll(scalar);
2070 }
2071 
IsAllComplex(complex64 value) const2072 bool LiteralBase::IsAllComplex(complex64 value) const {
2073   if (!shape().IsArray()) {
2074     return false;
2075   }
2076   PrimitiveType ty = shape().element_type();
2077   Literal scalar(ShapeUtil::MakeScalarShape(ty));
2078   switch (ty) {
2079     case C64:
2080       scalar.Set<complex64>({}, value);
2081       break;
2082     case C128:
2083       scalar.Set<complex128>({}, value);
2084       break;
2085     default:
2086       return false;
2087   }
2088   return root_piece().IsAll(scalar);
2089 }
2090 
IsAllFirst() const2091 bool LiteralBase::IsAllFirst() const {
2092   if (!shape().IsArray()) {
2093     return false;
2094   }
2095 
2096   // Empty shapes are not all the first element since there is no first element.
2097   if (ShapeUtil::IsZeroElementArray(shape())) {
2098     return false;
2099   }
2100 
2101   absl::InlinedVector<int64_t, 4> start_indices(/*n=*/shape().rank(), 0);
2102   absl::InlinedVector<int64_t, 4> end_indices(/*n=*/shape().rank(), 1);
2103   Literal first = Slice(start_indices, end_indices);
2104   return IsAll(first.Reshape({}).ValueOrDie());
2105 }
2106 
IsR1Iota() const2107 bool LiteralBase::IsR1Iota() const {
2108   if (!shape().IsArray()) {
2109     return false;
2110   }
2111 
2112   if (shape().rank() != 1) {
2113     return false;
2114   }
2115 
2116   auto is_iota_at_idx = [&](const int64_t idx) {
2117     switch (shape().element_type()) {
2118       case U8:
2119         return static_cast<int64_t>(Get<uint8_t>({idx})) == idx;
2120       case U16:
2121         return static_cast<int64_t>(Get<uint16_t>({idx})) == idx;
2122       case U32:
2123         return static_cast<int64_t>(Get<uint32_t>({idx})) == idx;
2124       case U64:
2125         return static_cast<int64_t>(Get<uint64_t>({idx})) == idx;
2126       case S8:
2127         return Get<int8_t>({idx}) == idx;
2128       case S16:
2129         return Get<int16_t>({idx}) == idx;
2130       case S32:
2131         return Get<int32_t>({idx}) == idx;
2132       case S64:
2133         return Get<int64_t>({idx}) == idx;
2134       case F32:
2135         return Get<float>({idx}) == idx;
2136       case F64:
2137         return Get<double>({idx}) == idx;
2138       case F16:
2139         return Get<half>({idx}) == static_cast<half>(idx);
2140       case BF16:
2141         return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
2142       case C64:
2143         return Get<complex64>({idx}) == complex64(idx, 0.0f);
2144       case C128:
2145         return Get<complex128>({idx}) == complex128(idx, 0.0f);
2146       // pred, token, opaque, tuple, etc. are all not iota.
2147       default:
2148         return false;
2149     }
2150   };
2151 
2152   const int64_t elements = ShapeUtil::ElementsIn(shape());
2153   for (int64_t idx = 0; idx < elements; ++idx) {
2154     if (!is_iota_at_idx(idx)) {
2155       return false;
2156     }
2157   }
2158 
2159   return true;
2160 }
2161 
2162 // Returns a stride if the literal is a strided iota, i.e., iota multiplied by a
2163 // stride. Only applicable for integer iotas. Returns std::nullopt if the
2164 // literal is not a strided iota.
IsR1StridedIota() const2165 std::optional<int64_t> LiteralBase::IsR1StridedIota() const {
2166   if (!shape().IsArray() || shape().rank() != 1) {
2167     return std::nullopt;
2168   }
2169 
2170   const int64_t elements = ShapeUtil::ElementsIn(shape());
2171   const PrimitiveType type = shape().element_type();
2172   if (elements <= 1 || !primitive_util::IsIntegralType(type)) {
2173     return std::nullopt;
2174   }
2175 
2176   auto get_element_at = [&](const int64_t idx) -> int64_t {
2177     switch (type) {
2178       case U8:
2179         return static_cast<int64_t>(Get<uint8_t>({idx}));
2180       case U16:
2181         return static_cast<int64_t>(Get<uint16_t>({idx}));
2182       case U32:
2183         return static_cast<int64_t>(Get<uint32_t>({idx}));
2184       case U64:
2185         return static_cast<int64_t>(Get<uint64_t>({idx}));
2186       case S8:
2187         return Get<int8_t>({idx});
2188       case S16:
2189         return Get<int16_t>({idx});
2190       case S32:
2191         return Get<int32_t>({idx});
2192       case S64:
2193         return Get<int64_t>({idx});
2194       default:
2195         CHECK(0);
2196         return 0;
2197     }
2198   };
2199 
2200   // Infer the stride as the second element (since first element is supposed
2201   // to be zero).
2202   int64_t stride = get_element_at(1);
2203   if (stride == 0) {
2204     return std::nullopt;
2205   }
2206 
2207   for (int64_t idx = 0; idx < elements; ++idx) {
2208     if (get_element_at(idx) != idx * stride) {
2209       return std::nullopt;
2210     }
2211   }
2212 
2213   return stride;
2214 }
2215 
IsZero(absl::Span<const int64_t> indices) const2216 bool LiteralBase::IsZero(absl::Span<const int64_t> indices) const {
2217   CHECK(shape().IsArray());
2218   switch (shape().element_type()) {
2219     case U8:
2220       return Get<uint8_t>(indices) == 0;
2221     case U16:
2222       return Get<uint16_t>(indices) == 0;
2223     case U32:
2224       return Get<uint32_t>(indices) == 0;
2225     case U64:
2226       return Get<uint64_t>(indices) == 0;
2227     case S8:
2228       return Get<int8_t>(indices) == 0;
2229     case S16:
2230       return Get<int16_t>(indices) == 0;
2231     case S32:
2232       return Get<int32_t>(indices) == 0;
2233     case S64:
2234       return Get<int64_t>(indices) == 0;
2235     case F32:
2236       return Get<float>(indices) == 0.0f;
2237     case F64:
2238       return Get<double>(indices) == 0.0;
2239     case C64:
2240       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
2241     case C128:
2242       return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2243     case F16:
2244       return Get<half>(indices) == static_cast<half>(0.0f);
2245     case BF16:
2246       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2247     case PRED:
2248       return Get<bool>(indices) == false;
2249     default:
2250       LOG(FATAL) << "Input literal must be an array.";
2251   }
2252 }
2253 
2254 namespace {
2255 
2256 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2257 void CopyToRepeatedField(RepeatedFieldT* dest,
2258                          const absl::Span<const NativeT> src) {
2259   *dest = RepeatedFieldT(src.begin(), src.end());
2260 }
2261 
2262 }  // namespace
2263 
set_array_value_state(ArrayValueState state)2264 void LiteralBase::Piece::set_array_value_state(ArrayValueState state) {
2265   array_value_state_ = state;
2266 }
2267 
get_array_value_state() const2268 LiteralBase::ArrayValueState LiteralBase::Piece::get_array_value_state() const {
2269   return array_value_state_;
2270 }
2271 
WriteToProto(LiteralProto * proto) const2272 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2273   *proto->mutable_shape() = subshape().ToProto();
2274   switch (subshape().element_type()) {
2275     case PRED:
2276       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2277       break;
2278     case S8:
2279       proto->set_s8s(static_cast<const signed char*>(data<int8_t>().data()),
2280                      element_count());
2281       break;
2282     case U8:
2283       proto->set_u8s(static_cast<const unsigned char*>(data<uint8_t>().data()),
2284                      element_count());
2285       break;
2286     case U32:
2287       CopyToRepeatedField(proto->mutable_u32s(), data<uint32_t>());
2288       break;
2289     case U64:
2290       CopyToRepeatedField(proto->mutable_u64s(), data<uint64_t>());
2291       break;
2292     case S32:
2293       CopyToRepeatedField(proto->mutable_s32s(), data<int32_t>());
2294       break;
2295     case S64:
2296       CopyToRepeatedField(proto->mutable_s64s(), data<int64_t>());
2297       break;
2298     case U16:
2299       *proto->mutable_u16s() = std::string(
2300           reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2301       if (!kLittleEndian) {
2302         ConvertEndianShort(proto->mutable_u16s());
2303       }
2304       break;
2305     case S16:
2306       *proto->mutable_s16s() = std::string(
2307           reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2308       if (!kLittleEndian) {
2309         ConvertEndianShort(proto->mutable_s16s());
2310       }
2311       break;
2312     case F16:
2313       *proto->mutable_f16s() = std::string(
2314           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2315       if (!kLittleEndian) {
2316         ConvertEndianShort(proto->mutable_f16s());
2317       }
2318       break;
2319     case BF16:
2320       *proto->mutable_bf16s() = std::string(
2321           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2322       if (!kLittleEndian) {
2323         ConvertEndianShort(proto->mutable_bf16s());
2324       }
2325       break;
2326     case F32:
2327       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2328       break;
2329     case F64:
2330       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2331       break;
2332     case C64:
2333       for (complex64 value : data<complex64>()) {
2334         proto->add_c64s(value.real());
2335         proto->add_c64s(value.imag());
2336       }
2337       break;
2338     case C128:
2339       for (complex128 value : data<complex128>()) {
2340         proto->add_c128s(value.real());
2341         proto->add_c128s(value.imag());
2342       }
2343       break;
2344     case TUPLE:
2345     case TOKEN:
2346       // Nothing to do but assign the shape which is done above.
2347       return;
2348     default:
2349       // TODO(b/111551621): Support serializing more PrimitiveTypes.
2350       LOG(FATAL) << "Unhandled primitive type "
2351                  << PrimitiveType_Name(subshape().element_type());
2352   }
2353 }
2354 
untyped_data() const2355 const void* LiteralBase::Piece::untyped_data() const {
2356   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2357   return buffer();
2358 }
2359 
untyped_data()2360 void* LiteralBase::Piece::untyped_data() {
2361   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2362   return buffer();
2363 }
2364 
2365 namespace {
2366 
2367 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2368 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2369                              const RepeatedFieldT& src) {
2370   if (dest.size() != src.size()) {
2371     return InvalidArgument(
2372         "Expected %lu elements in LiteralProto repeated field, has %d",
2373         dest.size(), src.size());
2374   }
2375   std::copy(src.begin(), src.end(), dest.begin());
2376   return OkStatus();
2377 }
2378 
2379 }  // namespace
2380 
CopyFromProto(const LiteralProto & proto)2381 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2382   // These conditions should have been checked in
2383   // MutableLiteralBase::CreateFromProto.
2384   TF_RET_CHECK(proto.has_shape());
2385   Shape shape(proto.shape());
2386   TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2387   TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2388 
2389   switch (subshape().element_type()) {
2390     case PRED:
2391       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2392       break;
2393     case S8: {
2394       auto s8_data = data<int8_t>();
2395       TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2396       std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2397     } break;
2398     case U8: {
2399       auto u8_data = data<uint8_t>();
2400       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2401       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2402     } break;
2403     case S32:
2404       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32_t>(), proto.s32s()));
2405       break;
2406     case S64:
2407       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64_t>(), proto.s64s()));
2408       break;
2409     case U32:
2410       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32_t>(), proto.u32s()));
2411       break;
2412     case U64:
2413       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64_t>(), proto.u64s()));
2414       break;
2415     case S16: {
2416       const std::string& s(proto.s16s());
2417       TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2418       memcpy(untyped_data(), s.data(), s.size());
2419       if (!kLittleEndian) {
2420         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2421       }
2422     } break;
2423     case U16: {
2424       const std::string& s(proto.u16s());
2425       TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2426       memcpy(untyped_data(), s.data(), s.size());
2427       if (!kLittleEndian) {
2428         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2429       }
2430     } break;
2431     case F16: {
2432       const std::string& s(proto.f16s());
2433       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2434       memcpy(untyped_data(), s.data(), s.size());
2435       if (!kLittleEndian) {
2436         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2437       }
2438     } break;
2439 
2440     case BF16: {
2441       const std::string& s(proto.bf16s());
2442       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2443       memcpy(untyped_data(), s.data(), s.size());
2444       if (!kLittleEndian) {
2445         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2446       }
2447     } break;
2448     case F32:
2449       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2450       break;
2451     case F64:
2452       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2453       break;
2454     case C64: {
2455       auto complex_data = data<complex64>();
2456       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2457       for (int64_t i = 0; i < complex_data.size(); ++i) {
2458         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2459       }
2460       break;
2461     }
2462     case C128: {
2463       auto complex_data = data<complex128>();
2464       const int64_t complex_data_size_doubled = complex_data.size() * 2;
2465       TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2466       for (int64_t i = 0, end = complex_data.size(); i < end; ++i) {
2467         complex_data[i] =
2468             complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2469       }
2470       break;
2471     }
2472     case TUPLE:
2473       return InvalidArgument("Should not be called on tuple shapes: %s",
2474                              ShapeUtil::HumanString(subshape()));
2475     default:
2476       return InvalidArgument("Is called on unsupported shape: %s",
2477                              ShapeUtil::HumanString(subshape()));
2478   }
2479   return OkStatus();
2480 }
2481 
IsKnown() const2482 bool LiteralBase::Piece::IsKnown() const {
2483   if (array_value_state_ != ArrayValueState::kKnown) {
2484     return false;
2485   }
2486   if (subshape().IsTuple()) {
2487     bool are_all_leaf_arrays_known = true;
2488     ForEachSubpiece([&are_all_leaf_arrays_known](const ShapeIndex& index,
2489                                                  const Piece& piece) {
2490       if (!piece.subshape().IsArray()) {
2491         return;
2492       }
2493       are_all_leaf_arrays_known &= piece.IsKnown();
2494     });
2495     return are_all_leaf_arrays_known;
2496   }
2497   return true;
2498 }
2499 
IsDetermined() const2500 bool LiteralBase::Piece::IsDetermined() const {
2501   if (array_value_state_ == ArrayValueState::kUndetermined) {
2502     return false;
2503   }
2504   if (subshape().IsTuple()) {
2505     bool are_all_leaf_arrays_determined = true;
2506     ForEachSubpiece([&are_all_leaf_arrays_determined](const ShapeIndex& index,
2507                                                       const Piece& piece) {
2508       if (!piece.subshape().IsArray()) {
2509         return;
2510       }
2511       are_all_leaf_arrays_determined &= piece.IsDetermined();
2512     });
2513     return are_all_leaf_arrays_determined;
2514   }
2515   return true;
2516 }
2517 
ToProto() const2518 LiteralProto LiteralBase::ToProto() const {
2519   LiteralProto proto;
2520   root_piece().ForEachSubpiece(
2521       [&](const ShapeIndex& index, const Piece& piece) {
2522         LiteralProto* proto_piece = &proto;
2523         for (int64_t i : index) {
2524           while (proto_piece->tuple_literals_size() <= i) {
2525             proto_piece->add_tuple_literals();
2526           }
2527           proto_piece = proto_piece->mutable_tuple_literals(i);
2528         }
2529         piece.WriteToProto(proto_piece);
2530       });
2531 
2532   return proto;
2533 }
2534 
untyped_data(const ShapeIndex & shape_index) const2535 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2536   return piece(shape_index).untyped_data();
2537 }
2538 
untyped_data(const ShapeIndex & shape_index)2539 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2540   return piece(shape_index).untyped_data();
2541 }
2542 
size_bytes(const ShapeIndex & shape_index) const2543 int64_t LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2544   return piece(shape_index).size_bytes();
2545 }
2546 
GetR1U8AsString() const2547 std::string LiteralBase::GetR1U8AsString() const {
2548   CHECK(shape().IsArray());
2549   CHECK_EQ(shape().rank(), 1);
2550   CHECK_EQ(shape().element_type(), U8);
2551   return std::string(absl::bit_cast<const char*>(data<uint8_t>().data()),
2552                      ShapeUtil::ElementsIn(shape()));
2553 }
2554 
CopyPieceSubtree(const Shape & shape,const Piece * src_piece,Piece * dest_piece)2555 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2556                                                const Piece* src_piece,
2557                                                Piece* dest_piece) {
2558   DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2559       << "src_piece has shape: "
2560       << ShapeUtil::HumanString(src_piece->subshape())
2561       << "dest_piece has shape: "
2562       << ShapeUtil::HumanString(dest_piece->subshape());
2563   dest_piece->set_array_value_state(src_piece->get_array_value_state());
2564   if (shape.IsTuple()) {
2565     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2566       const Shape& subshape = shape.tuple_shapes(i);
2567 
2568       auto child_piece = Piece();
2569       child_piece.set_subshape(&subshape);
2570 
2571       CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2572 
2573       dest_piece->emplace_back(std::move(child_piece));
2574     }
2575   } else if (shape.IsArray()) {
2576     dest_piece->set_buffer(const_cast<char*>(src_piece->buffer()));
2577   } else {
2578     // If the shape is neither an array nor tuple, then it must be
2579     // zero-sized. Otherwise, some memory needs to be allocated for it.
2580     CHECK_EQ(dest_piece->size_bytes(), 0);
2581   }
2582 }
2583 
~MutableLiteralBase()2584 MutableLiteralBase::~MutableLiteralBase() {}
2585 
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2586 MutableBorrowingLiteral::MutableBorrowingLiteral(
2587     const MutableBorrowingLiteral& literal)
2588     : MutableLiteralBase() {
2589   shape_ = literal.shape_.Clone();
2590   CHECK(LayoutUtil::HasLayout(*shape_));
2591 
2592   root_piece_ = new Piece();
2593   root_piece_->set_subshape(shape_.get());
2594 
2595   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2596 }
2597 
operator =(const MutableBorrowingLiteral & literal)2598 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2599     const MutableBorrowingLiteral& literal) {
2600   shape_ = literal.shape_.Clone();
2601   CHECK(LayoutUtil::HasLayout(*shape_));
2602 
2603   root_piece_ = new Piece();
2604   root_piece_->set_subshape(shape_.get());
2605 
2606   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2607 
2608   return *this;
2609 }
2610 
MutableBorrowingLiteral(MutableLiteralBase * literal)2611 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2612     : MutableLiteralBase() {
2613   shape_ = literal->shape_.Clone();
2614   CHECK(LayoutUtil::HasLayout(*shape_));
2615 
2616   root_piece_ = new Piece();
2617   root_piece_->set_subshape(shape_.get());
2618 
2619   CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2620 }
2621 
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2622 MutableBorrowingLiteral::MutableBorrowingLiteral(
2623     MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2624     : MutableLiteralBase() {
2625   shape_ = std::make_unique<Shape>(literal.piece(view_root).subshape());
2626   CHECK(LayoutUtil::HasLayout(*shape_));
2627 
2628   root_piece_ = new Piece();
2629   root_piece_->set_subshape(shape_.get());
2630 
2631   CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2632 }
2633 
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2634 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2635                                                  const Shape& shape)
2636     : MutableLiteralBase() {
2637   shape_ = std::make_unique<Shape>(shape);
2638   CHECK(LayoutUtil::HasLayout(*shape_));
2639   CHECK(!shape_->IsTuple());
2640 
2641   root_piece_ = new Piece();
2642   root_piece_->set_subshape(shape_.get());
2643   root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2644 }
2645 
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2646 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2647                                                  const Shape& shape)
2648     : MutableLiteralBase() {
2649   shape_ = std::make_unique<Shape>(shape);
2650   if (!shape_->IsTuple()) {
2651     CHECK_EQ(src_buf_ptrs.size(), 1);
2652     root_piece_ = new Piece();
2653     root_piece_->set_subshape(shape_.get());
2654     root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2655   } else {
2656     CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2657     CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2658     root_piece_ = new Piece();
2659     root_piece_->set_subshape(shape_.get());
2660 
2661     for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2662       Piece child_piece;
2663       const auto& src_shape = shape_->tuple_shapes(i);
2664       CHECK(src_shape.IsArray());
2665       child_piece.set_subshape(&src_shape);
2666       child_piece.set_buffer(src_buf_ptrs[i]);
2667       root_piece_->emplace_back(std::move(child_piece));
2668     }
2669   }
2670 }
2671 
~MutableBorrowingLiteral()2672 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2673   if (root_piece_ != nullptr) {
2674     delete root_piece_;
2675   }
2676 }
2677 
LiteralSlice(const LiteralBase & literal)2678 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2679     : LiteralBase(), root_piece_(&literal.root_piece()) {}
2680 
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2681 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2682                            const ShapeIndex& view_root)
2683     : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2684 
BuildPieceSubtree(const Shape & shape,Piece * piece)2685 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2686   CHECK(shape.IsTuple());
2687   for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2688     const Shape& subshape = shape.tuple_shapes(i);
2689 
2690     auto child_piece = Piece();
2691     child_piece.set_subshape(&subshape);
2692 
2693     if (subshape.IsTuple()) {
2694       BuildPieceSubtree(subshape, &child_piece);
2695     }
2696 
2697     piece->emplace_back(std::move(child_piece));
2698   }
2699 }
2700 
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2701 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2702     : LiteralBase(), shape_(std::make_unique<Shape>(shape)) {
2703   CHECK(shape_->IsArray());
2704   CHECK(LayoutUtil::HasLayout(*shape_));
2705 
2706   root_piece_ = Piece();
2707   root_piece_.set_subshape(shape_.get());
2708   root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2709 }
2710 
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2711 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2712                                    const Shape& shape)
2713     : LiteralBase(), shape_(std::make_unique<Shape>(shape)) {
2714   CHECK(shape_->IsTuple());
2715   CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2716   CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2717   root_piece_ = Piece();
2718   root_piece_.set_subshape(shape_.get());
2719   BuildPieceSubtree(*shape_, &root_piece_);
2720 
2721   for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2722     const auto& src_shape = shape_->tuple_shapes(i);
2723     CHECK(src_shape.IsArray());
2724     root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2725   }
2726 }
2727 
2728 }  // namespace xla
2729