1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/base/casts.h"
31 #include "absl/hash/hash.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/mem.h"
48
49 namespace xla {
50 namespace {
51
52 using absl::StrCat;
53
54 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
55 // Literals can be used as DMA targets, which can require alignment. We
56 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
57 // alignment.
58 constexpr int kMinimumAlignment = 64;
59
60 // Converts between little and big endian.
61 //
62 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(std::string * bytes)63 void ConvertEndianShort(std::string* bytes) {
64 CHECK_EQ(bytes->size() % 2, 0);
65 for (int64_t i = 0, end = bytes->size(); i < end; i += 2) {
66 std::swap((*bytes)[i], (*bytes)[i + 1]);
67 }
68 }
69
ConvertEndianShort(char * bytes,int64_t size)70 void ConvertEndianShort(char* bytes, int64_t size) {
71 CHECK_EQ(size % 2, 0);
72 for (int64_t i = 0; i < size; i += 2) {
73 std::swap(bytes[i], bytes[i + 1]);
74 }
75 }
76
CompactOneline(const std::string & input)77 std::string CompactOneline(const std::string& input) {
78 std::string result;
79 std::vector<std::string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
80 bool first = true;
81 // Concatenate elements in "v" with spaces separating them, but ignoring
82 // empty entries.
83 for (const auto& s : v) {
84 if (s.empty()) {
85 continue;
86 }
87 absl::StrAppend(&result, (first ? "" : " "), s);
88 first = false;
89 }
90 return result;
91 }
92
93 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
94 // able to transparently access the raw 16-bit value contained within.
95 template <typename T>
GetRawValue(T val)96 T GetRawValue(T val) {
97 return val;
98 }
GetRawValue(Eigen::half val)99 uint16_t GetRawValue(Eigen::half val) {
100 return Eigen::numext::bit_cast<uint16_t>(val);
101 }
102
LiteralProtoHasValues(const LiteralProto & proto)103 bool LiteralProtoHasValues(const LiteralProto& proto) {
104 return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
105 proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
106 proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
107 proto.c64s_size() || proto.c128s_size() ||
108 proto.tuple_literals_size() || !proto.f16s().empty() ||
109 !proto.bf16s().empty() || !proto.u16s().empty() ||
110 !proto.s16s().empty();
111 }
112
113 // Lazy getter for the interned scalar shape in static storage. We reuse this
114 // shape pointer to when constructing scalar Literals, which can happen a lot
115 // when we are evaluating reduce-like ops in HloEvalutator, and copying the
116 // shape over and over again significantly slows down the evaluator.
117 template <PrimitiveType kType>
ScalarShapeImpl()118 const Shape& ScalarShapeImpl() {
119 static_assert(primitive_util::IsArrayType(kType),
120 "Not a valid type for a scalar.");
121 static const Shape* shape = [] {
122 auto shape = new Shape(kType, {}, {}, {});
123 shape->mutable_layout();
124 return shape;
125 }();
126 return *shape;
127 }
128
ScalarShape(PrimitiveType type)129 const Shape& ScalarShape(PrimitiveType type) {
130 switch (type) {
131 case U8:
132 return ScalarShapeImpl<U8>();
133 case U16:
134 return ScalarShapeImpl<U16>();
135 case U32:
136 return ScalarShapeImpl<U32>();
137 case U64:
138 return ScalarShapeImpl<U64>();
139 case S8:
140 return ScalarShapeImpl<S8>();
141 case S16:
142 return ScalarShapeImpl<S16>();
143 case S32:
144 return ScalarShapeImpl<S32>();
145 case S64:
146 return ScalarShapeImpl<S64>();
147 case F16:
148 return ScalarShapeImpl<F16>();
149 case BF16:
150 return ScalarShapeImpl<BF16>();
151 case F32:
152 return ScalarShapeImpl<F32>();
153 case F64:
154 return ScalarShapeImpl<F64>();
155 case C64:
156 return ScalarShapeImpl<C64>();
157 case C128:
158 return ScalarShapeImpl<C128>();
159 case PRED:
160 return ScalarShapeImpl<PRED>();
161 case TUPLE:
162 LOG(FATAL) << "Tuple element type cannot be a scalar type.";
163 case OPAQUE_TYPE:
164 LOG(FATAL) << "Opaque element type cannot be a scalar type.";
165 case TOKEN:
166 LOG(FATAL) << "Token element type cannot be a scalar type.";
167 case PRIMITIVE_TYPE_INVALID:
168 LOG(FATAL) << "Invalid primitive type.";
169 default:
170 LOG(FATAL) << "Unhandled primitive type " << type;
171 }
172 }
173
NilShape()174 const Shape& NilShape() {
175 static const Shape* shape = new Shape(TUPLE, {}, {}, {});
176 return *shape;
177 }
178
179 // Returns the interned shape pointer in static storage if it's a scalar shape
180 // or nil shape.
TryInternShape(const Shape & shape)181 const Shape* TryInternShape(const Shape& shape) {
182 if (shape.IsTuple() && shape.tuple_shapes_size() == 0) {
183 return &NilShape();
184 }
185 if (shape.IsArray() && shape.dimensions_size() == 0 && shape.is_static() &&
186 shape.layout().tiles_size() == 0 && shape.layout().memory_space() == 0) {
187 return &ScalarShape(shape.element_type());
188 }
189 return nullptr;
190 }
191
192 } // namespace
193
~LiteralBase()194 LiteralBase::~LiteralBase() {}
195
operator <<(std::ostream & out,const Literal & literal)196 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
197 out << literal.ToString();
198 return out;
199 }
200
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64_t> dimensions)201 MutableLiteralBase::StrideConfig::StrideConfig(
202 const Shape& source_shape, const Shape& dest_shape,
203 absl::Span<const int64_t> dimensions)
204 : dimensions(dimensions),
205 base(dimensions.size(), 0),
206 step(dimensions.size(), 1) {
207 if (!dimensions.empty()) {
208 // Selects the shape with the largest minor dimension as the one upon
209 // which to run the tight stride loop.
210 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
211 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
212 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
213 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
214 } else {
215 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
216 source_stride =
217 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
218 }
219 minor_loop_size = dimensions[minor_dimension];
220 step[minor_dimension] = minor_loop_size;
221 }
222 }
223
mutable_shape_do_not_use()224 Shape* MutableLiteralBase::mutable_shape_do_not_use() {
225 const Shape* const_shape = shape_.get();
226 Shape* shape = shape_.get_mutable(/*ensure_owned=*/true);
227 if (shape != const_shape) {
228 std::function<void(const Shape&, Piece*)> set_piece_shapes =
229 [&set_piece_shapes](const Shape& shape, Piece* piece) {
230 piece->set_subshape(&shape);
231 if (shape.IsTuple()) {
232 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
233 const Shape& subshape = shape.tuple_shapes(i);
234 set_piece_shapes(subshape, &piece->child(i));
235 }
236 }
237 };
238 set_piece_shapes(*shape, &mutable_root_piece());
239 }
240 return shape;
241 }
242
Literal()243 Literal::Literal() : Literal(NilShape()) {}
244
Literal(const Shape & shape)245 Literal::Literal(const Shape& shape)
246 : Literal(shape, /*allocate_arrays=*/true) {}
247
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays,ArrayValueState leaf_array_value_state)248 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
249 ArrayValueState leaf_array_value_state) {
250 if (shape.IsTuple()) {
251 for (const Shape& subshape : shape.tuple_shapes()) {
252 auto child_piece = Piece();
253 child_piece.set_subshape(&subshape);
254
255 SetPiece(subshape, &child_piece, allocate_arrays, leaf_array_value_state);
256
257 piece->emplace_back(std::move(child_piece));
258 }
259 } else if (shape.IsArray()) {
260 piece->set_array_value_state(leaf_array_value_state);
261 if (leaf_array_value_state == LiteralBase::ArrayValueState::kKnown &&
262 allocate_arrays) {
263 piece->AllocateBuffers();
264 }
265 } else {
266 // If the shape is neither an array nor tuple, then it must be
267 // zero-sized. Otherwise, some memory needs to be allocated for it.
268 CHECK_EQ(piece->size_bytes(), 0);
269 }
270 }
271
Literal(const Shape & shape,bool allocate_arrays,ArrayValueState leaf_array_value_state)272 Literal::Literal(const Shape& shape, bool allocate_arrays,
273 ArrayValueState leaf_array_value_state)
274 : MutableLiteralBase() {
275 if (const Shape* intered_shape_ptr = TryInternShape(shape)) {
276 shape_ = intered_shape_ptr;
277 } else {
278 shape_ = std::make_unique<Shape>(shape);
279 }
280 CHECK(leaf_array_value_state != ArrayValueState::kKnown ||
281 LayoutUtil::HasLayout(*shape_));
282 root_piece_.set_subshape(shape_.get());
283 CHECK(&root_piece_.subshape() == shape_.get());
284
285 SetPiece(*shape_, &root_piece_, allocate_arrays, leaf_array_value_state);
286 }
287
~Literal()288 Literal::~Literal() { DeallocateBuffers(); }
289
DeallocateBuffers()290 void Literal::DeallocateBuffers() {
291 root_piece_.ForEachMutableSubpiece(
292 [&](const ShapeIndex& index, Piece* piece) {
293 piece->DeallocateBuffers();
294 });
295 }
296
Literal(Literal && other)297 Literal::Literal(Literal&& other) : MutableLiteralBase() {
298 *this = std::move(other);
299 }
300
operator =(Literal && other)301 Literal& Literal::operator=(Literal&& other) {
302 DCHECK(&other.root_piece_.subshape() == other.shape_.get());
303 using std::swap;
304 swap(shape_, other.shape_);
305 swap(root_piece_, other.root_piece_);
306 DCHECK(&root_piece_.subshape() == shape_.get());
307
308 return *this;
309 }
310
CreateFromShape(const Shape & shape)311 Literal LiteralBase::CreateFromShape(const Shape& shape) {
312 Literal literal(shape);
313 literal.root_piece_.ForEachMutableSubpiece(
314 [&](const ShapeIndex& index, Piece* piece) {
315 if (piece->subshape().IsArray()) {
316 memset(piece->untyped_data(), 0, piece->size_bytes());
317 }
318 });
319 return literal;
320 }
321
CreateFromShapeWithUnknownLeafArrays(const Shape & shape)322 Literal LiteralBase::CreateFromShapeWithUnknownLeafArrays(const Shape& shape) {
323 Literal literal(shape, /*allocate_arrays=*/false, ArrayValueState::kUnknown);
324 return literal;
325 }
326
CreateFromShapeWithUndeterminedLeafArrays(const Shape & shape)327 Literal LiteralBase::CreateFromShapeWithUndeterminedLeafArrays(
328 const Shape& shape) {
329 Literal literal(shape, /*allocate_arrays=*/false,
330 ArrayValueState::kUndetermined);
331 return literal;
332 }
333
GetDynamicSize(int64_t dim_index) const334 int32_t LiteralBase::GetDynamicSize(int64_t dim_index) const {
335 return GetDynamicSize(dim_index, {});
336 }
337
GetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index) const338 int32_t LiteralBase::GetDynamicSize(int64_t dim_index,
339 const ShapeIndex& shape_index) const {
340 return piece(shape_index).GetDynamicSize(dim_index);
341 }
342
GetFirstInteger() const343 std::optional<int64_t> LiteralBase::GetFirstInteger() const {
344 switch (shape().element_type()) {
345 case U8:
346 return GetFirstElement<uint8_t>();
347 case U16:
348 return GetFirstElement<uint16_t>();
349 case U32:
350 return GetFirstElement<uint32_t>();
351 case U64: {
352 int64_t v = GetFirstElement<uint64_t>();
353 if (v < 0) {
354 return std::nullopt;
355 }
356 return v;
357 }
358 case S8:
359 return GetFirstElement<int8_t>();
360 case S16:
361 return GetFirstElement<int16_t>();
362 case S32:
363 return GetFirstElement<int32_t>();
364 case S64:
365 return GetFirstElement<int64_t>();
366 default:
367 return std::nullopt;
368 }
369 }
370
371 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64_t> src_base,absl::Span<const int64_t> dest_base,absl::Span<const int64_t> copy_size)372 Status MutableLiteralBase::CopySliceFromInternal(
373 const LiteralBase& src_literal, absl::Span<const int64_t> src_base,
374 absl::Span<const int64_t> dest_base, absl::Span<const int64_t> copy_size) {
375 const int64_t src_base_size = src_base.size();
376 const int64_t dest_base_size = dest_base.size();
377 TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
378 TF_RET_CHECK(shape().rank() == dest_base_size);
379
380 auto linear_index = [](const Shape& shape,
381 absl::Span<const int64_t> multi_index) {
382 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
383 };
384
385 if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
386 // If any of the two shapes are scalars, we can just call the StridedCopy()
387 // directly, and we know we will be copying only one value.
388 TF_RET_CHECK(copy_size.empty());
389 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
390 src_literal.data<NativeT>(),
391 linear_index(src_literal.shape(), src_base), 0, 1);
392 } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
393 !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
394 // Perform copy if neither src nor dest has dimensions with zero element,
395 // otherwise it's a no-op.
396 TF_RET_CHECK(src_base.size() == dest_base.size());
397 TF_RET_CHECK(src_base.size() == copy_size.size());
398
399 // Scan the source from minor, stepping in copy size blocks, then within
400 // the index enumeration functor, do a strided copy advancing source index
401 // by one (walking through the minor dimension), and destination index by
402 // proper stride size at the matching dimension.
403 DimensionVector src_indexes(src_base.size(), 0);
404 DimensionVector dest_indexes(dest_base.size(), 0);
405 MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
406 copy_size);
407
408 auto copy_proc = [&](absl::Span<const int64_t> indexes) {
409 // Map from multi-dimensional index, to source index.
410 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
411 src_indexes.begin(), std::plus<int64_t>());
412 // Map from multi-dimensional index, to destination index.
413 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
414 dest_indexes.begin(), std::plus<int64_t>());
415
416 int64_t src_index = linear_index(src_literal.shape(), src_indexes);
417 int64_t dest_index = linear_index(shape(), dest_indexes);
418
419 // `this->` is needed to workaround MSVC bug: #16882
420 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
421 src_literal.data<NativeT>(), src_index,
422 stride_config.source_stride, stride_config.minor_loop_size);
423 return true;
424 };
425
426 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
427 stride_config.dimensions, stride_config.step,
428 copy_proc);
429 }
430 return OkStatus();
431 }
432
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64_t> src_index,absl::Span<const int64_t> dest_index)433 Status MutableLiteralBase::CopyElementFrom(
434 const LiteralSlice& src_literal, absl::Span<const int64_t> src_index,
435 absl::Span<const int64_t> dest_index) {
436 DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
437 const int64_t src_linear_index =
438 IndexUtil::MultidimensionalIndexToLinearIndex(src_literal.shape(),
439 src_index);
440 const int64_t dest_linear_index =
441 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
442 const int64_t primitive_size =
443 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
444
445 char* dest_address =
446 static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
447 const char* source_address =
448 static_cast<const char*>(src_literal.untyped_data()) +
449 src_linear_index * primitive_size;
450 if (dest_address != source_address) {
451 memcpy(dest_address, source_address, primitive_size);
452 }
453 return OkStatus();
454 }
455
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)456 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
457 const LiteralProto& proto, bool prohibit_empty_literal) {
458 if (!proto.has_shape()) {
459 return InvalidArgument("LiteralProto has no shape");
460 }
461 Shape shape(proto.shape());
462 if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
463 return InvalidArgument(
464 "Literal shape cannot include OPAQUE_TYPE sub-shape");
465 }
466 if (!LayoutUtil::HasLayout(shape)) {
467 return InvalidArgument("LiteralProto has no layout");
468 }
469
470 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
471
472 Literal literal(shape);
473
474 TF_RETURN_IF_ERROR(literal.root_piece_.ForEachMutableSubpieceWithStatus(
475 [&](const ShapeIndex& index, Piece* piece) {
476 const LiteralProto* proto_element = &proto;
477 for (int64_t i : index) {
478 CHECK(i < proto_element->tuple_literals_size());
479 proto_element = &proto_element->tuple_literals(i);
480 }
481
482 if (piece->subshape().IsTuple()) {
483 if (proto_element->tuple_literals_size() !=
484 ShapeUtil::TupleElementCount(piece->subshape())) {
485 return InvalidArgument(
486 "Expected %d tuple elements in LiteralProto, has %d",
487 ShapeUtil::TupleElementCount(piece->subshape()),
488 proto_element->tuple_literals_size());
489 }
490 return OkStatus();
491 }
492 if (piece->subshape().element_type() == TOKEN) {
493 return OkStatus();
494 }
495
496 CHECK(piece->subshape().IsArray());
497
498 // When prohibit_empty_literal is false (allowing literal with no
499 // values), only copy from proto if the literal proto has values. This
500 // mode is used for a learned cost model.
501 if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
502 TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
503 }
504
505 return OkStatus();
506 }));
507
508 return std::move(literal);
509 }
510
SubLiteral(ShapeIndexView shape_index)511 Literal Literal::SubLiteral(ShapeIndexView shape_index) {
512 if (!shape_index.empty()) {
513 auto decomposed = this->DecomposeTuple();
514 return decomposed.at(shape_index.front())
515 .SubLiteral(shape_index.subspan(1));
516 } else {
517 return std::move(*this);
518 }
519 }
520
DecomposeTuple()521 std::vector<Literal> Literal::DecomposeTuple() {
522 CHECK(shape().IsTuple());
523 std::vector<Literal> elements;
524 const auto tuple_element_count = ShapeUtil::TupleElementCount(shape());
525 elements.reserve(tuple_element_count);
526 for (int i = 0; i < tuple_element_count; ++i) {
527 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
528 /*allocate_arrays=*/false));
529 Literal& element = elements.back();
530 element.root_piece_.ForEachMutableSubpiece(
531 [&](const ShapeIndex& index, Piece* dest_piece) {
532 if (dest_piece->subshape().IsTuple()) {
533 return;
534 }
535 ShapeIndex src_index = {i};
536 for (int64_t j : index) {
537 src_index.push_back(j);
538 }
539 Piece& src_piece = piece(src_index);
540
541 // Move the respective buffer over to the element Literal.
542 dest_piece->MoveDataFrom(src_piece);
543 });
544 }
545 // Set this literal to be nil-shaped.
546 *this = Literal();
547 return elements;
548 }
549
550 namespace {
551
552 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
553 // the array slices are indicated by dest_shape and src_shape respectively.
554 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)555 void CopyElementsBetween(absl::Span<NativeT> dest,
556 absl::Span<const NativeT> src, const Shape& dest_shape,
557 const Shape& src_shape) {
558 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
559 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
560 return;
561 }
562 std::vector<int64_t> index(dest_shape.rank());
563 do {
564 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
565 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
566 } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
567 }
568 } // namespace
569
GetDynamicSize(int64_t dim_index) const570 int32_t LiteralBase::Piece::GetDynamicSize(int64_t dim_index) const {
571 CHECK(LayoutUtil::IsDenseArray(subshape()));
572 if (!subshape_->is_dynamic_dimension(dim_index)) {
573 // This is a static dimension, return size.
574 return subshape_->dimensions(dim_index);
575 }
576 return dynamic_size_buffer()[dim_index];
577 }
578
SetDynamicSize(int64_t dim_index,int32_t size)579 void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
580 CHECK(LayoutUtil::IsDenseArray(subshape()));
581 CHECK(subshape_->is_dynamic_dimension(dim_index));
582 dynamic_size_buffer()[dim_index] = size;
583 }
584
AllocateBuffers()585 void LiteralBase::Piece::AllocateBuffers() {
586 const int64_t bytes = total_bytes();
587 if (bytes > kMaxInlinedBytes) {
588 CHECK_EQ(buffer(), nullptr);
589 rep_.emplace<ArrayRep>();
590 set_buffer(static_cast<char*>(
591 tensorflow::port::AlignedMalloc(bytes, kMinimumAlignment)));
592 } else {
593 rep_.emplace<InlinedRep>();
594 }
595 }
596
DeallocateBuffers()597 void LiteralBase::Piece::DeallocateBuffers() {
598 if (auto* array_rep = GetArrayRep()) {
599 tensorflow::port::AlignedFree(array_rep->data);
600 rep_.emplace<Uninitialized>();
601 }
602 }
603
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)604 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
605 bool only_dynamic_bound) {
606 CHECK(subshape_ != nullptr);
607 CHECK(src.subshape_ != nullptr);
608 if (src.array_value_state_ == ArrayValueState::kUnknown ||
609 src.array_value_state_ == ArrayValueState::kUndetermined) {
610 if (array_value_state_ == ArrayValueState::kKnown) {
611 DeallocateBuffers();
612 }
613 array_value_state_ = src.array_value_state_;
614 return OkStatus();
615 } else {
616 CHECK(src.array_value_state_ == ArrayValueState::kKnown);
617 if (array_value_state_ == ArrayValueState::kUndetermined ||
618 array_value_state_ == ArrayValueState::kUnknown) {
619 AllocateBuffers();
620 }
621 array_value_state_ = src.array_value_state_;
622 }
623
624 if (ShapeUtil::Equal(subshape(), src.subshape())) {
625 // If the layouts are equal it's faster just to memcpy.
626 memcpy(buffer(), src.buffer(), src.size_bytes());
627 } else {
628 std::vector<int64_t> origin(subshape().rank(), 0);
629 switch (subshape().element_type()) {
630 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
631 case (XLA_T): \
632 if (only_dynamic_bound) { \
633 CopyElementsWithDynamicBound<NATIVE_T>(src); \
634 } else { \
635 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
636 subshape(), src.subshape()); \
637 } \
638 break;
639 COPY_ELEMENTS(U8, uint8_t);
640 COPY_ELEMENTS(U16, uint16_t);
641 COPY_ELEMENTS(U32, uint32_t);
642 COPY_ELEMENTS(U64, uint64_t);
643 COPY_ELEMENTS(S8, int8_t);
644 COPY_ELEMENTS(S16, int16_t);
645 COPY_ELEMENTS(S32, int32_t);
646 COPY_ELEMENTS(S64, int64_t);
647 COPY_ELEMENTS(F16, half);
648 COPY_ELEMENTS(BF16, bfloat16);
649 COPY_ELEMENTS(F32, float);
650 COPY_ELEMENTS(F64, double);
651 COPY_ELEMENTS(C64, complex64);
652 COPY_ELEMENTS(C128, complex128);
653 COPY_ELEMENTS(PRED, bool);
654 #undef COPY_ELEMENTS
655 default:
656 return Unimplemented(
657 "Copying a Literal object with element type %s is not implemented.",
658 PrimitiveType_Name(subshape().element_type()));
659 }
660 }
661 DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
662 if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
663 memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
664 src.dynamic_size_buffer_bytes());
665 }
666 return OkStatus();
667 }
668
SetDynamicSize(int64_t dim_index,int32_t size)669 void MutableLiteralBase::SetDynamicSize(int64_t dim_index, int32_t size) {
670 return SetDynamicSize(dim_index, {}, size);
671 }
672
SetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index,int32_t size)673 void MutableLiteralBase::SetDynamicSize(int64_t dim_index,
674 const ShapeIndex& shape_index,
675 int32_t size) {
676 Shape* subshape =
677 ShapeUtil::GetMutableSubshape(mutable_shape_do_not_use(), shape_index);
678 CHECK_GE(subshape->dimensions(dim_index), size);
679 if (subshape->dimensions(dim_index) == size) {
680 subshape->set_dynamic_dimension(dim_index, false);
681 return;
682 }
683 subshape->set_dynamic_dimension(dim_index, true);
684 CHECK_EQ(&piece(shape_index).subshape(), subshape);
685
686 piece(shape_index).SetDynamicSize(dim_index, size);
687 }
688
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)689 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
690 const ShapeIndex& dest_shape_index,
691 const ShapeIndex& src_shape_index,
692 bool only_dynamic_bound) {
693 const Shape& dest_subshape =
694 ShapeUtil::GetSubshape(shape(), dest_shape_index);
695 const Shape& src_subshape =
696 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
697 if (only_dynamic_bound) {
698 auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
699 auto compact_shape =
700 dest_subshape.is_static() ? dest_subshape : src_subshape;
701 CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
702 << compact_shape.ToString() << " vs " << bound_shape.ToString();
703 } else {
704 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
705 return InvalidArgument(
706 "Destination subshape incompatible with source subshape: %s vs %s",
707 ShapeUtil::HumanString(dest_subshape),
708 ShapeUtil::HumanString(src_subshape));
709 }
710 }
711 return mutable_root_piece().ForEachMutableSubpieceWithStatus(
712 [&](const ShapeIndex& index, Piece* piece) {
713 if (!piece->subshape().IsArray()) {
714 return OkStatus();
715 }
716
717 // Determine if this index is in the part of this literal that we want
718 // to copy over from src_literal.
719 bool in_subtree_to_copy = true;
720 for (int i = 0; i < dest_shape_index.size(); ++i) {
721 if (index[i] != dest_shape_index[i]) {
722 in_subtree_to_copy = false;
723 break;
724 }
725 }
726 if (!in_subtree_to_copy) {
727 return OkStatus();
728 }
729 // Construct the index of the corresponding piece in the source literal.
730 ShapeIndex src_piece_index = src_shape_index;
731 for (int64_t i = dest_shape_index.size(), end = index.size(); i < end;
732 ++i) {
733 src_piece_index.push_back(index[i]);
734 }
735 TF_RETURN_IF_ERROR(
736 piece->CopyFrom(src_literal.piece(src_piece_index),
737 /*only_dynamic_bound=*/only_dynamic_bound));
738 return OkStatus();
739 });
740 }
741
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)742 Status Literal::MoveFrom(Literal&& src_literal,
743 const ShapeIndex& dest_shape_index) {
744 const Shape& dest_subshape =
745 ShapeUtil::GetSubshape(shape(), dest_shape_index);
746 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
747 return InvalidArgument(
748 "Destination subshape not equal to source shape: %s vs %s",
749 ShapeUtil::HumanString(dest_subshape),
750 ShapeUtil::HumanString(src_literal.shape()));
751 }
752
753 src_literal.root_piece_.ForEachMutableSubpiece(
754 [&](const ShapeIndex& src_index, Piece* src_piece) {
755 if (!src_piece->subshape().IsArray()) {
756 return;
757 }
758
759 ShapeIndex dest_index = dest_shape_index;
760 for (int64_t i : src_index) {
761 dest_index.push_back(i);
762 }
763 Piece& dest_piece = piece(dest_index);
764 dest_piece.DeallocateBuffers();
765 dest_piece.MoveDataFrom(*src_piece);
766 });
767
768 src_literal.shape_ = MaybeOwningShapePtr(&NilShape());
769 src_literal.root_piece_ = Piece();
770 src_literal.root_piece_.set_subshape(src_literal.shape_.get());
771
772 return OkStatus();
773 }
774
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64_t> src_base,absl::Span<const int64_t> dest_base,absl::Span<const int64_t> copy_size)775 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
776 absl::Span<const int64_t> src_base,
777 absl::Span<const int64_t> dest_base,
778 absl::Span<const int64_t> copy_size) {
779 TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
780 TF_RET_CHECK(src_literal.shape().IsArray())
781 << ShapeUtil::HumanString(src_literal.shape());
782 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
783
784 switch (shape().element_type()) {
785 case U8:
786 return CopySliceFromInternal<uint8_t>(src_literal, src_base, dest_base,
787 copy_size);
788 case U16:
789 return CopySliceFromInternal<uint16_t>(src_literal, src_base, dest_base,
790 copy_size);
791 case U32:
792 return CopySliceFromInternal<uint32_t>(src_literal, src_base, dest_base,
793 copy_size);
794 case U64:
795 return CopySliceFromInternal<uint64_t>(src_literal, src_base, dest_base,
796 copy_size);
797 case S8:
798 return CopySliceFromInternal<int8_t>(src_literal, src_base, dest_base,
799 copy_size);
800 case S16:
801 return CopySliceFromInternal<int16_t>(src_literal, src_base, dest_base,
802 copy_size);
803 case S32:
804 return CopySliceFromInternal<int32_t>(src_literal, src_base, dest_base,
805 copy_size);
806 case S64:
807 return CopySliceFromInternal<int64_t>(src_literal, src_base, dest_base,
808 copy_size);
809 case F16:
810 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
811 copy_size);
812 case BF16:
813 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
814 copy_size);
815 case F32:
816 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
817 copy_size);
818 case F64:
819 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
820 copy_size);
821 case C64:
822 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
823 copy_size);
824 case C128:
825 return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
826 copy_size);
827 case PRED:
828 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
829 copy_size);
830 default:
831 break;
832 }
833 return Unimplemented(
834 "Copying a slice from a Literal object with element type %d is not "
835 "implemented.",
836 shape().element_type());
837 }
838
PopulateR1(const tensorflow::core::Bitmap & values)839 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
840 CHECK(shape().IsArray());
841 CHECK_EQ(shape().rank(), 1);
842 CHECK_EQ(element_count(), values.bits());
843 CHECK_EQ(shape().element_type(), PRED);
844 for (int64_t i = 0; i < static_cast<int64_t>(values.bits()); ++i) {
845 Set({i}, values.get(i));
846 }
847 }
848
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const849 Literal LiteralBase::Relayout(const Layout& new_layout,
850 const ShapeIndex& shape_index) const {
851 // Create new shape with 'new_layout' set at the given shape index.
852 Shape new_shape = shape();
853 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
854 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
855 *subshape->mutable_layout() = new_layout;
856 Literal result(new_shape);
857 TF_CHECK_OK(result.CopyFrom(*this));
858 return result;
859 }
860
Relayout(const Shape & shape_with_layout) const861 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
862 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
863 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
864 << " not compatible with literal shape "
865 << ShapeUtil::HumanString(shape());
866 Literal result = CreateFromShape(shape_with_layout);
867 ShapeUtil::ForEachSubshape(
868 result.shape(),
869 [this, &result](const Shape& subshape, const ShapeIndex& index) {
870 if (subshape.IsArray()) {
871 TF_CHECK_OK(result.CopyFrom(*this,
872 /*dest_shape_index=*/index,
873 /*src_shape_index=*/index));
874 }
875 });
876 return result;
877 }
878
ToBoundedDynamic(const Shape & bounded_shape) const879 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
880 CHECK(bounded_shape.is_dynamic());
881 Literal result(bounded_shape);
882 ShapeUtil::ForEachSubshape(
883 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
884 if (!subshape.IsArray()) {
885 return;
886 }
887 for (int64_t i = 0; i < subshape.rank(); ++i) {
888 result.SetDynamicSize(i, subshape.dimensions(i));
889 }
890 });
891 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
892
893 return result;
894 }
895
ToStatic() const896 Literal LiteralBase::ToStatic() const {
897 // Create new shape with 'new_layout' set at the given shape index.
898 Shape new_shape = shape();
899 ShapeUtil::ForEachMutableSubshape(
900 &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
901 if (!subshape->IsArray()) {
902 return;
903 }
904 for (int64_t i = 0; i < subshape->rank(); ++i) {
905 subshape->set_dynamic_dimension(i, false);
906 subshape->set_dimensions(i, GetDynamicSize(i, index));
907 }
908 });
909 Literal result(new_shape);
910 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
911 return result;
912 }
913
Broadcast(const Shape & result_shape,absl::Span<const int64_t> dimensions) const914 StatusOr<Literal> LiteralBase::Broadcast(
915 const Shape& result_shape, absl::Span<const int64_t> dimensions) const {
916 if (!shape().IsArray()) {
917 return InvalidArgument("Broadcast only supports arrays.");
918 }
919
920 for (int64_t i = 0, end = dimensions.size(); i < end; i++) {
921 TF_RET_CHECK(shape().dimensions(i) ==
922 result_shape.dimensions(dimensions[i]));
923 }
924
925 TF_RET_CHECK(result_shape.element_type() == shape().element_type());
926 Literal result(result_shape);
927 // scratch_source_index is temporary storage space for the computed index into
928 // the input literal. We put it here to avoid allocating an std::vector in
929 // every iteration of ShapeUtil::ForEachIndex.
930 std::vector<int64_t> scratch_source_index(shape().dimensions_size());
931
932 char* dest_data = static_cast<char*>(result.untyped_data());
933 const char* source_data = static_cast<const char*>(untyped_data());
934 const int64_t primitive_size =
935 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
936 for (int64_t i = 0; i < dimensions.size(); ++i) {
937 int64_t dynamic_size = GetDynamicSize(i);
938 result.SetDynamicSize(dimensions[i], dynamic_size);
939 }
940
941 ShapeUtil::ForEachIndex(
942 result_shape, [&](absl::Span<const int64_t> output_index) {
943 for (int64_t i = 0, end = dimensions.size(); i < end; ++i) {
944 scratch_source_index[i] = output_index[dimensions[i]];
945 }
946 int64_t dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
947 result_shape, output_index);
948 int64_t source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
949 shape(), scratch_source_index);
950 memcpy(dest_data + primitive_size * dest_index,
951 source_data + primitive_size * source_index, primitive_size);
952 return true;
953 });
954
955 return std::move(result);
956 }
957
Reshape(absl::Span<const int64_t> dimensions) const958 StatusOr<Literal> LiteralBase::Reshape(
959 absl::Span<const int64_t> dimensions) const {
960 if (!shape().IsArray()) {
961 return InvalidArgument("Reshape does not support tuples.");
962 }
963 if (shape().is_dynamic()) {
964 return Unimplemented("Dynamic reshape is not implemented.");
965 }
966 Literal output;
967 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
968 output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
969 } else {
970 output = Clone();
971 }
972 // Because the layout is monotonic, we can simply reuse the same sequence of
973 // values without changing their order.
974 *output.mutable_shape_do_not_use() =
975 ShapeUtil::MakeShape(shape().element_type(), dimensions);
976
977 int64_t elements_before = ShapeUtil::ElementsIn(shape());
978 int64_t elements_after = ShapeUtil::ElementsIn(output.shape());
979 if (elements_before != elements_after) {
980 return InvalidArgument(
981 "Shapes before and after Literal::Reshape have different numbers "
982 "of elements: %s vs %s.",
983 ShapeUtil::HumanString(shape()),
984 ShapeUtil::HumanString(output.shape()));
985 }
986 return std::move(output);
987 }
988
Transpose(absl::Span<const int64_t> permutation) const989 Literal LiteralBase::Transpose(absl::Span<const int64_t> permutation) const {
990 CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
991 CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
992 << "Given permutation is not a permutation of dimension numbers";
993 // To transpose the array, we just permute the dimensions and layout, and
994 // do a straight memory copy of the raw data set.
995 // This is considerably faster than iterating over every array element using
996 // the EachCell<>() and Set<>() APIs.
997 Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
998 // Replace the layout with one affine to this shape, such that a
999 // transpose operation can be performed by leaving the flat values
1000 // representation intact.
1001 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
1002 // The shape with affine layout resulting from that operation will be
1003 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
1004 // most minor.
1005 //
1006 // Essentially, given MinMaj(Di) the position of the Di dimension within the
1007 // minor to major vector, and given T(Di) the index that the original Di
1008 // dimension has within the transposed array, a layout is affine if
1009 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
1010 // vector of the affine layout.
1011 std::vector<int64_t> inverse_permutation = InversePermutation(permutation);
1012 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
1013 Layout* layout = permuted_shape.mutable_layout();
1014 layout->clear_minor_to_major();
1015 for (auto index : LayoutUtil::MinorToMajor(shape())) {
1016 layout->add_minor_to_major(inverse_permutation[index]);
1017 }
1018 Literal new_literal(permuted_shape);
1019 for (int64_t i = 0; i < shape().rank(); i++) {
1020 new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
1021 }
1022 DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
1023 ShapeUtil::ByteSizeOf(shape()));
1024 std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
1025 return new_literal;
1026 }
1027
1028 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64_t> start_indices) const1029 Literal LiteralBase::SliceInternal(
1030 const Shape& result_shape, absl::Span<const int64_t> start_indices) const {
1031 Literal result_literal(result_shape);
1032 DimensionVector new_indices(result_shape.rank());
1033 CHECK(result_literal
1034 .Populate<NativeT>([&](absl::Span<const int64_t> indices) {
1035 for (int64_t i = 0; i < result_shape.rank(); ++i) {
1036 new_indices[i] = indices[i] + start_indices[i];
1037 }
1038 return Get<NativeT>(new_indices);
1039 })
1040 .ok());
1041 for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
1042 if (shape().is_dynamic_dimension(dnum)) {
1043 int64_t dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
1044 CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
1045 dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
1046 result_literal.SetDynamicSize(dnum, dynamic_size);
1047 }
1048 }
1049 return result_literal;
1050 }
1051
Slice(absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices) const1052 Literal LiteralBase::Slice(absl::Span<const int64_t> start_indices,
1053 absl::Span<const int64_t> limit_indices) const {
1054 CHECK(shape().IsArray()) << "tuple is not supported for slice";
1055
1056 DimensionVector result_dimensions;
1057 for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
1058 CHECK_GE(start_indices[dnum], 0);
1059 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
1060 << "dnum = " << dnum;
1061 int64_t dimension = limit_indices[dnum] - start_indices[dnum];
1062 CHECK_GE(dimension, 0) << "dnum = " << dnum;
1063 result_dimensions.push_back(dimension);
1064 }
1065 auto result_shape =
1066 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
1067 LayoutUtil::MinorToMajor(shape()));
1068 ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
1069 switch (result_shape.element_type()) {
1070 case PRED:
1071 return SliceInternal<bool>(result_shape, start_indices);
1072 case U8:
1073 return SliceInternal<uint8_t>(result_shape, start_indices);
1074 case U16:
1075 return SliceInternal<uint16_t>(result_shape, start_indices);
1076 case U32:
1077 return SliceInternal<uint32_t>(result_shape, start_indices);
1078 case U64:
1079 return SliceInternal<uint64_t>(result_shape, start_indices);
1080 case S8:
1081 return SliceInternal<int8_t>(result_shape, start_indices);
1082 case S16:
1083 return SliceInternal<int16_t>(result_shape, start_indices);
1084 case S32:
1085 return SliceInternal<int32_t>(result_shape, start_indices);
1086 case S64:
1087 return SliceInternal<int64_t>(result_shape, start_indices);
1088 case F16:
1089 return SliceInternal<half>(result_shape, start_indices);
1090 case BF16:
1091 return SliceInternal<bfloat16>(result_shape, start_indices);
1092 case F32:
1093 return SliceInternal<float>(result_shape, start_indices);
1094 case F64:
1095 return SliceInternal<double>(result_shape, start_indices);
1096 case C64:
1097 return SliceInternal<complex64>(result_shape, start_indices);
1098 case C128:
1099 return SliceInternal<complex128>(result_shape, start_indices);
1100 default:
1101 LOG(FATAL) << "not yet implemented: "
1102 << PrimitiveType_Name(result_shape.element_type());
1103 }
1104 }
1105
Clone() const1106 Literal LiteralBase::Clone() const {
1107 Literal result(shape());
1108 TF_CHECK_OK(result.CopyFrom(*this));
1109 return result;
1110 }
1111
CloneToUnique() const1112 std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
1113 auto result = std::make_unique<Literal>(shape());
1114 TF_CHECK_OK(result->CopyFrom(*this));
1115 return result;
1116 }
1117
IsDetermined(const ShapeIndex & shape_index) const1118 bool LiteralBase::IsDetermined(const ShapeIndex& shape_index) const {
1119 return piece(shape_index).IsDetermined();
1120 }
1121
IsKnown(const ShapeIndex & shape_index) const1122 bool LiteralBase::IsKnown(const ShapeIndex& shape_index) const {
1123 return piece(shape_index).IsKnown();
1124 }
1125
GetAsString(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index) const1126 std::string LiteralBase::GetAsString(absl::Span<const int64_t> multi_index,
1127 const ShapeIndex& shape_index) const {
1128 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
1129 CHECK(LayoutUtil::IsDenseArray(subshape));
1130 switch (subshape.element_type()) {
1131 case PRED:
1132 return Get<bool>(multi_index, shape_index) ? "true" : "false";
1133 case S8:
1134 return StrCat(Get<int8_t>(multi_index, shape_index));
1135 case S16:
1136 return StrCat(Get<int16_t>(multi_index, shape_index));
1137 case S32:
1138 return StrCat(Get<int32_t>(multi_index, shape_index));
1139 case S64:
1140 return StrCat(Get<int64_t>(multi_index, shape_index));
1141 case U8:
1142 return StrCat(Get<uint8_t>(multi_index, shape_index));
1143 case U16:
1144 return StrCat(Get<uint16_t>(multi_index, shape_index));
1145 case U32:
1146 return StrCat(Get<uint32_t>(multi_index, shape_index));
1147 case U64:
1148 return StrCat(Get<uint64_t>(multi_index, shape_index));
1149 case F16:
1150 return RoundTripFpToString(Get<half>(multi_index, shape_index));
1151 case F32:
1152 return RoundTripFpToString(Get<float>(multi_index, shape_index));
1153 case BF16:
1154 return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1155 case F64:
1156 return RoundTripFpToString(Get<double>(multi_index, shape_index));
1157 case C64: {
1158 complex64 c = Get<complex64>(multi_index, shape_index);
1159 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1160 RoundTripFpToString(c.imag()), ")");
1161 }
1162 case C128: {
1163 complex128 c = Get<complex128>(multi_index, shape_index);
1164 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1165 RoundTripFpToString(c.imag()), ")");
1166 }
1167 default:
1168 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1169 }
1170 }
1171
GetIntegralAsS64(absl::Span<const int64_t> multi_index) const1172 std::optional<int64_t> LiteralBase::GetIntegralAsS64(
1173 absl::Span<const int64_t> multi_index) const {
1174 CHECK(LayoutUtil::IsDenseArray(shape()));
1175 switch (shape().element_type()) {
1176 case PRED:
1177 return Get<bool>(multi_index);
1178 case S8:
1179 return Get<int8_t>(multi_index);
1180 case U8:
1181 return Get<uint8_t>(multi_index);
1182 case S16:
1183 return Get<int16_t>(multi_index);
1184 case U16:
1185 return Get<uint16_t>(multi_index);
1186 case S32:
1187 return Get<int32_t>(multi_index);
1188 case U32:
1189 return Get<uint32_t>(multi_index);
1190 case S64:
1191 return Get<int64_t>(multi_index);
1192 case U64:
1193 return Get<uint64_t>(multi_index);
1194 default:
1195 return std::nullopt;
1196 }
1197 }
1198
GetAsDouble(absl::Span<const int64_t> multi_index) const1199 std::optional<double> LiteralBase::GetAsDouble(
1200 absl::Span<const int64_t> multi_index) const {
1201 CHECK(LayoutUtil::IsDenseArray(shape()));
1202 switch (shape().element_type()) {
1203 case F16:
1204 return static_cast<double>(Get<half>(multi_index));
1205 case F32:
1206 return static_cast<double>(Get<float>(multi_index));
1207 case F64:
1208 return Get<double>(multi_index);
1209 case BF16:
1210 return static_cast<double>(Get<bfloat16>(multi_index));
1211 default:
1212 return std::nullopt;
1213 }
1214 }
1215
GetAsComplex128(absl::Span<const int64_t> multi_index) const1216 std::optional<complex128> LiteralBase::GetAsComplex128(
1217 absl::Span<const int64_t> multi_index) const {
1218 switch (shape().element_type()) {
1219 case BF16:
1220 return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1221 case F16:
1222 return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1223 case F32:
1224 return {{Get<float>(multi_index), 0}};
1225 case F64:
1226 return {{Get<double>(multi_index), 0}};
1227 case C64:
1228 return {Get<complex64>(multi_index)};
1229 case C128:
1230 return {Get<complex128>(multi_index)};
1231 case S8:
1232 return {Get<int8_t>(multi_index)};
1233 default:
1234 return std::nullopt;
1235 }
1236 }
1237
SetIntegralAsS64(absl::Span<const int64_t> multi_index,int64_t value)1238 Status MutableLiteralBase::SetIntegralAsS64(
1239 absl::Span<const int64_t> multi_index, int64_t value) {
1240 CHECK(LayoutUtil::IsDenseArray(shape()));
1241 switch (shape().element_type()) {
1242 case PRED:
1243 Set<bool>(multi_index, value);
1244 break;
1245 case U8:
1246 Set<uint8_t>(multi_index, value);
1247 break;
1248 case S32:
1249 Set<int32_t>(multi_index, value);
1250 break;
1251 case S64:
1252 Set<int64_t>(multi_index, value);
1253 break;
1254 case U32:
1255 Set<uint32_t>(multi_index, value);
1256 break;
1257 case U64:
1258 Set<uint64_t>(multi_index, value);
1259 break;
1260 default:
1261 return FailedPrecondition("Array element type is not integral: %s",
1262 PrimitiveType_Name(shape().element_type()));
1263 }
1264 return OkStatus();
1265 }
1266
SetFromDouble(absl::Span<const int64_t> multi_index,double value)1267 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64_t> multi_index,
1268 double value) {
1269 CHECK(LayoutUtil::IsDenseArray(shape()));
1270 switch (shape().element_type()) {
1271 case F16:
1272 Set<half>(multi_index, Eigen::half(value));
1273 break;
1274 case F32:
1275 Set<float>(multi_index, value);
1276 break;
1277 case F64:
1278 Set<double>(multi_index, value);
1279 break;
1280 case BF16:
1281 Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1282 break;
1283 default:
1284 return FailedPrecondition("Array element type is not floating: %s",
1285 PrimitiveType_Name(shape().element_type()));
1286 }
1287 return OkStatus();
1288 }
1289
1290 namespace {
1291
ShapeToString(bool print_layout,const Shape & shape)1292 std::string ShapeToString(bool print_layout, const Shape& shape) {
1293 return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1294 : ShapeUtil::HumanString(shape);
1295 }
1296
1297 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1298 bool print_shape, bool print_layout,
1299 std::vector<std::string>* pieces);
1300
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1301 void TupleToStringHelper(const LiteralBase& literal,
1302 const ShapeIndex& shape_index, bool print_shape,
1303 bool print_layout, std::vector<std::string>* pieces) {
1304 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1305 pieces->push_back("(\n");
1306 std::vector<std::string> tuple_pieces;
1307 const auto tuple_element_count = ShapeUtil::TupleElementCount(subshape);
1308 tuple_pieces.reserve(tuple_element_count);
1309 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1310 ShapeIndex element_index = shape_index;
1311 element_index.push_back(i);
1312 std::vector<std::string> element_pieces;
1313 ToStringHelper(literal, element_index, print_shape, print_layout,
1314 &element_pieces);
1315 tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1316 }
1317 pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1318 pieces->push_back("\n)");
1319 }
1320
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1321 void DenseArrayToStringHelper(const LiteralBase& literal,
1322 const ShapeIndex& shape_index, bool print_shape,
1323 bool print_layout,
1324 std::vector<std::string>* pieces) {
1325 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1326 int64_t rank = subshape.rank();
1327
1328 std::function<void(absl::Span<const int64_t> dimensions,
1329 std::vector<int64_t>*)>
1330 to_string_recursive = [&](absl::Span<const int64_t> dimensions,
1331 std::vector<int64_t>* accum_indices) {
1332 // dimensions.size() decreases by 1 at each recursive call,
1333 // and accum_indices->size() increases by 1.
1334 // Their sum is equal to the rank of the tensor.
1335 CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1336
1337 auto brace_to_string = [&](std::string brace) -> std::string {
1338 // Handle 1D tensor
1339 if (rank == 1) {
1340 return brace;
1341 }
1342 // Handle the innermost tensor of a 2D+ tensor.
1343 if (dimensions.size() == 1 && brace == "{") {
1344 return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
1345 }
1346 if (dimensions.size() == 1 && brace == "}") {
1347 return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1348 }
1349 // Handle the non-innermost tensors of a 2D+ tensor.
1350 if (brace == "{") {
1351 const int64_t accum_indices_size = accum_indices->size();
1352 if (rank > 3 && !accum_indices->empty() &&
1353 accum_indices_size < rank) {
1354 int index = accum_indices->size() - 1;
1355 int value = accum_indices->back();
1356 return StrCat(brace, " /*i", index, "=", value, "*/\n");
1357 }
1358 return StrCat(brace, "\n");
1359 }
1360 return StrCat("\n", brace);
1361 };
1362
1363 if (dimensions.empty()) {
1364 // Display predicates as 0s and 1s so that the string is more dense.
1365 std::string elem;
1366 if (subshape.element_type() == PRED && rank > 0) {
1367 elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1368 } else {
1369 elem = literal.GetAsString(*accum_indices, shape_index);
1370 }
1371 pieces->push_back(elem);
1372 } else {
1373 pieces->push_back(brace_to_string("{"));
1374 for (int i = 0; i < dimensions[0]; ++i) {
1375 accum_indices->push_back(i);
1376 to_string_recursive(dimensions.subspan(1), accum_indices);
1377 accum_indices->pop_back();
1378 if (i < dimensions[0] - 1) {
1379 pieces->push_back(",");
1380 pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1381 }
1382 }
1383 pieces->push_back(brace_to_string("}"));
1384 }
1385 };
1386
1387 if (print_shape) {
1388 pieces->push_back(ShapeToString(print_layout, subshape));
1389 if (subshape.is_dynamic()) {
1390 pieces->push_back("(");
1391 for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1392 pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1393 if (i < subshape.dimensions_size() - 1) {
1394 pieces->push_back(",");
1395 }
1396 }
1397 pieces->push_back(")");
1398 }
1399 pieces->push_back(" ");
1400 }
1401 std::vector<int64_t> indices = {};
1402 std::vector<int64_t> dimensions;
1403 dimensions.reserve(subshape.rank());
1404 for (int64_t i = 0; i < subshape.rank(); ++i) {
1405 dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1406 }
1407 to_string_recursive(dimensions, &indices);
1408 }
1409
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<std::string> * pieces)1410 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1411 bool print_shape, bool print_layout,
1412 std::vector<std::string>* pieces) {
1413 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1414 CHECK(LayoutUtil::HasLayout(literal.shape()));
1415 CHECK(LayoutUtil::HasLayout(subshape));
1416 if (subshape.IsTuple()) {
1417 TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1418 pieces);
1419 } else if (subshape.IsToken()) {
1420 pieces->push_back("token");
1421 } else {
1422 CHECK(LayoutUtil::IsDenseArray(subshape));
1423 if (literal.IsKnown(shape_index)) {
1424 DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1425 pieces);
1426 } else {
1427 pieces->push_back(ShapeToString(print_layout, subshape));
1428 pieces->push_back(" ");
1429 if (literal.IsDetermined(shape_index)) {
1430 pieces->push_back("unknown");
1431 } else {
1432 pieces->push_back("undetermined");
1433 }
1434 }
1435 }
1436 }
1437
1438 } // namespace
1439
ToString() const1440 std::string LiteralBase::ToString() const {
1441 std::vector<std::string> pieces;
1442 CHECK(LayoutUtil::HasLayout(this->shape()));
1443 ToStringHelper(*this, {}, /*print_shape=*/true,
1444 /*print_layout=*/false, &pieces);
1445 return absl::StrJoin(pieces, "");
1446 }
1447
ToStringOneline() const1448 std::string LiteralBase::ToStringOneline() const {
1449 return CompactOneline(ToString());
1450 }
1451
ToStringWithoutShape() const1452 std::string LiteralBase::ToStringWithoutShape() const {
1453 std::vector<std::string> pieces;
1454 CHECK(LayoutUtil::HasLayout(this->shape()));
1455 ToStringHelper(*this, {}, /*print_shape=*/false,
1456 /*print_layout=*/false, &pieces);
1457 return absl::StrJoin(pieces, "");
1458 }
1459
ToStringWithoutShapeOneline() const1460 std::string LiteralBase::ToStringWithoutShapeOneline() const {
1461 return CompactOneline(ToStringWithoutShape());
1462 }
1463
ToStringWithLayout() const1464 std::string LiteralBase::ToStringWithLayout() const {
1465 std::vector<std::string> pieces;
1466 CHECK(LayoutUtil::HasLayout(this->shape()));
1467 ToStringHelper(*this, {}, /*print_shape=*/true,
1468 /*print_layout=*/true, &pieces);
1469 return absl::StrJoin(pieces, "");
1470 }
1471
ToStringWithLayoutOneline() const1472 std::string LiteralBase::ToStringWithLayoutOneline() const {
1473 return CompactOneline(ToStringWithLayout());
1474 }
1475
EachCellAsString(const std::function<void (absl::Span<const int64_t> indices,const std::string & value)> & per_cell) const1476 void LiteralBase::EachCellAsString(
1477 const std::function<void(absl::Span<const int64_t> indices,
1478 const std::string& value)>& per_cell) const {
1479 if (ShapeUtil::IsZeroElementArray(shape())) {
1480 return;
1481 }
1482 std::vector<int64_t> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1483 shape(), /*linear_index=*/0);
1484 do {
1485 per_cell(indices, GetAsString(indices));
1486 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1487 }
1488
1489 namespace {
1490 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1491 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1492 const ConverterType& converter) {
1493 CHECK(src_literal.shape().IsArray());
1494 Literal result_literal(ShapeUtil::ChangeElementType(
1495 src_literal.shape(),
1496 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1497 auto src_data = src_literal.data<NativeSrcT>();
1498 auto dest_data = result_literal.template data<NativeDestT>();
1499 int64_t num_elements = src_literal.element_count();
1500
1501 for (int64_t i = 0; i < num_elements; ++i) {
1502 dest_data[i] = converter(src_data[i]);
1503 }
1504 return result_literal;
1505 }
1506
1507 template <typename NativeSrcT, typename NativeDestT>
1508 typename std::enable_if<std::is_same<NativeSrcT, Eigen::half>::value &&
1509 (std::is_same<NativeDestT, complex64>::value ||
1510 std::is_same<NativeDestT, complex128>::value),
1511 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1512 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1513 auto converter = [](NativeSrcT src) {
1514 return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1515 };
1516 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1517 src_literal, converter);
1518 }
1519
1520 template <typename NativeSrcT, typename NativeDestT>
1521 typename std::enable_if<std::is_floating_point<NativeSrcT>::value &&
1522 std::is_integral<NativeDestT>::value,
1523 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1524 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1525 auto converter = [](NativeSrcT src) {
1526 // C++ [conv.bool]p1:
1527 // A prvalue of arithmetic [...] type can be converted to a prvalue of
1528 // type bool. A zero value [...] is converted to false; any other value is
1529 // converted to true.
1530 // C++ [conv.fpint]p1:
1531 // [...] The behavior is undefined if the truncated value cannot be
1532 // represented in the destination type.
1533 //
1534 // Using static_cast to convert a float to an integral type other than bool
1535 // may be undefined if the value's magnitude is too large or it is a NaN.
1536 // Let's choose saturating arithmetic as it captures the spirit of infinity
1537 // and arbitrarily map NaN to zero.
1538 if (!std::is_same<NativeDestT, bool>::value) {
1539 if (src != src) {
1540 return NativeDestT{0};
1541 }
1542 if (src >= std::numeric_limits<NativeDestT>::max()) {
1543 return std::numeric_limits<NativeDestT>::max();
1544 }
1545 if (src <= std::numeric_limits<NativeDestT>::lowest()) {
1546 return std::numeric_limits<NativeDestT>::lowest();
1547 }
1548 }
1549 return static_cast<NativeDestT>(src);
1550 };
1551 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1552 src_literal, converter);
1553 }
1554
1555 template <typename NativeSrcT, typename NativeDestT>
1556 typename std::enable_if<!(std::is_floating_point<NativeSrcT>::value &&
1557 std::is_integral<NativeDestT>::value) &&
1558 !(std::is_same<NativeSrcT, Eigen::half>::value &&
1559 (std::is_same<NativeDestT, complex64>::value ||
1560 std::is_same<NativeDestT, complex128>::value)),
1561 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1562 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1563 auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1564 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1565 src_literal, converter);
1566 }
1567
1568 template <typename NativeSrcT, typename NativeDestT>
1569 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1570 !std::is_same<NativeDestT, Eigen::half>::value),
1571 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1572 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1573 auto converter = [](NativeSrcT src) {
1574 return absl::bit_cast<NativeDestT>(GetRawValue(src));
1575 };
1576 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1577 src_literal, converter);
1578 }
1579
1580 template <typename NativeSrcT, typename NativeDestT>
1581 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1582 std::is_same<NativeDestT, Eigen::half>::value),
1583 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1584 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1585 // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1586 // cast to unsigned short first.
1587 auto converter = [](NativeSrcT src) {
1588 return Eigen::numext::bit_cast<Eigen::half>(
1589 absl::bit_cast<uint16_t>(GetRawValue(src)));
1590 };
1591 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1592 src_literal, converter);
1593 }
1594
1595 // This template specialization is here to make the compiler happy. bit_cast has
1596 // a static check that the types are the same size. This specialization should
1597 // never be used because the source and destination types are checked for
1598 // identical sizes higher up.
1599 template <typename NativeSrcT, typename NativeDestT>
1600 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1601 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1602 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1603 LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1604 }
1605
1606 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1607 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1608 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1609 if (bitcast) {
1610 return BitcastBetweenNativeTypes<
1611 typename primitive_util::PrimitiveTypeToNative<
1612 primitive_src_type>::type,
1613 typename primitive_util::PrimitiveTypeToNative<
1614 primitive_dest_type>::type>(src_literal);
1615 } else {
1616 return ConvertBetweenNativeTypes<
1617 typename primitive_util::PrimitiveTypeToNative<
1618 primitive_src_type>::type,
1619 typename primitive_util::PrimitiveTypeToNative<
1620 primitive_dest_type>::type>(src_literal);
1621 }
1622 }
1623
1624 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1625 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1626 PrimitiveType primitive_dest_type,
1627 bool bitcast) {
1628 switch (primitive_dest_type) {
1629 #define CONVERT_IF_TYPES_MATCH(type) \
1630 case (type): \
1631 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1632 bitcast);
1633 CONVERT_IF_TYPES_MATCH(PRED)
1634 CONVERT_IF_TYPES_MATCH(S8)
1635 CONVERT_IF_TYPES_MATCH(S16)
1636 CONVERT_IF_TYPES_MATCH(S32)
1637 CONVERT_IF_TYPES_MATCH(S64)
1638 CONVERT_IF_TYPES_MATCH(U8)
1639 CONVERT_IF_TYPES_MATCH(U16)
1640 CONVERT_IF_TYPES_MATCH(U32)
1641 CONVERT_IF_TYPES_MATCH(U64)
1642 CONVERT_IF_TYPES_MATCH(F16)
1643 CONVERT_IF_TYPES_MATCH(F32)
1644 CONVERT_IF_TYPES_MATCH(F64)
1645 CONVERT_IF_TYPES_MATCH(BF16)
1646 #undef CONVERT_IF_TYPES_MATCH
1647 case C64:
1648 if (bitcast) {
1649 break;
1650 }
1651 return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1652 case C128:
1653 if (bitcast) {
1654 break;
1655 }
1656 return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1657 // Other types are not yet supported.
1658 default:
1659 break;
1660 }
1661 return Unimplemented("Converting from type %s to type %s is not implemented.",
1662 PrimitiveType_Name(src_literal.shape().element_type()),
1663 PrimitiveType_Name(primitive_dest_type));
1664 }
1665
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1666 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1667 PrimitiveType primitive_dest_type,
1668 bool bitcast) {
1669 TF_RET_CHECK(literal.shape().IsArray());
1670 if (literal.shape().element_type() == primitive_dest_type) {
1671 return literal.Clone();
1672 }
1673 switch (literal.shape().element_type()) {
1674 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1675 case (type): \
1676 return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1677 bitcast);
1678 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1679 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1680 CONVERT_IF_DEST_TYPE_MATCHES(S16)
1681 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1682 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1683 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1684 CONVERT_IF_DEST_TYPE_MATCHES(U16)
1685 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1686 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1687 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1688 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1689 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1690 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1691 #undef CONVERT_IF_DEST_TYPE_MATCHES
1692 // Other types are not yet supported.
1693 default:
1694 return Unimplemented("%s from type %s to type %s is not implemented.",
1695 (bitcast ? "Bitcast converting" : "Converting"),
1696 PrimitiveType_Name(literal.shape().element_type()),
1697 PrimitiveType_Name(primitive_dest_type));
1698 }
1699 }
1700
1701 } // namespace
1702
Convert(PrimitiveType primitive_dest_type) const1703 StatusOr<Literal> LiteralBase::Convert(
1704 PrimitiveType primitive_dest_type) const {
1705 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1706 }
1707
BitcastConvert(const Shape & dest_shape) const1708 StatusOr<Literal> LiteralBase::BitcastConvert(const Shape& dest_shape) const {
1709 if (ShapeUtil::ByteSizeOf(dest_shape) != ShapeUtil::ByteSizeOf(shape())) {
1710 return InvalidArgument(
1711 "Can not bitcast-convert from shape %s to a shape of different size %s",
1712 shape().ToString(), dest_shape.ToString());
1713 }
1714 if (dest_shape.IsTuple() || shape().IsTuple()) {
1715 return InvalidArgument(
1716 "bitcast-convert is not valid for tuple shapes %s->%s",
1717 shape().ToString(), dest_shape.ToString());
1718 }
1719 if (shape().is_dynamic() || dest_shape.is_dynamic()) {
1720 return InvalidArgument(
1721 "bitcast-convert is not valid for dynamic shape %s->%s",
1722 shape().ToString(), dest_shape.ToString());
1723 }
1724
1725 Literal out(dest_shape);
1726 std::memcpy(out.root_piece_.buffer(), root_piece().buffer(),
1727 root_piece().size_bytes());
1728 return out;
1729 }
1730
ConvertToShape(const Shape & dest_shape) const1731 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1732 if (!dest_shape.IsTuple()) {
1733 return Convert(dest_shape.element_type());
1734 }
1735 std::vector<Literal> elements;
1736 const auto tuple_element_count = ShapeUtil::TupleElementCount(shape());
1737 elements.reserve(tuple_element_count);
1738 for (int i = 0; i < tuple_element_count; ++i) {
1739 auto element = LiteralSlice(*this, {i});
1740 TF_ASSIGN_OR_RETURN(
1741 auto new_element,
1742 element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1743 elements.push_back(std::move(new_element));
1744 }
1745 return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1746 }
1747
MoveIntoTuple(absl::Span<Literal> elements)1748 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1749 absl::Span<Literal> elements) {
1750 std::vector<const Shape*> element_shapes;
1751 element_shapes.reserve(elements.size());
1752 for (const Literal& element : elements) {
1753 element_shapes.push_back(&element.shape());
1754 }
1755 Literal literal(ShapeUtil::MakeTupleShapeWithPtrs(element_shapes),
1756 /*allocate_arrays=*/false);
1757 for (int i = 0, end = elements.size(); i < end; ++i) {
1758 TF_CHECK_OK(
1759 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1760 }
1761 return literal;
1762 }
1763
1764 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1765 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1766 const LiteralBase::Piece& src) {
1767 auto dest_shape = subshape();
1768 auto src_shape = src.subshape();
1769
1770 // At least one shape has to be static as bound.
1771 CHECK(dest_shape.is_static() || src_shape.is_static());
1772 auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1773 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1774 return;
1775 }
1776 std::vector<int64_t> index(dest_shape.rank());
1777 do {
1778 bool out_of_bound = false;
1779 for (int64_t i = 0; i < index.size(); ++i) {
1780 // Do not copy elements beyond dynamic bound.
1781 if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1782 out_of_bound = true;
1783 }
1784 }
1785 if (out_of_bound) {
1786 continue;
1787 }
1788 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1789 index)] =
1790 src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1791 src_shape, index)];
1792 } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1793 }
1794
1795 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64_t> * multi_index) const1796 bool LiteralBase::Piece::EqualElementsInternal(
1797 const LiteralBase::Piece& other, std::vector<int64_t>* multi_index) const {
1798 if (multi_index->size() == subshape().rank()) {
1799 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1800 }
1801 for (int64_t i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1802 multi_index->push_back(i);
1803 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1804 return false;
1805 }
1806 multi_index->pop_back();
1807 }
1808 return true;
1809 }
1810
EqualDynamicSize(const LiteralBase::Piece & other) const1811 bool LiteralBase::Piece::EqualDynamicSize(
1812 const LiteralBase::Piece& other) const {
1813 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1814 if (subshape().is_static()) {
1815 return true;
1816 }
1817
1818 for (int64_t i = 0; i < subshape().rank(); ++i) {
1819 if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1820 return false;
1821 }
1822 }
1823 return true;
1824 }
1825
EqualElements(const LiteralBase::Piece & other) const1826 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1827 if (subshape().is_static() &&
1828 ShapeUtil::Equal(subshape(), other.subshape()) &&
1829 LayoutUtil::IsDenseArray(subshape())) {
1830 CHECK_EQ(size_bytes(), other.size_bytes());
1831 return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1832 }
1833
1834 std::vector<int64_t> multi_index;
1835 switch (subshape().element_type()) {
1836 case PRED:
1837 return EqualElementsInternal<bool>(other, &multi_index);
1838 case S8:
1839 return EqualElementsInternal<int8_t>(other, &multi_index);
1840 case S16:
1841 return EqualElementsInternal<int16_t>(other, &multi_index);
1842 case S32:
1843 return EqualElementsInternal<int32_t>(other, &multi_index);
1844 case S64:
1845 return EqualElementsInternal<int64_t>(other, &multi_index);
1846 case U8:
1847 return EqualElementsInternal<uint8_t>(other, &multi_index);
1848 case U16:
1849 return EqualElementsInternal<uint16_t>(other, &multi_index);
1850 case U32:
1851 return EqualElementsInternal<uint32_t>(other, &multi_index);
1852 case U64:
1853 return EqualElementsInternal<uint64_t>(other, &multi_index);
1854 case F32:
1855 return EqualElementsInternal<float>(other, &multi_index);
1856 case F64:
1857 return EqualElementsInternal<double>(other, &multi_index);
1858 case F16:
1859 return EqualElementsInternal<half>(other, &multi_index);
1860 case BF16:
1861 return EqualElementsInternal<bfloat16>(other, &multi_index);
1862 case C64:
1863 return EqualElementsInternal<complex64>(other, &multi_index);
1864 case C128:
1865 return EqualElementsInternal<complex128>(other, &multi_index);
1866 default:
1867 LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1868 << PrimitiveType_Name(subshape().element_type());
1869 }
1870 }
1871
operator ==(const LiteralBase & other) const1872 bool LiteralBase::operator==(const LiteralBase& other) const {
1873 // Checking the structure of tuple literals. Checks for dense arrays are
1874 // performed below.
1875 if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1876 return false;
1877 }
1878
1879 return root_piece().ForEachSubpieceWithBool(
1880 [&](const ShapeIndex& index, const Piece& piece) {
1881 const Piece& other_piece = other.piece(index);
1882 const Shape& subshape = piece.subshape();
1883 const Shape& other_subshape = other_piece.subshape();
1884 if (subshape.element_type() != other_subshape.element_type()) {
1885 return false;
1886 }
1887 if (!piece.subshape().IsArray()) {
1888 return true;
1889 }
1890 if (subshape.rank() != other_subshape.rank()) {
1891 return false;
1892 }
1893
1894 for (int64_t i = 0; i < subshape.rank(); ++i) {
1895 if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1896 return false;
1897 }
1898 }
1899
1900 if (!piece.EqualElements(other_piece)) {
1901 return false;
1902 }
1903 return true;
1904 });
1905 }
1906
1907 template <typename NativeT>
EqualIncludingNan(NativeT a,NativeT b)1908 static bool EqualIncludingNan(NativeT a, NativeT b) {
1909 // msvc can't compile std::isnan(a) where `a` is uint8_t. This is a bug
1910 // according to https://en.cppreference.com/w/cpp/numeric/math/isnan, but it's
1911 // easy to work around.
1912 return a == b || (std::isnan(static_cast<double>(a)) &&
1913 std::isnan(static_cast<double>(b)));
1914 }
1915
1916 template <typename T>
EqualIncludingNan(std::complex<T> a,std::complex<T> b)1917 static bool EqualIncludingNan(std::complex<T> a, std::complex<T> b) {
1918 return EqualIncludingNan(a.real(), b.real()) &&
1919 EqualIncludingNan(a.imag(), b.imag());
1920 }
1921
1922 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1923 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1924 NativeT value) {
1925 for (int64_t i = 0; i < data.size(); ++i) {
1926 if (!EqualIncludingNan(data[i], value)) {
1927 return false;
1928 }
1929 }
1930 return true;
1931 }
1932
IsAll(const Literal & scalar) const1933 bool Literal::Piece::IsAll(const Literal& scalar) const {
1934 CHECK(ShapeUtil::IsScalar(scalar.shape())) << scalar.shape().ToString();
1935 if (!subshape().IsArray()) {
1936 return false;
1937 }
1938
1939 CHECK_EQ(subshape().element_type(), scalar.shape().element_type());
1940 switch (subshape().element_type()) {
1941 case U8:
1942 return AllElementsEqualValue(data<uint8_t>(),
1943 scalar.GetFirstElement<uint8_t>());
1944 case U16:
1945 return AllElementsEqualValue(data<uint16_t>(),
1946 scalar.GetFirstElement<uint16_t>());
1947 case U32:
1948 return AllElementsEqualValue(data<uint32_t>(),
1949 scalar.GetFirstElement<uint32_t>());
1950 case U64:
1951 return AllElementsEqualValue(data<uint64_t>(),
1952 scalar.GetFirstElement<uint64_t>());
1953 case S8:
1954 return AllElementsEqualValue(data<int8_t>(),
1955 scalar.GetFirstElement<int8_t>());
1956 case S16:
1957 return AllElementsEqualValue(data<int16_t>(),
1958 scalar.GetFirstElement<int16_t>());
1959 case S32:
1960 return AllElementsEqualValue(data<int32_t>(),
1961 scalar.GetFirstElement<int32_t>());
1962 case S64:
1963 return AllElementsEqualValue(data<int64_t>(),
1964 scalar.GetFirstElement<int64_t>());
1965 case PRED:
1966 return AllElementsEqualValue(data<bool>(),
1967 scalar.GetFirstElement<bool>());
1968 case F16:
1969 return AllElementsEqualValue(data<half>(),
1970 scalar.GetFirstElement<half>());
1971 case BF16:
1972 return AllElementsEqualValue(data<bfloat16>(),
1973 scalar.GetFirstElement<bfloat16>());
1974 case F32:
1975 return AllElementsEqualValue(data<float>(),
1976 scalar.GetFirstElement<float>());
1977 case F64:
1978 return AllElementsEqualValue(data<double>(),
1979 scalar.GetFirstElement<double>());
1980 case C64:
1981 return AllElementsEqualValue(data<complex64>(),
1982 scalar.GetFirstElement<complex64>());
1983 case C128:
1984 return AllElementsEqualValue(data<complex128>(),
1985 scalar.GetFirstElement<complex128>());
1986 default:
1987 return false;
1988 }
1989 }
1990
IsAll(const Literal & scalar) const1991 bool LiteralBase::IsAll(const Literal& scalar) const {
1992 return root_piece().IsAll(scalar);
1993 }
1994
IsAll(int8_t value) const1995 bool LiteralBase::IsAll(int8_t value) const {
1996 if (!shape().IsArray()) {
1997 return false;
1998 }
1999 PrimitiveType ty = shape().element_type();
2000 if (primitive_util::IsFloatingPointType(ty)) {
2001 return IsAllFloat(value);
2002 }
2003 if (primitive_util::IsUnsignedIntegralType(ty) && value < 0) {
2004 return false;
2005 }
2006 Literal scalar(ShapeUtil::MakeScalarShape(ty));
2007 switch (ty) {
2008 case U8:
2009 scalar.Set<uint8_t>({}, value);
2010 break;
2011 case U16:
2012 scalar.Set<uint16_t>({}, value);
2013 break;
2014 case U32:
2015 scalar.Set<uint32_t>({}, value);
2016 break;
2017 case U64:
2018 scalar.Set<uint64_t>({}, value);
2019 break;
2020 case S8:
2021 scalar.Set<int8_t>({}, value);
2022 break;
2023 case S16:
2024 scalar.Set<int16_t>({}, value);
2025 break;
2026 case S32:
2027 scalar.Set<int32_t>({}, value);
2028 break;
2029 case S64:
2030 scalar.Set<int64_t>({}, value);
2031 break;
2032 case PRED:
2033 if (value == 0) {
2034 scalar.Set<bool>({}, false);
2035 } else if (value == 1) {
2036 scalar.Set<bool>({}, true);
2037 } else {
2038 return false;
2039 }
2040 break;
2041 default:
2042 return false;
2043 }
2044 return root_piece().IsAll(scalar);
2045 }
2046
IsAllFloat(float value) const2047 bool LiteralBase::IsAllFloat(float value) const {
2048 if (!shape().IsArray()) {
2049 return false;
2050 }
2051 PrimitiveType ty = shape().element_type();
2052 Literal scalar(ShapeUtil::MakeScalarShape(ty));
2053 switch (ty) {
2054 case F16:
2055 scalar.Set<half>({}, static_cast<half>(value));
2056 break;
2057 case BF16:
2058 scalar.Set<bfloat16>({}, static_cast<bfloat16>(value));
2059 break;
2060 case F32:
2061 scalar.Set<float>({}, value);
2062 break;
2063 case F64:
2064 scalar.Set<double>({}, value);
2065 break;
2066 default:
2067 return false;
2068 }
2069 return root_piece().IsAll(scalar);
2070 }
2071
IsAllComplex(complex64 value) const2072 bool LiteralBase::IsAllComplex(complex64 value) const {
2073 if (!shape().IsArray()) {
2074 return false;
2075 }
2076 PrimitiveType ty = shape().element_type();
2077 Literal scalar(ShapeUtil::MakeScalarShape(ty));
2078 switch (ty) {
2079 case C64:
2080 scalar.Set<complex64>({}, value);
2081 break;
2082 case C128:
2083 scalar.Set<complex128>({}, value);
2084 break;
2085 default:
2086 return false;
2087 }
2088 return root_piece().IsAll(scalar);
2089 }
2090
IsAllFirst() const2091 bool LiteralBase::IsAllFirst() const {
2092 if (!shape().IsArray()) {
2093 return false;
2094 }
2095
2096 // Empty shapes are not all the first element since there is no first element.
2097 if (ShapeUtil::IsZeroElementArray(shape())) {
2098 return false;
2099 }
2100
2101 absl::InlinedVector<int64_t, 4> start_indices(/*n=*/shape().rank(), 0);
2102 absl::InlinedVector<int64_t, 4> end_indices(/*n=*/shape().rank(), 1);
2103 Literal first = Slice(start_indices, end_indices);
2104 return IsAll(first.Reshape({}).ValueOrDie());
2105 }
2106
IsR1Iota() const2107 bool LiteralBase::IsR1Iota() const {
2108 if (!shape().IsArray()) {
2109 return false;
2110 }
2111
2112 if (shape().rank() != 1) {
2113 return false;
2114 }
2115
2116 auto is_iota_at_idx = [&](const int64_t idx) {
2117 switch (shape().element_type()) {
2118 case U8:
2119 return static_cast<int64_t>(Get<uint8_t>({idx})) == idx;
2120 case U16:
2121 return static_cast<int64_t>(Get<uint16_t>({idx})) == idx;
2122 case U32:
2123 return static_cast<int64_t>(Get<uint32_t>({idx})) == idx;
2124 case U64:
2125 return static_cast<int64_t>(Get<uint64_t>({idx})) == idx;
2126 case S8:
2127 return Get<int8_t>({idx}) == idx;
2128 case S16:
2129 return Get<int16_t>({idx}) == idx;
2130 case S32:
2131 return Get<int32_t>({idx}) == idx;
2132 case S64:
2133 return Get<int64_t>({idx}) == idx;
2134 case F32:
2135 return Get<float>({idx}) == idx;
2136 case F64:
2137 return Get<double>({idx}) == idx;
2138 case F16:
2139 return Get<half>({idx}) == static_cast<half>(idx);
2140 case BF16:
2141 return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
2142 case C64:
2143 return Get<complex64>({idx}) == complex64(idx, 0.0f);
2144 case C128:
2145 return Get<complex128>({idx}) == complex128(idx, 0.0f);
2146 // pred, token, opaque, tuple, etc. are all not iota.
2147 default:
2148 return false;
2149 }
2150 };
2151
2152 const int64_t elements = ShapeUtil::ElementsIn(shape());
2153 for (int64_t idx = 0; idx < elements; ++idx) {
2154 if (!is_iota_at_idx(idx)) {
2155 return false;
2156 }
2157 }
2158
2159 return true;
2160 }
2161
2162 // Returns a stride if the literal is a strided iota, i.e., iota multiplied by a
2163 // stride. Only applicable for integer iotas. Returns std::nullopt if the
2164 // literal is not a strided iota.
IsR1StridedIota() const2165 std::optional<int64_t> LiteralBase::IsR1StridedIota() const {
2166 if (!shape().IsArray() || shape().rank() != 1) {
2167 return std::nullopt;
2168 }
2169
2170 const int64_t elements = ShapeUtil::ElementsIn(shape());
2171 const PrimitiveType type = shape().element_type();
2172 if (elements <= 1 || !primitive_util::IsIntegralType(type)) {
2173 return std::nullopt;
2174 }
2175
2176 auto get_element_at = [&](const int64_t idx) -> int64_t {
2177 switch (type) {
2178 case U8:
2179 return static_cast<int64_t>(Get<uint8_t>({idx}));
2180 case U16:
2181 return static_cast<int64_t>(Get<uint16_t>({idx}));
2182 case U32:
2183 return static_cast<int64_t>(Get<uint32_t>({idx}));
2184 case U64:
2185 return static_cast<int64_t>(Get<uint64_t>({idx}));
2186 case S8:
2187 return Get<int8_t>({idx});
2188 case S16:
2189 return Get<int16_t>({idx});
2190 case S32:
2191 return Get<int32_t>({idx});
2192 case S64:
2193 return Get<int64_t>({idx});
2194 default:
2195 CHECK(0);
2196 return 0;
2197 }
2198 };
2199
2200 // Infer the stride as the second element (since first element is supposed
2201 // to be zero).
2202 int64_t stride = get_element_at(1);
2203 if (stride == 0) {
2204 return std::nullopt;
2205 }
2206
2207 for (int64_t idx = 0; idx < elements; ++idx) {
2208 if (get_element_at(idx) != idx * stride) {
2209 return std::nullopt;
2210 }
2211 }
2212
2213 return stride;
2214 }
2215
IsZero(absl::Span<const int64_t> indices) const2216 bool LiteralBase::IsZero(absl::Span<const int64_t> indices) const {
2217 CHECK(shape().IsArray());
2218 switch (shape().element_type()) {
2219 case U8:
2220 return Get<uint8_t>(indices) == 0;
2221 case U16:
2222 return Get<uint16_t>(indices) == 0;
2223 case U32:
2224 return Get<uint32_t>(indices) == 0;
2225 case U64:
2226 return Get<uint64_t>(indices) == 0;
2227 case S8:
2228 return Get<int8_t>(indices) == 0;
2229 case S16:
2230 return Get<int16_t>(indices) == 0;
2231 case S32:
2232 return Get<int32_t>(indices) == 0;
2233 case S64:
2234 return Get<int64_t>(indices) == 0;
2235 case F32:
2236 return Get<float>(indices) == 0.0f;
2237 case F64:
2238 return Get<double>(indices) == 0.0;
2239 case C64:
2240 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
2241 case C128:
2242 return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2243 case F16:
2244 return Get<half>(indices) == static_cast<half>(0.0f);
2245 case BF16:
2246 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2247 case PRED:
2248 return Get<bool>(indices) == false;
2249 default:
2250 LOG(FATAL) << "Input literal must be an array.";
2251 }
2252 }
2253
2254 namespace {
2255
2256 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2257 void CopyToRepeatedField(RepeatedFieldT* dest,
2258 const absl::Span<const NativeT> src) {
2259 *dest = RepeatedFieldT(src.begin(), src.end());
2260 }
2261
2262 } // namespace
2263
set_array_value_state(ArrayValueState state)2264 void LiteralBase::Piece::set_array_value_state(ArrayValueState state) {
2265 array_value_state_ = state;
2266 }
2267
get_array_value_state() const2268 LiteralBase::ArrayValueState LiteralBase::Piece::get_array_value_state() const {
2269 return array_value_state_;
2270 }
2271
WriteToProto(LiteralProto * proto) const2272 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2273 *proto->mutable_shape() = subshape().ToProto();
2274 switch (subshape().element_type()) {
2275 case PRED:
2276 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2277 break;
2278 case S8:
2279 proto->set_s8s(static_cast<const signed char*>(data<int8_t>().data()),
2280 element_count());
2281 break;
2282 case U8:
2283 proto->set_u8s(static_cast<const unsigned char*>(data<uint8_t>().data()),
2284 element_count());
2285 break;
2286 case U32:
2287 CopyToRepeatedField(proto->mutable_u32s(), data<uint32_t>());
2288 break;
2289 case U64:
2290 CopyToRepeatedField(proto->mutable_u64s(), data<uint64_t>());
2291 break;
2292 case S32:
2293 CopyToRepeatedField(proto->mutable_s32s(), data<int32_t>());
2294 break;
2295 case S64:
2296 CopyToRepeatedField(proto->mutable_s64s(), data<int64_t>());
2297 break;
2298 case U16:
2299 *proto->mutable_u16s() = std::string(
2300 reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2301 if (!kLittleEndian) {
2302 ConvertEndianShort(proto->mutable_u16s());
2303 }
2304 break;
2305 case S16:
2306 *proto->mutable_s16s() = std::string(
2307 reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2308 if (!kLittleEndian) {
2309 ConvertEndianShort(proto->mutable_s16s());
2310 }
2311 break;
2312 case F16:
2313 *proto->mutable_f16s() = std::string(
2314 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2315 if (!kLittleEndian) {
2316 ConvertEndianShort(proto->mutable_f16s());
2317 }
2318 break;
2319 case BF16:
2320 *proto->mutable_bf16s() = std::string(
2321 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2322 if (!kLittleEndian) {
2323 ConvertEndianShort(proto->mutable_bf16s());
2324 }
2325 break;
2326 case F32:
2327 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2328 break;
2329 case F64:
2330 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2331 break;
2332 case C64:
2333 for (complex64 value : data<complex64>()) {
2334 proto->add_c64s(value.real());
2335 proto->add_c64s(value.imag());
2336 }
2337 break;
2338 case C128:
2339 for (complex128 value : data<complex128>()) {
2340 proto->add_c128s(value.real());
2341 proto->add_c128s(value.imag());
2342 }
2343 break;
2344 case TUPLE:
2345 case TOKEN:
2346 // Nothing to do but assign the shape which is done above.
2347 return;
2348 default:
2349 // TODO(b/111551621): Support serializing more PrimitiveTypes.
2350 LOG(FATAL) << "Unhandled primitive type "
2351 << PrimitiveType_Name(subshape().element_type());
2352 }
2353 }
2354
untyped_data() const2355 const void* LiteralBase::Piece::untyped_data() const {
2356 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2357 return buffer();
2358 }
2359
untyped_data()2360 void* LiteralBase::Piece::untyped_data() {
2361 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2362 return buffer();
2363 }
2364
2365 namespace {
2366
2367 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2368 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2369 const RepeatedFieldT& src) {
2370 if (dest.size() != src.size()) {
2371 return InvalidArgument(
2372 "Expected %lu elements in LiteralProto repeated field, has %d",
2373 dest.size(), src.size());
2374 }
2375 std::copy(src.begin(), src.end(), dest.begin());
2376 return OkStatus();
2377 }
2378
2379 } // namespace
2380
CopyFromProto(const LiteralProto & proto)2381 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2382 // These conditions should have been checked in
2383 // MutableLiteralBase::CreateFromProto.
2384 TF_RET_CHECK(proto.has_shape());
2385 Shape shape(proto.shape());
2386 TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2387 TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2388
2389 switch (subshape().element_type()) {
2390 case PRED:
2391 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2392 break;
2393 case S8: {
2394 auto s8_data = data<int8_t>();
2395 TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2396 std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2397 } break;
2398 case U8: {
2399 auto u8_data = data<uint8_t>();
2400 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2401 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2402 } break;
2403 case S32:
2404 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32_t>(), proto.s32s()));
2405 break;
2406 case S64:
2407 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64_t>(), proto.s64s()));
2408 break;
2409 case U32:
2410 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32_t>(), proto.u32s()));
2411 break;
2412 case U64:
2413 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64_t>(), proto.u64s()));
2414 break;
2415 case S16: {
2416 const std::string& s(proto.s16s());
2417 TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2418 memcpy(untyped_data(), s.data(), s.size());
2419 if (!kLittleEndian) {
2420 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2421 }
2422 } break;
2423 case U16: {
2424 const std::string& s(proto.u16s());
2425 TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2426 memcpy(untyped_data(), s.data(), s.size());
2427 if (!kLittleEndian) {
2428 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2429 }
2430 } break;
2431 case F16: {
2432 const std::string& s(proto.f16s());
2433 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2434 memcpy(untyped_data(), s.data(), s.size());
2435 if (!kLittleEndian) {
2436 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2437 }
2438 } break;
2439
2440 case BF16: {
2441 const std::string& s(proto.bf16s());
2442 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2443 memcpy(untyped_data(), s.data(), s.size());
2444 if (!kLittleEndian) {
2445 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2446 }
2447 } break;
2448 case F32:
2449 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2450 break;
2451 case F64:
2452 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2453 break;
2454 case C64: {
2455 auto complex_data = data<complex64>();
2456 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2457 for (int64_t i = 0; i < complex_data.size(); ++i) {
2458 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2459 }
2460 break;
2461 }
2462 case C128: {
2463 auto complex_data = data<complex128>();
2464 const int64_t complex_data_size_doubled = complex_data.size() * 2;
2465 TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2466 for (int64_t i = 0, end = complex_data.size(); i < end; ++i) {
2467 complex_data[i] =
2468 complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2469 }
2470 break;
2471 }
2472 case TUPLE:
2473 return InvalidArgument("Should not be called on tuple shapes: %s",
2474 ShapeUtil::HumanString(subshape()));
2475 default:
2476 return InvalidArgument("Is called on unsupported shape: %s",
2477 ShapeUtil::HumanString(subshape()));
2478 }
2479 return OkStatus();
2480 }
2481
IsKnown() const2482 bool LiteralBase::Piece::IsKnown() const {
2483 if (array_value_state_ != ArrayValueState::kKnown) {
2484 return false;
2485 }
2486 if (subshape().IsTuple()) {
2487 bool are_all_leaf_arrays_known = true;
2488 ForEachSubpiece([&are_all_leaf_arrays_known](const ShapeIndex& index,
2489 const Piece& piece) {
2490 if (!piece.subshape().IsArray()) {
2491 return;
2492 }
2493 are_all_leaf_arrays_known &= piece.IsKnown();
2494 });
2495 return are_all_leaf_arrays_known;
2496 }
2497 return true;
2498 }
2499
IsDetermined() const2500 bool LiteralBase::Piece::IsDetermined() const {
2501 if (array_value_state_ == ArrayValueState::kUndetermined) {
2502 return false;
2503 }
2504 if (subshape().IsTuple()) {
2505 bool are_all_leaf_arrays_determined = true;
2506 ForEachSubpiece([&are_all_leaf_arrays_determined](const ShapeIndex& index,
2507 const Piece& piece) {
2508 if (!piece.subshape().IsArray()) {
2509 return;
2510 }
2511 are_all_leaf_arrays_determined &= piece.IsDetermined();
2512 });
2513 return are_all_leaf_arrays_determined;
2514 }
2515 return true;
2516 }
2517
ToProto() const2518 LiteralProto LiteralBase::ToProto() const {
2519 LiteralProto proto;
2520 root_piece().ForEachSubpiece(
2521 [&](const ShapeIndex& index, const Piece& piece) {
2522 LiteralProto* proto_piece = &proto;
2523 for (int64_t i : index) {
2524 while (proto_piece->tuple_literals_size() <= i) {
2525 proto_piece->add_tuple_literals();
2526 }
2527 proto_piece = proto_piece->mutable_tuple_literals(i);
2528 }
2529 piece.WriteToProto(proto_piece);
2530 });
2531
2532 return proto;
2533 }
2534
untyped_data(const ShapeIndex & shape_index) const2535 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2536 return piece(shape_index).untyped_data();
2537 }
2538
untyped_data(const ShapeIndex & shape_index)2539 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2540 return piece(shape_index).untyped_data();
2541 }
2542
size_bytes(const ShapeIndex & shape_index) const2543 int64_t LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2544 return piece(shape_index).size_bytes();
2545 }
2546
GetR1U8AsString() const2547 std::string LiteralBase::GetR1U8AsString() const {
2548 CHECK(shape().IsArray());
2549 CHECK_EQ(shape().rank(), 1);
2550 CHECK_EQ(shape().element_type(), U8);
2551 return std::string(absl::bit_cast<const char*>(data<uint8_t>().data()),
2552 ShapeUtil::ElementsIn(shape()));
2553 }
2554
CopyPieceSubtree(const Shape & shape,const Piece * src_piece,Piece * dest_piece)2555 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2556 const Piece* src_piece,
2557 Piece* dest_piece) {
2558 DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2559 << "src_piece has shape: "
2560 << ShapeUtil::HumanString(src_piece->subshape())
2561 << "dest_piece has shape: "
2562 << ShapeUtil::HumanString(dest_piece->subshape());
2563 dest_piece->set_array_value_state(src_piece->get_array_value_state());
2564 if (shape.IsTuple()) {
2565 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2566 const Shape& subshape = shape.tuple_shapes(i);
2567
2568 auto child_piece = Piece();
2569 child_piece.set_subshape(&subshape);
2570
2571 CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2572
2573 dest_piece->emplace_back(std::move(child_piece));
2574 }
2575 } else if (shape.IsArray()) {
2576 dest_piece->set_buffer(const_cast<char*>(src_piece->buffer()));
2577 } else {
2578 // If the shape is neither an array nor tuple, then it must be
2579 // zero-sized. Otherwise, some memory needs to be allocated for it.
2580 CHECK_EQ(dest_piece->size_bytes(), 0);
2581 }
2582 }
2583
~MutableLiteralBase()2584 MutableLiteralBase::~MutableLiteralBase() {}
2585
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2586 MutableBorrowingLiteral::MutableBorrowingLiteral(
2587 const MutableBorrowingLiteral& literal)
2588 : MutableLiteralBase() {
2589 shape_ = literal.shape_.Clone();
2590 CHECK(LayoutUtil::HasLayout(*shape_));
2591
2592 root_piece_ = new Piece();
2593 root_piece_->set_subshape(shape_.get());
2594
2595 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2596 }
2597
operator =(const MutableBorrowingLiteral & literal)2598 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2599 const MutableBorrowingLiteral& literal) {
2600 shape_ = literal.shape_.Clone();
2601 CHECK(LayoutUtil::HasLayout(*shape_));
2602
2603 root_piece_ = new Piece();
2604 root_piece_->set_subshape(shape_.get());
2605
2606 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2607
2608 return *this;
2609 }
2610
MutableBorrowingLiteral(MutableLiteralBase * literal)2611 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2612 : MutableLiteralBase() {
2613 shape_ = literal->shape_.Clone();
2614 CHECK(LayoutUtil::HasLayout(*shape_));
2615
2616 root_piece_ = new Piece();
2617 root_piece_->set_subshape(shape_.get());
2618
2619 CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2620 }
2621
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2622 MutableBorrowingLiteral::MutableBorrowingLiteral(
2623 MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2624 : MutableLiteralBase() {
2625 shape_ = std::make_unique<Shape>(literal.piece(view_root).subshape());
2626 CHECK(LayoutUtil::HasLayout(*shape_));
2627
2628 root_piece_ = new Piece();
2629 root_piece_->set_subshape(shape_.get());
2630
2631 CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2632 }
2633
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2634 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2635 const Shape& shape)
2636 : MutableLiteralBase() {
2637 shape_ = std::make_unique<Shape>(shape);
2638 CHECK(LayoutUtil::HasLayout(*shape_));
2639 CHECK(!shape_->IsTuple());
2640
2641 root_piece_ = new Piece();
2642 root_piece_->set_subshape(shape_.get());
2643 root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2644 }
2645
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2646 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2647 const Shape& shape)
2648 : MutableLiteralBase() {
2649 shape_ = std::make_unique<Shape>(shape);
2650 if (!shape_->IsTuple()) {
2651 CHECK_EQ(src_buf_ptrs.size(), 1);
2652 root_piece_ = new Piece();
2653 root_piece_->set_subshape(shape_.get());
2654 root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2655 } else {
2656 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2657 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2658 root_piece_ = new Piece();
2659 root_piece_->set_subshape(shape_.get());
2660
2661 for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2662 Piece child_piece;
2663 const auto& src_shape = shape_->tuple_shapes(i);
2664 CHECK(src_shape.IsArray());
2665 child_piece.set_subshape(&src_shape);
2666 child_piece.set_buffer(src_buf_ptrs[i]);
2667 root_piece_->emplace_back(std::move(child_piece));
2668 }
2669 }
2670 }
2671
~MutableBorrowingLiteral()2672 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2673 if (root_piece_ != nullptr) {
2674 delete root_piece_;
2675 }
2676 }
2677
LiteralSlice(const LiteralBase & literal)2678 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2679 : LiteralBase(), root_piece_(&literal.root_piece()) {}
2680
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2681 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2682 const ShapeIndex& view_root)
2683 : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2684
BuildPieceSubtree(const Shape & shape,Piece * piece)2685 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2686 CHECK(shape.IsTuple());
2687 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2688 const Shape& subshape = shape.tuple_shapes(i);
2689
2690 auto child_piece = Piece();
2691 child_piece.set_subshape(&subshape);
2692
2693 if (subshape.IsTuple()) {
2694 BuildPieceSubtree(subshape, &child_piece);
2695 }
2696
2697 piece->emplace_back(std::move(child_piece));
2698 }
2699 }
2700
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2701 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2702 : LiteralBase(), shape_(std::make_unique<Shape>(shape)) {
2703 CHECK(shape_->IsArray());
2704 CHECK(LayoutUtil::HasLayout(*shape_));
2705
2706 root_piece_ = Piece();
2707 root_piece_.set_subshape(shape_.get());
2708 root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2709 }
2710
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2711 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2712 const Shape& shape)
2713 : LiteralBase(), shape_(std::make_unique<Shape>(shape)) {
2714 CHECK(shape_->IsTuple());
2715 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2716 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2717 root_piece_ = Piece();
2718 root_piece_.set_subshape(shape_.get());
2719 BuildPieceSubtree(*shape_, &root_piece_);
2720
2721 for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2722 const auto& src_shape = shape_->tuple_shapes(i);
2723 CHECK(src_shape.IsArray());
2724 root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2725 }
2726 }
2727
2728 } // namespace xla
2729