1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/index_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace xla {
44 namespace {
45
46 using absl::StrCat;
47
48 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
49
50 // Converts between little and big endian.
51 //
52 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)53 void ConvertEndianShort(string* bytes) {
54 CHECK_EQ(bytes->size() / 2, 0);
55 for (int64 i = 0; i < bytes->size(); i += 2) {
56 std::swap((*bytes)[i], (*bytes)[i + 1]);
57 }
58 }
59
ConvertEndianShort(char * bytes,int64 size)60 void ConvertEndianShort(char* bytes, int64 size) {
61 CHECK_EQ(size / 2, 0);
62 for (int64 i = 0; i < size; i += 2) {
63 std::swap(bytes[i], bytes[i + 1]);
64 }
65 }
66
67 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
68 // able to transparently access the raw 16-bit value contained within.
69 template <typename T>
GetRawValue(T val)70 T GetRawValue(T val) {
71 return val;
72 }
GetRawValue(Eigen::half val)73 uint16 GetRawValue(Eigen::half val) { return val.x; }
74
75 } // namespace
76
~LiteralBase()77 LiteralBase::~LiteralBase() {}
78
operator <<(std::ostream & out,const Literal & literal)79 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
80 out << literal.ToString();
81 return out;
82 }
83
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)84 MutableLiteralBase::StrideConfig::StrideConfig(
85 const Shape& source_shape, const Shape& dest_shape,
86 absl::Span<const int64> dimensions)
87 : dimensions(dimensions),
88 base(dimensions.size(), 0),
89 step(dimensions.size(), 1) {
90 if (!dimensions.empty()) {
91 // Selects the shape with the largest minor dimension as the one upon
92 // which to run the tight stride loop.
93 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
94 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
95 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
96 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
97 } else {
98 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
99 source_stride =
100 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
101 }
102 minor_loop_size = dimensions[minor_dimension];
103 step[minor_dimension] = minor_loop_size;
104 }
105 }
106
Literal(const Shape & shape)107 Literal::Literal(const Shape& shape)
108 : Literal(shape, /*allocate_arrays=*/true) {}
109
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)110 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
111 if (shape.IsTuple()) {
112 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
113 const Shape& subshape = shape.tuple_shapes(i);
114
115 auto child_piece = Piece();
116 child_piece.set_subshape(&subshape);
117
118 SetPiece(subshape, &child_piece, allocate_arrays);
119
120 piece->emplace_back(std::move(child_piece));
121 }
122 } else if (shape.IsArray()) {
123 if (allocate_arrays) {
124 if (LayoutUtil::IsSparseArray(shape)) {
125 // For sparse arrays, the buffer must be of the size of the maximum
126 // number of sparse elements possible.
127 const int64 max_sparse_elements =
128 LayoutUtil::MaxSparseElements(shape.layout());
129 piece->set_buffer(
130 new char[max_sparse_elements *
131 ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
132 piece->set_sparse_indices(
133 new SparseIndexArray(max_sparse_elements, shape.rank()));
134 } else {
135 piece->set_buffer(new char[piece->size_bytes()]);
136 }
137 }
138 } else {
139 // If the shape is neither an array nor tuple, then it must be
140 // zero-sized. Otherwise, some memory needs to be allocated for it.
141 CHECK_EQ(piece->size_bytes(), 0);
142 }
143 }
144
Literal(const Shape & shape,bool allocate_arrays)145 Literal::Literal(const Shape& shape, bool allocate_arrays)
146 : MutableLiteralBase() {
147 shape_ = absl::make_unique<Shape>(shape);
148 CHECK(LayoutUtil::HasLayout(*shape_));
149 root_piece_ = new Piece();
150 root_piece_->set_subshape(shape_.get());
151 CHECK(&root_piece_->subshape() == shape_.get());
152
153 SetPiece(*shape_, root_piece_, allocate_arrays);
154 }
155
~Literal()156 Literal::~Literal() {
157 if (root_piece_ != nullptr) {
158 DeallocateBuffers();
159 delete root_piece_;
160 }
161 }
162
DeallocateBuffers()163 void Literal::DeallocateBuffers() {
164 root_piece_->ForEachMutableSubpiece(
165 [&](const ShapeIndex& index, Piece* piece) {
166 if (piece->buffer() != nullptr) {
167 delete[] piece->buffer();
168 delete piece->sparse_indices();
169 }
170 });
171 }
172
Literal(Literal && other)173 Literal::Literal(Literal&& other) : MutableLiteralBase() {
174 *this = std::move(other);
175 }
176
operator =(Literal && other)177 Literal& Literal::operator=(Literal&& other) {
178 DCHECK(&other.root_piece_->subshape() == other.shape_.get());
179 using std::swap;
180 swap(shape_, other.shape_);
181 swap(root_piece_, other.root_piece_);
182 DCHECK(&root_piece_->subshape() == shape_.get());
183
184 return *this;
185 }
186
CreateFromShape(const Shape & shape)187 Literal LiteralBase::CreateFromShape(const Shape& shape) {
188 Literal literal(shape);
189 literal.root_piece_->ForEachMutableSubpiece(
190 [&](const ShapeIndex& index, Piece* piece) {
191 if (piece->subshape().IsArray()) {
192 memset(piece->untyped_data(), 0, piece->size_bytes());
193 }
194 });
195 return literal;
196 }
197
sparse_indices(const ShapeIndex & shape_index) const198 const SparseIndexArray* LiteralBase::sparse_indices(
199 const ShapeIndex& shape_index) const {
200 return piece(shape_index).sparse_indices();
201 }
202
sparse_indices(const ShapeIndex & shape_index)203 SparseIndexArray* MutableLiteralBase::sparse_indices(
204 const ShapeIndex& shape_index) {
205 return piece(shape_index).sparse_indices();
206 }
207
208 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)209 Status MutableLiteralBase::CopySliceFromInternal(
210 const LiteralBase& src_literal, absl::Span<const int64> src_base,
211 absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
212 TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
213 TF_RET_CHECK(shape().rank() == dest_base.size());
214
215 auto linear_index = [](const Shape& shape,
216 absl::Span<const int64> multi_index) {
217 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
218 };
219
220 if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
221 // If any of the two shapes are scalars, we can just call the StridedCopy()
222 // directly, and we know we will be copying only one value.
223 TF_RET_CHECK(copy_size.empty());
224 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
225 src_literal.data<NativeT>(),
226 linear_index(src_literal.shape(), src_base), 0, 1);
227 } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
228 !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
229 // Perform copy if neither src nor dest has dimensions with zero element,
230 // otherwise it's a no-op.
231 TF_RET_CHECK(src_base.size() == dest_base.size());
232 TF_RET_CHECK(src_base.size() == copy_size.size());
233
234 // Scan the source from minor, stepping in copy size blocks, then within
235 // the index enumaration functor, do a strided copy advancing source index
236 // by one (walking through the minor dimension), and destination index by
237 // proper stride size at the matching dimension.
238 DimensionVector src_indexes(src_base.size(), 0);
239 DimensionVector dest_indexes(dest_base.size(), 0);
240 MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
241 copy_size);
242
243 auto copy_proc = [&](absl::Span<const int64> indexes) {
244 // Map from multi-dimensional index, to source index.
245 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
246 src_indexes.begin(), std::plus<int64>());
247 // Map from multi-dimensional index, to destination index.
248 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
249 dest_indexes.begin(), std::plus<int64>());
250
251 int64 src_index = linear_index(src_literal.shape(), src_indexes);
252 int64 dest_index = linear_index(shape(), dest_indexes);
253
254 // `this->` is needed to workaround MSVC bug: #16882
255 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
256 src_literal.data<NativeT>(), src_index,
257 stride_config.source_stride, stride_config.minor_loop_size);
258 return true;
259 };
260
261 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
262 stride_config.dimensions, stride_config.step,
263 copy_proc);
264 }
265 return Status::OK();
266 }
267
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)268 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
269 absl::Span<const int64> src_index,
270 absl::Span<const int64> dest_index) {
271 DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
272 const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
273 src_literal.shape(), src_index);
274 const int64 dest_linear_index =
275 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
276 const int64 primitive_size =
277 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
278
279 char* dest_address =
280 static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
281 const char* source_address =
282 static_cast<const char*>(src_literal.untyped_data()) +
283 src_linear_index * primitive_size;
284 if (dest_address != source_address) {
285 memcpy(dest_address, source_address, primitive_size);
286 }
287 return Status::OK();
288 }
289
CreateFromProto(const LiteralProto & proto)290 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
291 const LiteralProto& proto) {
292 if (!proto.has_shape()) {
293 return InvalidArgument("LiteralProto has no shape");
294 }
295 Shape shape(proto.shape());
296 if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) {
297 return InvalidArgument("Literal shape cannot include OPAQUE sub-shape");
298 }
299 if (!LayoutUtil::HasLayout(shape)) {
300 return InvalidArgument("LiteralProto has no layout");
301 }
302
303 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
304
305 Literal literal(shape);
306
307 TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
308 [&](const ShapeIndex& index, Piece* piece) {
309 const LiteralProto* proto_element = &proto;
310 for (int64 i : index) {
311 CHECK(i < proto_element->tuple_literals_size());
312 proto_element = &proto_element->tuple_literals(i);
313 }
314
315 if (piece->subshape().IsTuple()) {
316 if (proto_element->tuple_literals_size() !=
317 ShapeUtil::TupleElementCount(piece->subshape())) {
318 return InvalidArgument(
319 "Expected %d tuple elements in LiteralProto, has %d",
320 ShapeUtil::TupleElementCount(piece->subshape()),
321 proto_element->tuple_literals_size());
322 }
323 return Status::OK();
324 }
325 if (piece->subshape().element_type() == TOKEN) {
326 return Status::OK();
327 }
328
329 CHECK(piece->subshape().IsArray());
330 TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
331
332 return Status::OK();
333 }));
334
335 return std::move(literal);
336 }
337
DecomposeTuple()338 std::vector<Literal> Literal::DecomposeTuple() {
339 CHECK(shape().IsTuple());
340 std::vector<Literal> elements;
341 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
342 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
343 /*allocate_arrays=*/false));
344 Literal& element = elements.back();
345 element.root_piece_->ForEachMutableSubpiece(
346 [&](const ShapeIndex& index, Piece* dest_piece) {
347 ShapeIndex src_index = {i};
348 for (int64 j : index) {
349 src_index.push_back(j);
350 }
351 Piece& src_piece = piece(src_index);
352
353 // Move the respective buffer and sparse indices over to the element
354 // Literal.
355 dest_piece->set_buffer(src_piece.buffer());
356 src_piece.set_buffer(nullptr);
357 dest_piece->set_sparse_indices(src_piece.sparse_indices());
358 src_piece.set_sparse_indices(nullptr);
359 });
360 }
361 // Set this literal to be nil-shaped.
362 *this = Literal();
363 return elements;
364 }
365
366 namespace {
367
368 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
369 // the array slices are indicated by dest_shape and src_shape respectively.
370 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)371 void CopyElementsBetween(absl::Span<NativeT> dest,
372 absl::Span<const NativeT> src, const Shape& dest_shape,
373 const Shape& src_shape) {
374 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
375 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
376 return;
377 }
378 std::vector<int64> index(dest_shape.rank());
379 do {
380 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
381 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
382 } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
383 }
384
385 } // namespace
386
CopyFrom(const LiteralBase::Piece & src)387 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
388 CHECK(subshape_ != nullptr);
389 CHECK(src.subshape_ != nullptr);
390 if (ShapeUtil::Equal(subshape(), src.subshape())) {
391 // If the layouts are equal it's faster just to memcpy.
392 memcpy(buffer(), src.buffer(), src.size_bytes());
393 } else {
394 TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
395 std::vector<int64> origin(subshape().rank(), 0);
396 switch (subshape().element_type()) {
397 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
398 case (XLA_T): \
399 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
400 subshape(), src.subshape()); \
401 break;
402 COPY_ELEMENTS(U8, uint8);
403 COPY_ELEMENTS(U16, uint16);
404 COPY_ELEMENTS(U32, uint32);
405 COPY_ELEMENTS(U64, uint64);
406 COPY_ELEMENTS(S8, int8);
407 COPY_ELEMENTS(S16, int16);
408 COPY_ELEMENTS(S32, int32);
409 COPY_ELEMENTS(S64, int64);
410 COPY_ELEMENTS(F16, half);
411 COPY_ELEMENTS(BF16, bfloat16);
412 COPY_ELEMENTS(F32, float);
413 COPY_ELEMENTS(F64, double);
414 COPY_ELEMENTS(C64, complex64);
415 COPY_ELEMENTS(C128, complex128);
416 COPY_ELEMENTS(PRED, bool);
417 #undef COPY_ELEMENTS
418 default:
419 return Unimplemented(
420 "Copying a Literal object with element type %s is not implemented.",
421 PrimitiveType_Name(subshape().element_type()));
422 }
423 }
424 return Status::OK();
425 }
426
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index)427 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
428 const ShapeIndex& dest_shape_index,
429 const ShapeIndex& src_shape_index) {
430 const Shape& dest_subshape =
431 ShapeUtil::GetSubshape(shape(), dest_shape_index);
432 const Shape& src_subshape =
433 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
434 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
435 return InvalidArgument(
436 "Destination subshape incompatible with source subshape: %s vs %s",
437 ShapeUtil::HumanString(dest_subshape),
438 ShapeUtil::HumanString(src_subshape));
439 }
440 return root_piece_->ForEachMutableSubpieceWithStatus(
441 [&](const ShapeIndex& index, Piece* piece) {
442 if (!piece->subshape().IsArray()) {
443 return Status::OK();
444 }
445
446 // Determine if this index is in the part of this literal that we want
447 // to copy over from src_literal.
448 bool in_subtree_to_copy = true;
449 for (int i = 0; i < dest_shape_index.size(); ++i) {
450 if (index[i] != dest_shape_index[i]) {
451 in_subtree_to_copy = false;
452 break;
453 }
454 }
455 if (!in_subtree_to_copy) {
456 return Status::OK();
457 }
458 // Construct the index of the corresponding piece in the source literal.
459 ShapeIndex src_piece_index = src_shape_index;
460 for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
461 src_piece_index.push_back(index[i]);
462 }
463 TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
464 return Status::OK();
465 });
466 }
467
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)468 Status Literal::MoveFrom(Literal&& src_literal,
469 const ShapeIndex& dest_shape_index) {
470 const Shape& dest_subshape =
471 ShapeUtil::GetSubshape(shape(), dest_shape_index);
472 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
473 return InvalidArgument(
474 "Destination subshape not equal to source shape: %s vs %s",
475 ShapeUtil::HumanString(dest_subshape),
476 ShapeUtil::HumanString(src_literal.shape()));
477 }
478
479 src_literal.root_piece_->ForEachSubpiece(
480 [&](const ShapeIndex& src_index, const Piece& src_piece) {
481 if (!src_piece.subshape().IsArray()) {
482 return;
483 }
484
485 ShapeIndex dest_index = dest_shape_index;
486 for (int64 i : src_index) {
487 dest_index.push_back(i);
488 }
489 Piece& dest_piece = piece(dest_index);
490 delete[] dest_piece.buffer();
491 dest_piece.set_buffer(src_piece.buffer());
492 delete dest_piece.sparse_indices();
493 dest_piece.set_sparse_indices(src_piece.sparse_indices());
494 });
495
496 src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
497 delete src_literal.root_piece_;
498 src_literal.root_piece_ = new LiteralBase::Piece();
499 src_literal.root_piece_->set_subshape(src_literal.shape_.get());
500
501 return Status::OK();
502 }
503
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)504 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
505 absl::Span<const int64> src_base,
506 absl::Span<const int64> dest_base,
507 absl::Span<const int64> copy_size) {
508 TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
509 TF_RET_CHECK(src_literal.shape().IsArray())
510 << ShapeUtil::HumanString(src_literal.shape());
511 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
512
513 switch (shape().element_type()) {
514 case U8:
515 return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
516 copy_size);
517 case U16:
518 return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
519 copy_size);
520 case U32:
521 return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
522 copy_size);
523 case U64:
524 return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
525 copy_size);
526 case S8:
527 return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
528 copy_size);
529 case S16:
530 return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
531 copy_size);
532 case S32:
533 return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
534 copy_size);
535 case S64:
536 return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
537 copy_size);
538 case F16:
539 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
540 copy_size);
541 case BF16:
542 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
543 copy_size);
544 case F32:
545 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
546 copy_size);
547 case F64:
548 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
549 copy_size);
550 case C64:
551 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
552 copy_size);
553 case C128:
554 return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
555 copy_size);
556 case PRED:
557 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
558 copy_size);
559 default:
560 break;
561 }
562 return Unimplemented(
563 "Copying a slice from a Literal object with element type %d is not "
564 "implemented.",
565 shape().element_type());
566 }
567
PopulateR1(const tensorflow::core::Bitmap & values)568 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
569 CHECK(shape().IsArray());
570 CHECK_EQ(shape().rank(), 1);
571 CHECK_EQ(element_count(), values.bits());
572 CHECK_EQ(shape().element_type(), PRED);
573 for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
574 Set({i}, values.get(i));
575 }
576 }
577
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const578 Literal LiteralBase::Relayout(const Layout& new_layout,
579 const ShapeIndex& shape_index) const {
580 // Create new shape with 'new_layout' set at the given shape index.
581 Shape new_shape = shape();
582 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
583 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
584 *subshape->mutable_layout() = new_layout;
585 Literal result(new_shape);
586 TF_CHECK_OK(result.CopyFrom(*this));
587 return result;
588 }
589
Relayout(const Shape & shape_with_layout) const590 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
591 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
592 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
593 << " not compatible with literal shape "
594 << ShapeUtil::HumanString(shape());
595 Literal result = CreateFromShape(shape_with_layout);
596 ShapeUtil::ForEachSubshape(
597 result.shape(),
598 [this, &result](const Shape& subshape, const ShapeIndex& index) {
599 if (subshape.IsArray()) {
600 TF_CHECK_OK(result.CopyFrom(*this,
601 /*dest_shape_index=*/index,
602 /*src_shape_index=*/index));
603 }
604 });
605 return result;
606 }
607
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const608 StatusOr<Literal> LiteralBase::Broadcast(
609 const Shape& result_shape, absl::Span<const int64> dimensions) const {
610 if (!shape().IsArray()) {
611 return InvalidArgument("Broadcast only supports arrays.");
612 }
613
614 for (int64 i = 0; i < dimensions.size(); i++) {
615 TF_RET_CHECK(shape().dimensions(i) ==
616 result_shape.dimensions(dimensions[i]));
617 }
618
619 Literal result(result_shape);
620
621 // scratch_source_index is temporary storage space for the computed index into
622 // the input literal. We put it here to avoid allocating an std::vector in
623 // every iteration of ShapeUtil::ForEachIndex.
624 std::vector<int64> scratch_source_index(shape().dimensions_size());
625
626 char* dest_data = static_cast<char*>(result.untyped_data());
627 const char* source_data = static_cast<const char*>(untyped_data());
628 const int64 primitive_size =
629 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
630
631 ShapeUtil::ForEachIndex(
632 result_shape, [&](absl::Span<const int64> output_index) {
633 for (int64 i = 0; i < dimensions.size(); ++i) {
634 scratch_source_index[i] = output_index[dimensions[i]];
635 }
636 int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
637 result_shape, output_index);
638 int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
639 shape(), scratch_source_index);
640 memcpy(dest_data + primitive_size * dest_index,
641 source_data + primitive_size * source_index, primitive_size);
642 return true;
643 });
644
645 return std::move(result);
646 }
647
Reshape(absl::Span<const int64> dimensions) const648 StatusOr<Literal> LiteralBase::Reshape(
649 absl::Span<const int64> dimensions) const {
650 if (!shape().IsArray()) {
651 return InvalidArgument("Reshape does not support tuples.");
652 }
653 Literal output;
654 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
655 output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
656 } else {
657 output = Clone();
658 }
659 // Because the layout is monotonic, we can simply reuse the same sequence of
660 // values without changing their order.
661 *output.mutable_shape_do_not_use() =
662 ShapeUtil::MakeShape(shape().element_type(), dimensions);
663
664 int64 elements_before = ShapeUtil::ElementsIn(shape());
665 int64 elements_after = ShapeUtil::ElementsIn(output.shape());
666 if (elements_before != elements_after) {
667 return InvalidArgument(
668 "Shapes before and after Literal::Reshape have different numbers "
669 "of elements: %s vs %s.",
670 ShapeUtil::HumanString(shape()),
671 ShapeUtil::HumanString(output.shape()));
672 }
673 return std::move(output);
674 }
675
Transpose(absl::Span<const int64> permutation) const676 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
677 CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
678 CHECK(IsPermutation(permutation, shape().rank()))
679 << "Given permutation is not a permutation of dimension numbers";
680 // To transpose the array, we just permute the dimensions and layout, and
681 // do a straight memory copy of the raw data set.
682 // This is considerably faster than iterating over every array element using
683 // the EachCell<>() and Set<>() APIs.
684 std::vector<int64> inverse_permutation = InversePermutation(permutation);
685 Shape permuted_shape =
686 ShapeUtil::PermuteDimensions(inverse_permutation, shape());
687 // Replace the layout with one affine to this shape, such that a
688 // transpose operation can be performed by leaving the flat values
689 // representation intact.
690 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
691 // The shape with affine layout resulting from that operation will be
692 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
693 // most minor.
694 //
695 // Essentially, given MinMaj(Di) the position of the Di dimension within the
696 // minor to major vector, and given T(Di) the index that the original Di
697 // dimension has within the transposed array, a layout is affine if
698 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
699 // vector of the affine layout.
700 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
701 Layout* layout = permuted_shape.mutable_layout();
702 layout->clear_minor_to_major();
703 for (auto index : LayoutUtil::MinorToMajor(shape())) {
704 layout->add_minor_to_major(inverse_permutation[index]);
705 }
706 Literal new_literal(permuted_shape);
707 DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
708 ShapeUtil::ByteSizeOf(shape()));
709 std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
710 return new_literal;
711 }
712
713 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const714 Literal LiteralBase::SliceInternal(
715 const Shape& result_shape, absl::Span<const int64> start_indices) const {
716 Literal result_literal(result_shape);
717 DimensionVector new_indices(result_shape.rank());
718 result_literal.EachCell<NativeT>(
719 [&](absl::Span<const int64> indices, NativeT /*value*/) {
720 for (int64 i = 0; i < result_shape.rank(); ++i) {
721 new_indices[i] = indices[i] + start_indices[i];
722 }
723 NativeT value = Get<NativeT>(new_indices);
724 result_literal.Set<NativeT>(indices, value);
725 });
726 return result_literal;
727 }
728
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const729 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
730 absl::Span<const int64> limit_indices) const {
731 CHECK(shape().IsArray()) << "tuple is not supported for slice";
732
733 DimensionVector result_dimensions;
734 for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
735 CHECK_GE(start_indices[dnum], 0);
736 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
737 << "dnum = " << dnum;
738 int64 dimension = limit_indices[dnum] - start_indices[dnum];
739 CHECK_GE(dimension, 0) << "dnum = " << dnum;
740 result_dimensions.push_back(dimension);
741 }
742 const auto result_shape =
743 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
744 LayoutUtil::MinorToMajor(shape()));
745 switch (result_shape.element_type()) {
746 case PRED:
747 return SliceInternal<bool>(result_shape, start_indices);
748 case U8:
749 return SliceInternal<uint8>(result_shape, start_indices);
750 case U16:
751 return SliceInternal<uint16>(result_shape, start_indices);
752 case U32:
753 return SliceInternal<uint32>(result_shape, start_indices);
754 case U64:
755 return SliceInternal<uint64>(result_shape, start_indices);
756 case S8:
757 return SliceInternal<int8>(result_shape, start_indices);
758 case S16:
759 return SliceInternal<int16>(result_shape, start_indices);
760 case S32:
761 return SliceInternal<int32>(result_shape, start_indices);
762 case S64:
763 return SliceInternal<int64>(result_shape, start_indices);
764 case F16:
765 return SliceInternal<half>(result_shape, start_indices);
766 case BF16:
767 return SliceInternal<bfloat16>(result_shape, start_indices);
768 case F32:
769 return SliceInternal<float>(result_shape, start_indices);
770 case F64:
771 return SliceInternal<double>(result_shape, start_indices);
772 case C64:
773 return SliceInternal<complex64>(result_shape, start_indices);
774 case C128:
775 return SliceInternal<complex128>(result_shape, start_indices);
776 default:
777 LOG(FATAL) << "not yet implemented: "
778 << PrimitiveType_Name(result_shape.element_type());
779 }
780 }
781
Clone() const782 Literal LiteralBase::Clone() const {
783 Literal result(shape());
784 TF_CHECK_OK(result.CopyFrom(*this));
785 return result;
786 }
787
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const788 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
789 const ShapeIndex& shape_index) const {
790 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
791 CHECK(LayoutUtil::IsDenseArray(subshape));
792 switch (subshape.element_type()) {
793 case PRED:
794 return Get<bool>(multi_index, shape_index) ? "true" : "false";
795 case S8:
796 return StrCat(Get<int8>(multi_index, shape_index));
797 case S16:
798 return StrCat(Get<int16>(multi_index, shape_index));
799 case S32:
800 return StrCat(Get<int32>(multi_index, shape_index));
801 case S64:
802 return StrCat(Get<int64>(multi_index, shape_index));
803 case U8:
804 return StrCat(Get<uint8>(multi_index, shape_index));
805 case U16:
806 return StrCat(Get<uint16>(multi_index, shape_index));
807 case U32:
808 return StrCat(Get<uint32>(multi_index, shape_index));
809 case U64:
810 return StrCat(Get<uint64>(multi_index, shape_index));
811 case F16:
812 return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
813 case F32:
814 return StrCat(Get<float>(multi_index, shape_index));
815 case BF16:
816 return StrCat(
817 static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
818 case F64:
819 return StrCat(Get<double>(multi_index, shape_index));
820 case C64: {
821 complex64 c = Get<complex64>(multi_index, shape_index);
822 return StrCat("(", c.real(), ", ", c.imag(), ")");
823 }
824 case C128: {
825 complex128 c = Get<complex128>(multi_index, shape_index);
826 return StrCat("(", c.real(), ", ", c.imag(), ")");
827 }
828 default:
829 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
830 }
831 }
832
GetSparseElementAsString(int64 sparse_element_number,const ShapeIndex & shape_index) const833 string LiteralBase::GetSparseElementAsString(
834 int64 sparse_element_number, const ShapeIndex& shape_index) const {
835 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
836 CHECK(LayoutUtil::IsSparseArray(subshape));
837 switch (subshape.element_type()) {
838 case PRED:
839 return GetSparseElement<bool>(sparse_element_number, shape_index)
840 ? "true"
841 : "false";
842 case S8:
843 return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
844 case S16:
845 return StrCat(
846 GetSparseElement<int16>(sparse_element_number, shape_index));
847 case S32:
848 return StrCat(
849 GetSparseElement<int32>(sparse_element_number, shape_index));
850 case S64:
851 return StrCat(
852 GetSparseElement<int64>(sparse_element_number, shape_index));
853 case U8:
854 return StrCat(
855 GetSparseElement<uint8>(sparse_element_number, shape_index));
856 case U16:
857 return StrCat(
858 GetSparseElement<uint16>(sparse_element_number, shape_index));
859 case U32:
860 return StrCat(
861 GetSparseElement<uint32>(sparse_element_number, shape_index));
862 case U64:
863 return StrCat(
864 GetSparseElement<uint64>(sparse_element_number, shape_index));
865 case F16:
866 return StrCat(static_cast<float>(
867 GetSparseElement<half>(sparse_element_number, shape_index)));
868 case F32:
869 return StrCat(
870 GetSparseElement<float>(sparse_element_number, shape_index));
871 case BF16:
872 return StrCat(static_cast<float>(
873 GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
874 case F64:
875 return StrCat(
876 GetSparseElement<double>(sparse_element_number, shape_index));
877 case C64: {
878 complex64 c =
879 GetSparseElement<complex64>(sparse_element_number, shape_index);
880 return StrCat("(", c.real(), ", ", c.imag(), ")");
881 }
882 case C128: {
883 complex128 c =
884 GetSparseElement<complex128>(sparse_element_number, shape_index);
885 return StrCat("(", c.real(), ", ", c.imag(), ")");
886 }
887 default:
888 LOG(FATAL) << "Invalid element type for sparse arrays: "
889 << PrimitiveType_Name(subshape.element_type());
890 }
891 }
892
GetIntegralAsS64(absl::Span<const int64> multi_index) const893 StatusOr<int64> LiteralBase::GetIntegralAsS64(
894 absl::Span<const int64> multi_index) const {
895 CHECK(LayoutUtil::IsDenseArray(shape()));
896 switch (shape().element_type()) {
897 case PRED:
898 return Get<bool>(multi_index);
899 case U8:
900 return Get<uint8>(multi_index);
901 case S32:
902 return Get<int32>(multi_index);
903 case S64:
904 return Get<int64>(multi_index);
905 case U32:
906 return Get<uint32>(multi_index);
907 case U64:
908 return Get<uint64>(multi_index);
909 default:
910 return FailedPrecondition("Array element type is not integral: %s",
911 PrimitiveType_Name(shape().element_type()));
912 }
913 }
914
Hash() const915 size_t LiteralBase::Hash() const {
916 using tensorflow::Hash64;
917 using tensorflow::Hash64Combine;
918
919 size_t hash_value = ShapeUtil::Hash(shape());
920
921 ShapeUtil::ForEachSubshape(
922 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
923 if (!subshape.IsArray()) {
924 return;
925 }
926
927 CHECK(LayoutUtil::IsDense(subshape.layout()));
928 hash_value = Hash64Combine(
929 hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
930 size_bytes(index)));
931 });
932
933 return hash_value;
934 }
935
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)936 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
937 int64 value) {
938 CHECK(LayoutUtil::IsDenseArray(shape()));
939 switch (shape().element_type()) {
940 case PRED:
941 Set<bool>(multi_index, value);
942 break;
943 case U8:
944 Set<uint8>(multi_index, value);
945 break;
946 case S32:
947 Set<int32>(multi_index, value);
948 break;
949 case S64:
950 Set<int64>(multi_index, value);
951 break;
952 case U32:
953 Set<uint32>(multi_index, value);
954 break;
955 case U64:
956 Set<uint64>(multi_index, value);
957 break;
958 default:
959 return FailedPrecondition("Array element type is not integral: %s",
960 PrimitiveType_Name(shape().element_type()));
961 }
962 return Status::OK();
963 }
964
GetSparseIndex(int64 sparse_element_number,const ShapeIndex & shape_index) const965 absl::Span<const int64> LiteralBase::GetSparseIndex(
966 int64 sparse_element_number, const ShapeIndex& shape_index) const {
967 const Piece& p = piece(shape_index);
968 CHECK_GE(sparse_element_number, 0);
969 CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
970 return p.sparse_indices()->At(sparse_element_number);
971 }
972
SortSparseElements(const ShapeIndex & shape_index)973 void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
974 piece(shape_index).SortSparseElements();
975 }
976
SortSparseElements()977 void LiteralBase::Piece::SortSparseElements() {
978 switch (subshape().element_type()) {
979 case PRED:
980 SortSparseElementsInternal<bool>();
981 break;
982 case S8:
983 SortSparseElementsInternal<int8>();
984 break;
985 case U8:
986 SortSparseElementsInternal<uint8>();
987 break;
988 case S16:
989 SortSparseElementsInternal<int16>();
990 break;
991 case U16:
992 SortSparseElementsInternal<uint16>();
993 break;
994 case S32:
995 SortSparseElementsInternal<int32>();
996 break;
997 case U32:
998 SortSparseElementsInternal<uint32>();
999 break;
1000 case S64:
1001 SortSparseElementsInternal<int64>();
1002 break;
1003 case U64:
1004 SortSparseElementsInternal<uint64>();
1005 break;
1006 case F32:
1007 SortSparseElementsInternal<float>();
1008 break;
1009 case F64:
1010 SortSparseElementsInternal<double>();
1011 break;
1012 case C64:
1013 SortSparseElementsInternal<complex64>();
1014 break;
1015 case C128:
1016 SortSparseElementsInternal<complex128>();
1017 break;
1018 case F16:
1019 SortSparseElementsInternal<half>();
1020 break;
1021 case BF16:
1022 SortSparseElementsInternal<bfloat16>();
1023 break;
1024 default:
1025 LOG(FATAL) << "Element type not valid for sparse array: "
1026 << PrimitiveType_Name(subshape().element_type());
1027 }
1028 }
1029
1030 template <typename NativeT>
SortSparseElementsInternal()1031 void LiteralBase::Piece::SortSparseElementsInternal() {
1032 CHECK(LayoutUtil::IsSparseArray(subshape()));
1033 int64 num_elements = sparse_indices()->index_count();
1034 auto values = data<NativeT>();
1035 CHECK_LE(num_elements, values.size());
1036 sparse_indices()->SortWithValues(
1037 absl::Span<NativeT>(values.data(), num_elements));
1038 }
1039
1040 namespace {
1041
ShapeToString(bool print_layout,const Shape & shape)1042 string ShapeToString(bool print_layout, const Shape& shape) {
1043 return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1044 : ShapeUtil::HumanString(shape);
1045 }
1046
1047 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1048 bool print_shape, bool print_layout,
1049 std::vector<string>* pieces);
1050
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1051 void TupleToStringHelper(const LiteralBase& literal,
1052 const ShapeIndex& shape_index, bool print_shape,
1053 bool print_layout, std::vector<string>* pieces) {
1054 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1055 pieces->push_back("(\n");
1056 std::vector<string> tuple_pieces;
1057 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1058 ShapeIndex element_index = shape_index;
1059 element_index.push_back(i);
1060 std::vector<string> element_pieces;
1061 ToStringHelper(literal, element_index, print_shape, print_layout,
1062 &element_pieces);
1063 tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1064 }
1065 pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1066 pieces->push_back("\n)");
1067 }
1068
SparseArrayToStringHelper(const LiteralBase & literal,const Shape & subshape,bool print_shape,bool print_layout,std::vector<string> * pieces)1069 void SparseArrayToStringHelper(const LiteralBase& literal,
1070 const Shape& subshape, bool print_shape,
1071 bool print_layout, std::vector<string>* pieces) {
1072 if (print_shape) {
1073 pieces->push_back(ShapeToString(print_layout, subshape));
1074 }
1075 pieces->push_back("{");
1076 int64 rank = subshape.rank();
1077 int64 num_elements = literal.sparse_element_count();
1078 for (int64 i = 0; i < num_elements; ++i) {
1079 if (i > 0) {
1080 pieces->push_back(", ");
1081 }
1082 if (rank == 1) {
1083 pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
1084 pieces->push_back(": ");
1085 } else {
1086 pieces->push_back("[");
1087 pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
1088 pieces->push_back("]: ");
1089 }
1090 pieces->push_back(literal.GetSparseElementAsString(i));
1091 }
1092 pieces->push_back("}");
1093 }
1094
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1095 void DenseArrayToStringHelper(const LiteralBase& literal,
1096 const ShapeIndex& shape_index, bool print_shape,
1097 bool print_layout, std::vector<string>* pieces) {
1098 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1099 int64 rank = subshape.rank();
1100
1101 std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1102 to_string_recursive = [&](absl::Span<const int64> dimensions,
1103 std::vector<int64>* accum_indices) {
1104 // dimensions.size() decreases by 1 at each recursive call,
1105 // and accum_indices->size() increases by 1.
1106 // Their sum is equal to the rank of the tensor.
1107 CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1108
1109 auto brace_to_string = [&](string brace) -> string {
1110 // Handle 1D tensor
1111 if (rank == 1) {
1112 return brace;
1113 }
1114 // Handle the innermost tensor of a 2D+ tensor.
1115 if (dimensions.size() == 1 && brace == "{") {
1116 return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
1117 }
1118 if (dimensions.size() == 1 && brace == "}") {
1119 return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1120 }
1121 // Handle the non-innermost tensors of a 2D+ tensor.
1122 if (brace == "{") {
1123 if (rank > 3 && !accum_indices->empty() &&
1124 accum_indices->size() < rank) {
1125 int index = accum_indices->size() - 1;
1126 int value = accum_indices->back();
1127 return StrCat(brace, " /*i", index, "=", value, "*/\n");
1128 }
1129 return StrCat(brace, "\n");
1130 }
1131 return StrCat("\n", brace);
1132 };
1133
1134 if (dimensions.empty()) {
1135 // Display predicates as 0s and 1s so that the string is more dense.
1136 string elem;
1137 if (subshape.element_type() == PRED && rank > 0) {
1138 elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1139 } else {
1140 elem = literal.GetAsString(*accum_indices, shape_index);
1141 }
1142 pieces->push_back(elem);
1143 } else {
1144 pieces->push_back(brace_to_string("{"));
1145 for (int i = 0; i < dimensions[0]; ++i) {
1146 std::vector<int64> cloned_indices(*accum_indices);
1147 cloned_indices.push_back(i);
1148 to_string_recursive(dimensions.subspan(1), &cloned_indices);
1149 if (i < dimensions[0] - 1) {
1150 pieces->push_back(",");
1151 pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1152 }
1153 }
1154 pieces->push_back(brace_to_string("}"));
1155 }
1156 };
1157
1158 if (print_shape) {
1159 pieces->push_back(ShapeToString(print_layout, subshape));
1160 pieces->push_back(" ");
1161 }
1162 std::vector<int64> indices = {};
1163 std::vector<int64> dimensions(subshape.dimensions().begin(),
1164 subshape.dimensions().end());
1165 to_string_recursive(dimensions, &indices);
1166 }
1167
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1168 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1169 bool print_shape, bool print_layout,
1170 std::vector<string>* pieces) {
1171 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1172 CHECK(LayoutUtil::HasLayout(literal.shape()));
1173 CHECK(LayoutUtil::HasLayout(subshape));
1174 if (subshape.IsTuple()) {
1175 TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1176 pieces);
1177 } else if (subshape.IsToken()) {
1178 pieces->push_back("token");
1179 } else if (LayoutUtil::IsSparseArray(subshape)) {
1180 SparseArrayToStringHelper(literal, subshape, print_shape, print_layout,
1181 pieces);
1182 } else {
1183 CHECK(LayoutUtil::IsDenseArray(subshape));
1184 DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1185 pieces);
1186 }
1187 }
1188
1189 } // namespace
1190
sparse_element_count() const1191 int64 LiteralBase::sparse_element_count() const {
1192 CHECK(LayoutUtil::IsSparseArray(shape()));
1193 return sparse_indices()->index_count();
1194 }
1195
ToString() const1196 string LiteralBase::ToString() const {
1197 std::vector<string> pieces;
1198 CHECK(LayoutUtil::HasLayout(this->shape()));
1199 ToStringHelper(*this, {}, /*print_shape=*/true,
1200 /*print_layout=*/false, &pieces);
1201 return absl::StrJoin(pieces, "");
1202 }
1203
ToStringWithoutShape() const1204 string LiteralBase::ToStringWithoutShape() const {
1205 std::vector<string> pieces;
1206 CHECK(LayoutUtil::HasLayout(this->shape()));
1207 ToStringHelper(*this, {}, /*print_shape=*/false,
1208 /*print_layout=*/false, &pieces);
1209 return absl::StrJoin(pieces, "");
1210 }
1211
ToStringWithLayout() const1212 string LiteralBase::ToStringWithLayout() const {
1213 std::vector<string> pieces;
1214 CHECK(LayoutUtil::HasLayout(this->shape()));
1215 ToStringHelper(*this, {}, /*print_shape=*/true,
1216 /*print_layout=*/true, &pieces);
1217 return absl::StrJoin(pieces, "");
1218 }
1219
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1220 void LiteralBase::EachCellAsString(
1221 const std::function<void(absl::Span<const int64> indices,
1222 const string& value)>& per_cell) const {
1223 if (ShapeUtil::IsZeroElementArray(shape())) {
1224 return;
1225 }
1226 std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1227 shape(), /*linear_index=*/0);
1228 do {
1229 per_cell(indices, GetAsString(indices));
1230 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1231 }
1232
1233 namespace {
1234 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1235 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1236 const ConverterType& converter) {
1237 CHECK(src_literal.shape().IsArray());
1238 Literal result_literal(ShapeUtil::ChangeElementType(
1239 src_literal.shape(),
1240 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1241 auto src_data = src_literal.data<NativeSrcT>();
1242 auto dest_data = result_literal.template data<NativeDestT>();
1243 int64 num_elements = src_literal.element_count();
1244
1245 for (int64 i = 0; i < num_elements; ++i) {
1246 dest_data[i] = converter(src_data[i]);
1247 }
1248 return result_literal;
1249 }
1250
1251 template <typename NativeSrcT, typename NativeDestT>
1252 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1253 (std::is_same<NativeDestT, complex64>::value ||
1254 std::is_same<NativeDestT, complex128>::value),
1255 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1256 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1257 auto converter = [](NativeSrcT src) {
1258 return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1259 };
1260 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1261 src_literal, converter);
1262 }
1263
1264 template <typename NativeSrcT, typename NativeDestT>
1265 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1266 (!std::is_same<NativeDestT, complex64>::value &&
1267 !std::is_same<NativeDestT, complex128>::value),
1268 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1269 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1270 auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1271 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1272 src_literal, converter);
1273 }
1274
1275 template <typename NativeSrcT, typename NativeDestT>
1276 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1277 !std::is_same<NativeDestT, Eigen::half>::value),
1278 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1279 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1280 auto converter = [](NativeSrcT src) {
1281 return absl::bit_cast<NativeDestT>(GetRawValue(src));
1282 };
1283 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1284 src_literal, converter);
1285 }
1286
1287 template <typename NativeSrcT, typename NativeDestT>
1288 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1289 std::is_same<NativeDestT, Eigen::half>::value),
1290 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1291 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1292 // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1293 // cast to unsigned short and then use raw_uint16_to_half.
1294 auto converter = [](NativeSrcT src) {
1295 return Eigen::half_impl::raw_uint16_to_half(
1296 absl::bit_cast<uint16>(GetRawValue(src)));
1297 };
1298 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1299 src_literal, converter);
1300 }
1301
1302 // This template specialization is here to make the compiler happy. bit_cast has
1303 // a static check that the types are the same size. This specialization should
1304 // never be used because the source and destination types are checked for
1305 // identical sizes higher up.
1306 template <typename NativeSrcT, typename NativeDestT>
1307 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1308 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1309 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1310 LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1311 }
1312
1313 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1314 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1315 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1316 if (bitcast) {
1317 return BitcastBetweenNativeTypes<
1318 typename primitive_util::PrimitiveTypeToNative<
1319 primitive_src_type>::type,
1320 typename primitive_util::PrimitiveTypeToNative<
1321 primitive_dest_type>::type>(src_literal);
1322 } else {
1323 return ConvertBetweenNativeTypes<
1324 typename primitive_util::PrimitiveTypeToNative<
1325 primitive_src_type>::type,
1326 typename primitive_util::PrimitiveTypeToNative<
1327 primitive_dest_type>::type>(src_literal);
1328 }
1329 }
1330
1331 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1332 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1333 PrimitiveType primitive_dest_type,
1334 bool bitcast) {
1335 switch (primitive_dest_type) {
1336 #define CONVERT_IF_TYPES_MATCH(type) \
1337 case (type): \
1338 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1339 bitcast);
1340 CONVERT_IF_TYPES_MATCH(PRED)
1341 CONVERT_IF_TYPES_MATCH(S8)
1342 CONVERT_IF_TYPES_MATCH(S16)
1343 CONVERT_IF_TYPES_MATCH(S32)
1344 CONVERT_IF_TYPES_MATCH(S64)
1345 CONVERT_IF_TYPES_MATCH(U8)
1346 CONVERT_IF_TYPES_MATCH(U16)
1347 CONVERT_IF_TYPES_MATCH(U32)
1348 CONVERT_IF_TYPES_MATCH(U64)
1349 CONVERT_IF_TYPES_MATCH(F16)
1350 CONVERT_IF_TYPES_MATCH(F32)
1351 CONVERT_IF_TYPES_MATCH(F64)
1352 CONVERT_IF_TYPES_MATCH(BF16)
1353 #undef CONVERT_IF_TYPES_MATCH
1354 case C64:
1355 if (bitcast) {
1356 break;
1357 }
1358 return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1359 case C128:
1360 if (bitcast) {
1361 break;
1362 }
1363 return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1364 // Other types are not yet supported.
1365 default:
1366 break;
1367 }
1368 return Unimplemented("Converting from type %s to type %s is not implemented.",
1369 PrimitiveType_Name(src_literal.shape().element_type()),
1370 PrimitiveType_Name(primitive_dest_type));
1371 }
1372
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1373 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1374 PrimitiveType primitive_dest_type,
1375 bool bitcast) {
1376 TF_RET_CHECK(literal.shape().IsArray());
1377 if (literal.shape().element_type() == primitive_dest_type) {
1378 return literal.Clone();
1379 }
1380 switch (literal.shape().element_type()) {
1381 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1382 case (type): \
1383 return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1384 bitcast);
1385 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1386 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1387 CONVERT_IF_DEST_TYPE_MATCHES(S16)
1388 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1389 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1390 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1391 CONVERT_IF_DEST_TYPE_MATCHES(U16)
1392 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1393 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1394 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1395 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1396 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1397 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1398 #undef CONVERT_IF_DEST_TYPE_MATCHES
1399 // Other types are not yet supported.
1400 default:
1401 return Unimplemented("%s from type %s to type %s is not implemented.",
1402 (bitcast ? "Bitcast converting" : "Converting"),
1403 PrimitiveType_Name(literal.shape().element_type()),
1404 PrimitiveType_Name(primitive_dest_type));
1405 }
1406 }
1407
1408 } // namespace
1409
Convert(PrimitiveType primitive_dest_type) const1410 StatusOr<Literal> LiteralBase::Convert(
1411 PrimitiveType primitive_dest_type) const {
1412 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1413 }
1414
BitcastConvert(PrimitiveType primitive_dest_type) const1415 StatusOr<Literal> LiteralBase::BitcastConvert(
1416 PrimitiveType primitive_dest_type) const {
1417 if (primitive_util::BitWidth(shape().element_type()) !=
1418 primitive_util::BitWidth(primitive_dest_type)) {
1419 return InvalidArgument(
1420 "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1421 "%d",
1422 PrimitiveType_Name(shape().element_type()),
1423 PrimitiveType_Name(primitive_dest_type),
1424 primitive_util::BitWidth(shape().element_type()),
1425 primitive_util::BitWidth(primitive_dest_type));
1426 }
1427 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1428 }
1429
ConvertToShape(const Shape & dest_shape) const1430 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1431 if (!dest_shape.IsTuple()) {
1432 return Convert(dest_shape.element_type());
1433 }
1434 std::vector<Literal> elements;
1435 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1436 auto element = LiteralSlice(*this, {i});
1437 TF_ASSIGN_OR_RETURN(
1438 auto new_element,
1439 element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1440 elements.push_back(std::move(new_element));
1441 }
1442 return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1443 }
1444
MoveIntoTuple(absl::Span<Literal> elements)1445 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1446 absl::Span<Literal> elements) {
1447 std::vector<Shape> element_shapes;
1448 for (const Literal& element : elements) {
1449 element_shapes.push_back(element.shape());
1450 }
1451 Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1452 /*allocate_arrays=*/false);
1453 for (int i = 0; i < elements.size(); ++i) {
1454 TF_CHECK_OK(
1455 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1456 }
1457 return literal;
1458 }
1459
1460 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1461 bool LiteralBase::Piece::EqualElementsInternal(
1462 const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1463 if (multi_index->size() == subshape().rank()) {
1464 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1465 }
1466 for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
1467 multi_index->push_back(i);
1468 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1469 return false;
1470 }
1471 multi_index->pop_back();
1472 }
1473 return true;
1474 }
1475
EqualElements(const LiteralBase::Piece & other) const1476 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1477 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1478
1479 if (ShapeUtil::Equal(subshape(), other.subshape()) &&
1480 LayoutUtil::IsDenseArray(subshape())) {
1481 CHECK_EQ(size_bytes(), other.size_bytes());
1482 return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1483 }
1484
1485 std::vector<int64> multi_index;
1486 switch (subshape().element_type()) {
1487 case PRED:
1488 return EqualElementsInternal<bool>(other, &multi_index);
1489 case U8:
1490 return EqualElementsInternal<uint8>(other, &multi_index);
1491 case S16:
1492 return EqualElementsInternal<int16>(other, &multi_index);
1493 case S32:
1494 return EqualElementsInternal<int32>(other, &multi_index);
1495 case S64:
1496 return EqualElementsInternal<int64>(other, &multi_index);
1497 case U16:
1498 return EqualElementsInternal<uint16>(other, &multi_index);
1499 case U32:
1500 return EqualElementsInternal<uint32>(other, &multi_index);
1501 case U64:
1502 return EqualElementsInternal<uint64>(other, &multi_index);
1503 case F32:
1504 return EqualElementsInternal<float>(other, &multi_index);
1505 case F64:
1506 return EqualElementsInternal<double>(other, &multi_index);
1507 case F16:
1508 return EqualElementsInternal<half>(other, &multi_index);
1509 case BF16:
1510 return EqualElementsInternal<bfloat16>(other, &multi_index);
1511 case C64:
1512 return EqualElementsInternal<complex64>(other, &multi_index);
1513 case C128:
1514 return EqualElementsInternal<complex128>(other, &multi_index);
1515 default:
1516 LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1517 << PrimitiveType_Name(subshape().element_type());
1518 }
1519 }
1520
operator ==(const LiteralBase & other) const1521 bool LiteralBase::operator==(const LiteralBase& other) const {
1522 if (!ShapeUtil::Compatible(shape(), other.shape())) {
1523 return false;
1524 }
1525
1526 return root_piece().ForEachSubpieceWithBool(
1527 [&](const ShapeIndex& index, const Piece& piece) {
1528 if (!piece.subshape().IsArray()) {
1529 return true;
1530 }
1531
1532 const Piece& other_piece = other.piece(index);
1533 if (!piece.EqualElements(other_piece)) {
1534 return false;
1535 }
1536 return true;
1537 });
1538 }
1539
1540 namespace {
1541
1542 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1543 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1544 NativeT value) {
1545 for (int64 i = 0; i < data.size(); ++i) {
1546 if (data[i] != value) {
1547 return false;
1548 }
1549 }
1550 return true;
1551 }
1552
1553 } // namespace
1554
IsAll(int8 value) const1555 bool LiteralBase::IsAll(int8 value) const {
1556 return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1557 const Piece& piece) {
1558 if (!piece.subshape().IsArray()) {
1559 return true;
1560 }
1561
1562 auto piece_is_all = [&]() {
1563 switch (shape().element_type()) {
1564 case U8:
1565 if (value >= 0) {
1566 return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1567 }
1568 return false;
1569 case U16:
1570 if (value >= 0) {
1571 return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1572 }
1573 return false;
1574 case U32:
1575 if (value >= 0) {
1576 return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1577 }
1578 return false;
1579 case U64:
1580 if (value >= 0) {
1581 return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1582 }
1583 return false;
1584 case S8:
1585 return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1586 case S16:
1587 return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1588 case S32:
1589 return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1590 case S64:
1591 return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1592 case F32:
1593 return AllElementsEqualValue<float>(piece.data<float>(), value);
1594 case F64:
1595 return AllElementsEqualValue<double>(piece.data<double>(), value);
1596 case F16:
1597 return AllElementsEqualValue<half>(piece.data<half>(),
1598 static_cast<half>(value));
1599 case BF16:
1600 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1601 static_cast<bfloat16>(value));
1602 case PRED:
1603 if (value == 0) {
1604 return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1605 }
1606 if (value == 1) {
1607 return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1608 }
1609 return false;
1610 default:
1611 return false;
1612 }
1613 return false;
1614 };
1615
1616 if (!piece_is_all()) {
1617 return false;
1618 }
1619 return true;
1620 });
1621 }
1622
IsAllFloat(float value) const1623 bool LiteralBase::IsAllFloat(float value) const {
1624 return root_piece().ForEachSubpieceWithBool(
1625 [&](const ShapeIndex& index, const Piece& piece) {
1626 if (!piece.subshape().IsArray()) {
1627 return true;
1628 }
1629
1630 switch (shape().element_type()) {
1631 case F32:
1632 return AllElementsEqualValue<float>(piece.data<float>(), value);
1633 case F64:
1634 return AllElementsEqualValue<double>(piece.data<double>(), value);
1635 case F16:
1636 return AllElementsEqualValue<half>(piece.data<half>(),
1637 static_cast<half>(value));
1638 case BF16:
1639 return AllElementsEqualValue<bfloat16>(
1640 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1641 default:
1642 return false;
1643 }
1644 });
1645 }
1646
IsAllComplex(complex64 value) const1647 bool LiteralBase::IsAllComplex(complex64 value) const {
1648 switch (shape().element_type()) {
1649 case C64:
1650 return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1651 value);
1652 case C128:
1653 return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1654 value);
1655 default:
1656 return false;
1657 }
1658 }
1659
IsAllFirst() const1660 bool LiteralBase::IsAllFirst() const {
1661 return root_piece().ForEachSubpieceWithBool(
1662 [&](const ShapeIndex& index, const Piece& piece) {
1663 if (!piece.subshape().IsArray()) {
1664 return true;
1665 }
1666
1667 // Empty shapes are not all the first element since there is no first
1668 // element.
1669 if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1670 return false;
1671 }
1672 auto piece_is_all = [&]() {
1673 switch (piece.subshape().element_type()) {
1674 case PRED: {
1675 auto data = piece.data<bool>();
1676 return AllElementsEqualValue<bool>(data, data[0]);
1677 }
1678 // 8 bit types
1679 case S8: {
1680 auto data = piece.data<int8>();
1681 return AllElementsEqualValue<int8>(data, data[0]);
1682 }
1683 case U8: {
1684 auto data = piece.data<uint8>();
1685 return AllElementsEqualValue<uint8>(data, data[0]);
1686 }
1687 // 16 bit types
1688 case BF16: {
1689 auto data = piece.data<bfloat16>();
1690 return AllElementsEqualValue<bfloat16>(data, data[0]);
1691 }
1692 case F16: {
1693 auto data = piece.data<half>();
1694 return AllElementsEqualValue<half>(data, data[0]);
1695 }
1696 case S16: {
1697 auto data = piece.data<int16>();
1698 return AllElementsEqualValue<int16>(data, data[0]);
1699 }
1700 case U16: {
1701 auto data = piece.data<uint16>();
1702 return AllElementsEqualValue<uint16>(data, data[0]);
1703 }
1704 // 32 bit types
1705 case F32: {
1706 auto data = piece.data<float>();
1707 return AllElementsEqualValue<float>(data, data[0]);
1708 }
1709 case U32: {
1710 auto data = piece.data<uint32>();
1711 return AllElementsEqualValue<uint32>(data, data[0]);
1712 }
1713 case S32: {
1714 auto data = piece.data<int32>();
1715 return AllElementsEqualValue<int32>(data, data[0]);
1716 }
1717 // 64 bit types
1718 case C64: {
1719 auto data = piece.data<complex64>();
1720 return AllElementsEqualValue<complex64>(data, data[0]);
1721 }
1722 case F64: {
1723 auto data = piece.data<double>();
1724 return AllElementsEqualValue<double>(data, data[0]);
1725 }
1726 case S64: {
1727 auto data = piece.data<int64>();
1728 return AllElementsEqualValue<int64>(data, data[0]);
1729 }
1730 case U64: {
1731 auto data = piece.data<uint64>();
1732 return AllElementsEqualValue<uint64>(data, data[0]);
1733 }
1734
1735 case C128: {
1736 auto data = piece.data<complex128>();
1737 return AllElementsEqualValue<complex128>(data, data[0]);
1738 }
1739 default:
1740 return false;
1741 }
1742 };
1743
1744 if (!piece_is_all()) {
1745 return false;
1746 }
1747 return true;
1748 });
1749 }
1750
IsR1Iota() const1751 bool LiteralBase::IsR1Iota() const {
1752 if (!shape().IsArray()) {
1753 return false;
1754 }
1755
1756 if (shape().rank() != 1) {
1757 return false;
1758 }
1759
1760 auto is_iota_at_idx = [&](const int64 idx) {
1761 switch (shape().element_type()) {
1762 case U8:
1763 return Get<uint8>({idx}) == idx;
1764 case U16:
1765 return Get<uint16>({idx}) == idx;
1766 case U32:
1767 return Get<uint32>({idx}) == idx;
1768 case U64:
1769 return Get<uint64>({idx}) == idx;
1770 case S8:
1771 return Get<int8>({idx}) == idx;
1772 case S16:
1773 return Get<int16>({idx}) == idx;
1774 case S32:
1775 return Get<int32>({idx}) == idx;
1776 case S64:
1777 return Get<int64>({idx}) == idx;
1778 case F32:
1779 return Get<float>({idx}) == idx;
1780 case F64:
1781 return Get<double>({idx}) == idx;
1782 case F16:
1783 return Get<half>({idx}) == static_cast<half>(idx);
1784 case BF16:
1785 return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1786 case C64:
1787 return Get<complex64>({idx}) == complex64(idx, 0.0f);
1788 case C128:
1789 return Get<complex128>({idx}) == complex128(idx, 0.0f);
1790 case PRED:
1791 return Get<bool>({idx}) == idx;
1792 // token, opaque, tuple, etc. are all not iota.
1793 default:
1794 return false;
1795 }
1796 };
1797
1798 const int64 elements = ShapeUtil::ElementsIn(shape());
1799 for (int64 idx = 0; idx < elements; ++idx) {
1800 if (!is_iota_at_idx(idx)) {
1801 return false;
1802 }
1803 }
1804
1805 return true;
1806 }
1807
IsZero(absl::Span<const int64> indices) const1808 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1809 CHECK(shape().IsArray());
1810 switch (shape().element_type()) {
1811 case U8:
1812 return Get<uint8>(indices) == 0;
1813 case U16:
1814 return Get<uint16>(indices) == 0;
1815 case U32:
1816 return Get<uint32>(indices) == 0;
1817 case U64:
1818 return Get<uint64>(indices) == 0;
1819 case S8:
1820 return Get<int8>(indices) == 0;
1821 case S16:
1822 return Get<int16>(indices) == 0;
1823 case S32:
1824 return Get<int32>(indices) == 0;
1825 case S64:
1826 return Get<int64>(indices) == 0;
1827 case F32:
1828 return Get<float>(indices) == 0.0f;
1829 case F64:
1830 return Get<double>(indices) == 0.0;
1831 case C64:
1832 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1833 case C128:
1834 return Get<complex128>(indices) == complex128(0.0f, 0.0f);
1835 case F16:
1836 return Get<half>(indices) == static_cast<half>(0.0f);
1837 case BF16:
1838 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
1839 case PRED:
1840 return Get<bool>(indices) == false;
1841 default:
1842 LOG(FATAL) << "Input literal must be an array.";
1843 }
1844 }
1845
1846 namespace {
1847
1848 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)1849 void CopyToRepeatedField(RepeatedFieldT* dest,
1850 const absl::Span<const NativeT> src) {
1851 *dest = RepeatedFieldT(src.begin(), src.end());
1852 }
1853
1854 } // namespace
1855
WriteToProto(LiteralProto * proto) const1856 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
1857 *proto->mutable_shape() = subshape().ToProto();
1858 switch (subshape().element_type()) {
1859 case PRED:
1860 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
1861 break;
1862 case S8:
1863 proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
1864 element_count());
1865 break;
1866 case U8:
1867 proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
1868 element_count());
1869 break;
1870 case U32:
1871 CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
1872 break;
1873 case U64:
1874 CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
1875 break;
1876 case S32:
1877 CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
1878 break;
1879 case S64:
1880 CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
1881 break;
1882 case U16:
1883 *proto->mutable_u16s() = string(
1884 reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
1885 if (!kLittleEndian) {
1886 ConvertEndianShort(proto->mutable_u16s());
1887 }
1888 break;
1889 case S16:
1890 *proto->mutable_s16s() = string(
1891 reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
1892 if (!kLittleEndian) {
1893 ConvertEndianShort(proto->mutable_s16s());
1894 }
1895 break;
1896 case F16:
1897 *proto->mutable_f16s() = string(
1898 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
1899 if (!kLittleEndian) {
1900 ConvertEndianShort(proto->mutable_f16s());
1901 }
1902 break;
1903 case BF16:
1904 *proto->mutable_bf16s() = string(
1905 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
1906 if (!kLittleEndian) {
1907 ConvertEndianShort(proto->mutable_bf16s());
1908 }
1909 break;
1910 case F32:
1911 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
1912 break;
1913 case F64:
1914 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
1915 break;
1916 case C64:
1917 for (complex64 value : data<complex64>()) {
1918 proto->add_c64s(value.real());
1919 proto->add_c64s(value.imag());
1920 }
1921 break;
1922 case C128:
1923 for (complex128 value : data<complex128>()) {
1924 proto->add_c128s(value.real());
1925 proto->add_c128s(value.imag());
1926 }
1927 break;
1928 case TUPLE:
1929 case TOKEN:
1930 // Nothing to do but assign the shape which is done above.
1931 return;
1932 default:
1933 // TODO(b/111551621): Support serializing more PrimitiveTypes.
1934 LOG(FATAL) << "Unhandled primitive type "
1935 << PrimitiveType_Name(subshape().element_type());
1936 }
1937 }
1938
untyped_data() const1939 const void* LiteralBase::Piece::untyped_data() const {
1940 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1941 return buffer();
1942 }
1943
untyped_data()1944 void* LiteralBase::Piece::untyped_data() {
1945 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1946 return buffer();
1947 }
1948
1949 namespace {
1950
1951 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)1952 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
1953 const RepeatedFieldT& src) {
1954 if (dest.size() != src.size()) {
1955 return InvalidArgument(
1956 "Expected %lu elements in LiteralProto repeated field, has %d",
1957 dest.size(), src.size());
1958 }
1959 std::copy(src.begin(), src.end(), dest.begin());
1960 return Status::OK();
1961 }
1962
1963 } // namespace
1964
CopyFromProto(const LiteralProto & proto)1965 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
1966 // These conditions should have been checked in
1967 // MutableLiteralBase::CreateFromProto.
1968 TF_RET_CHECK(proto.has_shape());
1969 Shape shape(proto.shape());
1970 TF_RET_CHECK(LayoutUtil::HasLayout(shape));
1971 TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
1972
1973 if (LayoutUtil::IsSparseArray(subshape())) {
1974 // Compute the number of elements (indices) in the sparse shape and reserve
1975 // the necessary space in spare_indices.
1976 TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
1977 TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
1978 << "Unexpected number of indices in proto ("
1979 << proto.sparse_indices_size() << ") for shape of rank "
1980 << subshape().rank();
1981 const int64 index_count = proto.sparse_indices_size() / subshape().rank();
1982 sparse_indices()->Resize(index_count);
1983
1984 // Copy the indices from the proto into the SparseIndexArray object.
1985 TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
1986 proto.sparse_indices()));
1987 }
1988
1989 switch (subshape().element_type()) {
1990 case PRED:
1991 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
1992 break;
1993 case S8: {
1994 auto s8_data = data<int8>();
1995 TF_RET_CHECK(proto.s8s().size() == s8_data.size());
1996 std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
1997 } break;
1998 case U8: {
1999 auto u8_data = data<uint8>();
2000 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2001 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2002 } break;
2003 case S32:
2004 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2005 break;
2006 case S64:
2007 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2008 break;
2009 case U32:
2010 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2011 break;
2012 case U64:
2013 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2014 break;
2015 case S16: {
2016 const string& s(proto.s16s());
2017 TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2018 memcpy(untyped_data(), s.data(), s.size());
2019 if (!kLittleEndian) {
2020 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2021 }
2022 } break;
2023 case U16: {
2024 const string& s(proto.u16s());
2025 TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2026 memcpy(untyped_data(), s.data(), s.size());
2027 if (!kLittleEndian) {
2028 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2029 }
2030 } break;
2031 case F16: {
2032 const string& s(proto.f16s());
2033 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2034 memcpy(untyped_data(), s.data(), s.size());
2035 if (!kLittleEndian) {
2036 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2037 }
2038 } break;
2039
2040 case BF16: {
2041 const string& s(proto.bf16s());
2042 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2043 memcpy(untyped_data(), s.data(), s.size());
2044 if (!kLittleEndian) {
2045 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2046 }
2047 } break;
2048 case F32:
2049 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2050 break;
2051 case F64:
2052 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2053 break;
2054 case C64: {
2055 auto complex_data = data<complex64>();
2056 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2057 for (int64 i = 0; i < complex_data.size(); ++i) {
2058 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2059 }
2060 break;
2061 }
2062 case C128: {
2063 auto complex_data = data<complex128>();
2064 TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
2065 for (int64 i = 0; i < complex_data.size(); ++i) {
2066 complex_data[i] =
2067 complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2068 }
2069 break;
2070 }
2071 case TUPLE:
2072 return InvalidArgument("Should not be called on tuple shapes: %s",
2073 ShapeUtil::HumanString(subshape()));
2074 default:
2075 return InvalidArgument("Is called on unsupported shape: %s",
2076 ShapeUtil::HumanString(subshape()));
2077 }
2078 return Status::OK();
2079 }
2080
ToProto() const2081 LiteralProto LiteralBase::ToProto() const {
2082 LiteralProto proto;
2083 root_piece().ForEachSubpiece(
2084 [&](const ShapeIndex& index, const Piece& piece) {
2085 LiteralProto* proto_piece = &proto;
2086 for (int64 i : index) {
2087 while (proto_piece->tuple_literals_size() <= i) {
2088 proto_piece->add_tuple_literals();
2089 }
2090 proto_piece = proto_piece->mutable_tuple_literals(i);
2091 }
2092 piece.WriteToProto(proto_piece);
2093 });
2094
2095 if (LayoutUtil::IsSparseArray(shape())) {
2096 CopyToRepeatedField(proto.mutable_sparse_indices(),
2097 sparse_indices()->data());
2098 }
2099
2100 return proto;
2101 }
2102
untyped_data(const ShapeIndex & shape_index) const2103 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2104 return piece(shape_index).untyped_data();
2105 }
2106
untyped_data(const ShapeIndex & shape_index)2107 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2108 return piece(shape_index).untyped_data();
2109 }
2110
size_bytes(const ShapeIndex & shape_index) const2111 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2112 return piece(shape_index).size_bytes();
2113 }
2114
GetR1U8AsString() const2115 string LiteralBase::GetR1U8AsString() const {
2116 CHECK(shape().IsArray());
2117 CHECK_EQ(shape().rank(), 1);
2118 CHECK_EQ(shape().element_type(), U8);
2119 return string(absl::bit_cast<const char*>(data<uint8>().data()),
2120 ShapeUtil::ElementsIn(shape()));
2121 }
2122
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2123 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2124 Piece* src_piece,
2125 Piece* dest_piece) {
2126 DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2127 << "src_piece has shape: "
2128 << ShapeUtil::HumanString(src_piece->subshape())
2129 << "dest_piece has shape: "
2130 << ShapeUtil::HumanString(dest_piece->subshape());
2131 if (shape.IsTuple()) {
2132 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2133 const Shape& subshape = shape.tuple_shapes(i);
2134
2135 auto child_piece = Piece();
2136 child_piece.set_subshape(&subshape);
2137
2138 CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2139
2140 dest_piece->emplace_back(std::move(child_piece));
2141 }
2142 } else if (shape.IsArray()) {
2143 dest_piece->set_buffer(src_piece->buffer());
2144 } else {
2145 // If the shape is neither an array nor tuple, then it must be
2146 // zero-sized. Otherwise, some memory needs to be allocated for it.
2147 CHECK_EQ(dest_piece->size_bytes(), 0);
2148 }
2149 }
2150
~MutableLiteralBase()2151 MutableLiteralBase::~MutableLiteralBase() {}
2152
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2153 MutableBorrowingLiteral::MutableBorrowingLiteral(
2154 const MutableBorrowingLiteral& literal)
2155 : MutableLiteralBase() {
2156 shape_ = absl::make_unique<Shape>(literal.shape());
2157 CHECK(LayoutUtil::HasLayout(*shape_));
2158
2159 root_piece_ = new Piece();
2160 root_piece_->set_subshape(shape_.get());
2161
2162 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2163 }
2164
operator =(const MutableBorrowingLiteral & literal)2165 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2166 const MutableBorrowingLiteral& literal) {
2167 shape_ = absl::make_unique<Shape>(literal.shape());
2168 CHECK(LayoutUtil::HasLayout(*shape_));
2169
2170 root_piece_ = new Piece();
2171 root_piece_->set_subshape(shape_.get());
2172
2173 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2174
2175 return *this;
2176 }
2177
MutableBorrowingLiteral(const MutableLiteralBase & literal)2178 MutableBorrowingLiteral::MutableBorrowingLiteral(
2179 const MutableLiteralBase& literal)
2180 : MutableLiteralBase() {
2181 shape_ = absl::make_unique<Shape>(literal.shape());
2182 CHECK(LayoutUtil::HasLayout(*shape_));
2183
2184 root_piece_ = new Piece();
2185 root_piece_->set_subshape(shape_.get());
2186
2187 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2188 }
2189
MutableBorrowingLiteral(MutableLiteralBase * literal)2190 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2191 : MutableLiteralBase() {
2192 shape_ = absl::make_unique<Shape>(literal->shape());
2193 CHECK(LayoutUtil::HasLayout(*shape_));
2194
2195 root_piece_ = new Piece();
2196 root_piece_->set_subshape(shape_.get());
2197
2198 CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2199 }
2200
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2201 MutableBorrowingLiteral::MutableBorrowingLiteral(
2202 MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2203 : MutableLiteralBase() {
2204 shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2205 CHECK(LayoutUtil::HasLayout(*shape_));
2206
2207 root_piece_ = new Piece();
2208 root_piece_->set_subshape(shape_.get());
2209
2210 CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2211 }
2212
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2213 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2214 const Shape& shape)
2215 : MutableLiteralBase() {
2216 shape_ = absl::make_unique<Shape>(shape);
2217 CHECK(LayoutUtil::HasLayout(*shape_));
2218 CHECK(!shape_->IsTuple());
2219
2220 root_piece_ = new Piece();
2221 root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2222 root_piece_->set_subshape(shape_.get());
2223 }
2224
~MutableBorrowingLiteral()2225 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2226 if (root_piece_ != nullptr) {
2227 root_piece_->ForEachMutableSubpiece(
2228 [&](const ShapeIndex& index, Piece* piece) {
2229 if (piece->buffer() != nullptr) {
2230 delete piece->sparse_indices();
2231 }
2232 });
2233 delete root_piece_;
2234 }
2235 }
2236
LiteralSlice(const LiteralBase & literal)2237 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2238 : LiteralBase(), root_piece_(&literal.root_piece()) {}
2239
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2240 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2241 const ShapeIndex& view_root)
2242 : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2243
BuildPieceSubtree(const Shape & shape,Piece * piece)2244 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2245 CHECK(shape.IsTuple());
2246 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2247 const Shape& subshape = shape.tuple_shapes(i);
2248
2249 auto child_piece = Piece();
2250 child_piece.set_subshape(&subshape);
2251
2252 if (subshape.IsTuple()) {
2253 BuildPieceSubtree(subshape, &child_piece);
2254 }
2255
2256 piece->emplace_back(std::move(child_piece));
2257 }
2258 }
2259
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2260 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2261 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2262 CHECK(shape_->IsArray());
2263 CHECK(LayoutUtil::HasLayout(*shape_));
2264
2265 root_piece_ = Piece();
2266 root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2267 root_piece_.set_subshape(shape_.get());
2268 }
2269
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2270 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2271 const Shape& shape)
2272 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2273 CHECK(shape_->IsTuple());
2274 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2275 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2276 root_piece_ = Piece();
2277 root_piece_.set_subshape(shape_.get());
2278 BuildPieceSubtree(*shape_, &root_piece_);
2279
2280 for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2281 const auto& src_shape = shape_->tuple_shapes(i);
2282 CHECK(src_shape.IsArray());
2283 root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2284 }
2285 }
2286
2287 } // namespace xla
2288