1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/index_util.h"
34 #include "tensorflow/compiler/xla/permutation_util.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mem.h"
45 #include "tensorflow/core/platform/types.h"
46
47 namespace xla {
48 namespace {
49
50 using absl::StrCat;
51
52 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
53 // Literals can be used as DMA targets, which can require alignment. We
54 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
55 // alignment.
56 constexpr int kMinimumAlignment = 64;
57
58 // Converts between little and big endian.
59 //
60 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)61 void ConvertEndianShort(string* bytes) {
62 CHECK_EQ(bytes->size() % 2, 0);
63 for (int64 i = 0, end = bytes->size(); i < end; i += 2) {
64 std::swap((*bytes)[i], (*bytes)[i + 1]);
65 }
66 }
67
ConvertEndianShort(char * bytes,int64 size)68 void ConvertEndianShort(char* bytes, int64 size) {
69 CHECK_EQ(size % 2, 0);
70 for (int64 i = 0; i < size; i += 2) {
71 std::swap(bytes[i], bytes[i + 1]);
72 }
73 }
74
CompactOneline(const string & input)75 string CompactOneline(const string& input) {
76 string result;
77 std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
78 bool first = true;
79 // Concatenate elements in "v" with spaces separating them, but ignoring
80 // empty entries.
81 for (const auto& s : v) {
82 if (s.empty()) {
83 continue;
84 }
85 absl::StrAppend(&result, (first ? "" : " "), s);
86 first = false;
87 }
88 return result;
89 }
90
91 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
92 // able to transparently access the raw 16-bit value contained within.
93 template <typename T>
GetRawValue(T val)94 T GetRawValue(T val) {
95 return val;
96 }
GetRawValue(Eigen::half val)97 uint16 GetRawValue(Eigen::half val) { return val.x; }
98
LiteralProtoHasValues(const LiteralProto & proto)99 bool LiteralProtoHasValues(const LiteralProto& proto) {
100 return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
101 proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
102 proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
103 proto.c64s_size() || proto.c128s_size() ||
104 proto.tuple_literals_size() || !proto.f16s().empty() ||
105 !proto.bf16s().empty() || !proto.u16s().empty() ||
106 !proto.s16s().empty();
107 }
108
109 } // namespace
110
~LiteralBase()111 LiteralBase::~LiteralBase() {}
112
operator <<(std::ostream & out,const Literal & literal)113 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
114 out << literal.ToString();
115 return out;
116 }
117
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)118 MutableLiteralBase::StrideConfig::StrideConfig(
119 const Shape& source_shape, const Shape& dest_shape,
120 absl::Span<const int64> dimensions)
121 : dimensions(dimensions),
122 base(dimensions.size(), 0),
123 step(dimensions.size(), 1) {
124 if (!dimensions.empty()) {
125 // Selects the shape with the largest minor dimension as the one upon
126 // which to run the tight stride loop.
127 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
128 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
129 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
130 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
131 } else {
132 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
133 source_stride =
134 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
135 }
136 minor_loop_size = dimensions[minor_dimension];
137 step[minor_dimension] = minor_loop_size;
138 }
139 }
140
Literal(const Shape & shape)141 Literal::Literal(const Shape& shape)
142 : Literal(shape, /*allocate_arrays=*/true) {}
143
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)144 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
145 if (shape.IsTuple()) {
146 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
147 const Shape& subshape = shape.tuple_shapes(i);
148
149 auto child_piece = Piece();
150 child_piece.set_subshape(&subshape);
151
152 SetPiece(subshape, &child_piece, allocate_arrays);
153
154 piece->emplace_back(std::move(child_piece));
155 }
156 } else if (shape.IsArray()) {
157 if (allocate_arrays) {
158 piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
159 piece->size_bytes(), kMinimumAlignment)));
160 if (shape.is_dynamic()) {
161 CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
162 piece->set_dynamic_size_buffer(
163 static_cast<int32*>(tensorflow::port::AlignedMalloc(
164 piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
165 }
166 }
167 } else {
168 // If the shape is neither an array nor tuple, then it must be
169 // zero-sized. Otherwise, some memory needs to be allocated for it.
170 CHECK_EQ(piece->size_bytes(), 0);
171 }
172 }
173
Literal(const Shape & shape,bool allocate_arrays)174 Literal::Literal(const Shape& shape, bool allocate_arrays)
175 : MutableLiteralBase() {
176 shape_ = absl::make_unique<Shape>(shape);
177 CHECK(LayoutUtil::HasLayout(*shape_));
178 root_piece_ = new Piece();
179 root_piece_->set_subshape(shape_.get());
180 CHECK(&root_piece_->subshape() == shape_.get());
181
182 SetPiece(*shape_, root_piece_, allocate_arrays);
183 }
184
~Literal()185 Literal::~Literal() {
186 if (root_piece_ != nullptr) {
187 DeallocateBuffers();
188 delete root_piece_;
189 }
190 }
191
DeallocateBuffers()192 void Literal::DeallocateBuffers() {
193 root_piece_->ForEachMutableSubpiece(
194 [&](const ShapeIndex& index, Piece* piece) {
195 if (piece->buffer() != nullptr) {
196 tensorflow::port::AlignedFree(piece->buffer());
197 }
198 if (piece->dynamic_size_buffer() != nullptr) {
199 tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
200 }
201 });
202 }
203
Literal(Literal && other)204 Literal::Literal(Literal&& other) : MutableLiteralBase() {
205 *this = std::move(other);
206 }
207
operator =(Literal && other)208 Literal& Literal::operator=(Literal&& other) {
209 DCHECK(&other.root_piece_->subshape() == other.shape_.get());
210 using std::swap;
211 swap(shape_, other.shape_);
212 swap(root_piece_, other.root_piece_);
213 DCHECK(&root_piece_->subshape() == shape_.get());
214
215 return *this;
216 }
217
CreateFromShape(const Shape & shape)218 Literal LiteralBase::CreateFromShape(const Shape& shape) {
219 Literal literal(shape);
220 literal.root_piece_->ForEachMutableSubpiece(
221 [&](const ShapeIndex& index, Piece* piece) {
222 if (piece->subshape().IsArray()) {
223 memset(piece->untyped_data(), 0, piece->size_bytes());
224 }
225 });
226 return literal;
227 }
228
GetDynamicSize(int64 dim_index) const229 int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
230 return GetDynamicSize(dim_index, {});
231 }
232
GetDynamicSize(int64 dim_index,const ShapeIndex & shape_index) const233 int32 LiteralBase::GetDynamicSize(int64 dim_index,
234 const ShapeIndex& shape_index) const {
235 return piece(shape_index).GetDynamicSize(dim_index);
236 }
237
GetFirstInteger() const238 absl::optional<int64> LiteralBase::GetFirstInteger() const {
239 switch (shape().element_type()) {
240 case U8:
241 return GetFirstElement<uint8>();
242 case U16:
243 return GetFirstElement<uint16>();
244 case U32:
245 return GetFirstElement<uint32>();
246 case U64: {
247 int64 v = GetFirstElement<uint64>();
248 if (v < 0) {
249 return absl::nullopt;
250 }
251 return v;
252 }
253 case S8:
254 return GetFirstElement<int8>();
255 case S16:
256 return GetFirstElement<int16>();
257 case S32:
258 return GetFirstElement<int32>();
259 case S64:
260 return GetFirstElement<int64>();
261 default:
262 return absl::nullopt;
263 }
264 }
265
266 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)267 Status MutableLiteralBase::CopySliceFromInternal(
268 const LiteralBase& src_literal, absl::Span<const int64> src_base,
269 absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
270 const int64 src_base_size = src_base.size();
271 const int64 dest_base_size = dest_base.size();
272 TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
273 TF_RET_CHECK(shape().rank() == dest_base_size);
274
275 auto linear_index = [](const Shape& shape,
276 absl::Span<const int64> multi_index) {
277 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
278 };
279
280 if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
281 // If any of the two shapes are scalars, we can just call the StridedCopy()
282 // directly, and we know we will be copying only one value.
283 TF_RET_CHECK(copy_size.empty());
284 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
285 src_literal.data<NativeT>(),
286 linear_index(src_literal.shape(), src_base), 0, 1);
287 } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
288 !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
289 // Perform copy if neither src nor dest has dimensions with zero element,
290 // otherwise it's a no-op.
291 TF_RET_CHECK(src_base.size() == dest_base.size());
292 TF_RET_CHECK(src_base.size() == copy_size.size());
293
294 // Scan the source from minor, stepping in copy size blocks, then within
295 // the index enumeration functor, do a strided copy advancing source index
296 // by one (walking through the minor dimension), and destination index by
297 // proper stride size at the matching dimension.
298 DimensionVector src_indexes(src_base.size(), 0);
299 DimensionVector dest_indexes(dest_base.size(), 0);
300 MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
301 copy_size);
302
303 auto copy_proc = [&](absl::Span<const int64> indexes) {
304 // Map from multi-dimensional index, to source index.
305 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
306 src_indexes.begin(), std::plus<int64>());
307 // Map from multi-dimensional index, to destination index.
308 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
309 dest_indexes.begin(), std::plus<int64>());
310
311 int64 src_index = linear_index(src_literal.shape(), src_indexes);
312 int64 dest_index = linear_index(shape(), dest_indexes);
313
314 // `this->` is needed to workaround MSVC bug: #16882
315 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
316 src_literal.data<NativeT>(), src_index,
317 stride_config.source_stride, stride_config.minor_loop_size);
318 return true;
319 };
320
321 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
322 stride_config.dimensions, stride_config.step,
323 copy_proc);
324 }
325 return Status::OK();
326 }
327
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)328 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
329 absl::Span<const int64> src_index,
330 absl::Span<const int64> dest_index) {
331 DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
332 const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
333 src_literal.shape(), src_index);
334 const int64 dest_linear_index =
335 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
336 const int64 primitive_size =
337 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
338
339 char* dest_address =
340 static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
341 const char* source_address =
342 static_cast<const char*>(src_literal.untyped_data()) +
343 src_linear_index * primitive_size;
344 if (dest_address != source_address) {
345 memcpy(dest_address, source_address, primitive_size);
346 }
347 return Status::OK();
348 }
349
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)350 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
351 const LiteralProto& proto, bool prohibit_empty_literal) {
352 if (!proto.has_shape()) {
353 return InvalidArgument("LiteralProto has no shape");
354 }
355 Shape shape(proto.shape());
356 if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
357 return InvalidArgument(
358 "Literal shape cannot include OPAQUE_TYPE sub-shape");
359 }
360 if (!LayoutUtil::HasLayout(shape)) {
361 return InvalidArgument("LiteralProto has no layout");
362 }
363
364 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
365
366 Literal literal(shape);
367
368 TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
369 [&](const ShapeIndex& index, Piece* piece) {
370 const LiteralProto* proto_element = &proto;
371 for (int64 i : index) {
372 CHECK(i < proto_element->tuple_literals_size());
373 proto_element = &proto_element->tuple_literals(i);
374 }
375
376 if (piece->subshape().IsTuple()) {
377 if (proto_element->tuple_literals_size() !=
378 ShapeUtil::TupleElementCount(piece->subshape())) {
379 return InvalidArgument(
380 "Expected %d tuple elements in LiteralProto, has %d",
381 ShapeUtil::TupleElementCount(piece->subshape()),
382 proto_element->tuple_literals_size());
383 }
384 return Status::OK();
385 }
386 if (piece->subshape().element_type() == TOKEN) {
387 return Status::OK();
388 }
389
390 CHECK(piece->subshape().IsArray());
391
392 // When prohibit_empty_literal is false (allowing literal with no
393 // values), only copy from proto if the literal proto has values. This
394 // mode is used for a learned cost model.
395 if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
396 TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
397 }
398
399 return Status::OK();
400 }));
401
402 return std::move(literal);
403 }
404
DecomposeTuple()405 std::vector<Literal> Literal::DecomposeTuple() {
406 CHECK(shape().IsTuple());
407 std::vector<Literal> elements;
408 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
409 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
410 /*allocate_arrays=*/false));
411 Literal& element = elements.back();
412 element.root_piece_->ForEachMutableSubpiece(
413 [&](const ShapeIndex& index, Piece* dest_piece) {
414 ShapeIndex src_index = {i};
415 for (int64 j : index) {
416 src_index.push_back(j);
417 }
418 Piece& src_piece = piece(src_index);
419
420 // Move the respective buffer over to the element Literal.
421 dest_piece->set_buffer(src_piece.buffer());
422 dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
423 src_piece.set_buffer(nullptr);
424 src_piece.set_dynamic_size_buffer(nullptr);
425 });
426 }
427 // Set this literal to be nil-shaped.
428 *this = Literal();
429 return elements;
430 }
431
432 namespace {
433
434 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
435 // the array slices are indicated by dest_shape and src_shape respectively.
436 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)437 void CopyElementsBetween(absl::Span<NativeT> dest,
438 absl::Span<const NativeT> src, const Shape& dest_shape,
439 const Shape& src_shape) {
440 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
441 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
442 return;
443 }
444 std::vector<int64> index(dest_shape.rank());
445 do {
446 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
447 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
448 } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
449 }
450 } // namespace
451
GetDynamicSize(int64 dim_index) const452 int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
453 CHECK(LayoutUtil::IsDenseArray(subshape()));
454 if (!subshape_->is_dynamic_dimension(dim_index)) {
455 // This is a static dimension, return size.
456 return subshape_->dimensions(dim_index);
457 }
458 CHECK_NE(dynamic_size_buffer(), nullptr);
459 return dynamic_size_buffer_[dim_index];
460 }
461
SetDynamicSize(int64 dim_index,int32 size)462 void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
463 CHECK(LayoutUtil::IsDenseArray(subshape()));
464 CHECK(subshape_->is_dynamic_dimension(dim_index));
465 if (dynamic_size_buffer() == nullptr) {
466 // Lazily initialize the dynamic size buffer.
467 set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
468 dynamic_size_buffer_bytes(), kMinimumAlignment)));
469 /*for (int64 i = 0; i < subshape().rank(); ++i) {
470 // Initialized to -1 to help debug.
471 dynamic_size_buffer_[i] = -1;
472 }*/
473 }
474 dynamic_size_buffer_[dim_index] = size;
475 }
476
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)477 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
478 bool only_dynamic_bound) {
479 CHECK(subshape_ != nullptr);
480 CHECK(src.subshape_ != nullptr);
481 if (ShapeUtil::Equal(subshape(), src.subshape())) {
482 // If the layouts are equal it's faster just to memcpy.
483 memcpy(buffer(), src.buffer(), src.size_bytes());
484 } else {
485 std::vector<int64> origin(subshape().rank(), 0);
486 switch (subshape().element_type()) {
487 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
488 case (XLA_T): \
489 if (only_dynamic_bound) { \
490 CopyElementsWithDynamicBound<NATIVE_T>(src); \
491 } else { \
492 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
493 subshape(), src.subshape()); \
494 } \
495 break;
496 COPY_ELEMENTS(U8, uint8);
497 COPY_ELEMENTS(U16, uint16);
498 COPY_ELEMENTS(U32, uint32);
499 COPY_ELEMENTS(U64, uint64);
500 COPY_ELEMENTS(S8, int8);
501 COPY_ELEMENTS(S16, int16);
502 COPY_ELEMENTS(S32, int32);
503 COPY_ELEMENTS(S64, int64);
504 COPY_ELEMENTS(F16, half);
505 COPY_ELEMENTS(BF16, bfloat16);
506 COPY_ELEMENTS(F32, float);
507 COPY_ELEMENTS(F64, double);
508 COPY_ELEMENTS(C64, complex64);
509 COPY_ELEMENTS(C128, complex128);
510 COPY_ELEMENTS(PRED, bool);
511 #undef COPY_ELEMENTS
512 default:
513 return Unimplemented(
514 "Copying a Literal object with element type %s is not implemented.",
515 PrimitiveType_Name(subshape().element_type()));
516 }
517 }
518 DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
519 if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
520 CHECK_NE(dynamic_size_buffer_, nullptr);
521 CHECK_NE(src.dynamic_size_buffer_, nullptr);
522 memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
523 src.dynamic_size_buffer_bytes());
524 }
525 return Status::OK();
526 }
527
SetDynamicSize(int64 dim_index,int32 size)528 void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
529 return SetDynamicSize(dim_index, {}, size);
530 }
531
SetDynamicSize(int64 dim_index,const ShapeIndex & shape_index,int32 size)532 void MutableLiteralBase::SetDynamicSize(int64 dim_index,
533 const ShapeIndex& shape_index,
534 int32 size) {
535 Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
536 CHECK_GE(subshape_->dimensions(dim_index), size);
537 if (subshape_->dimensions(dim_index) == size) {
538 subshape_->set_dynamic_dimension(dim_index, false);
539 return;
540 }
541 subshape_->set_dynamic_dimension(dim_index, true);
542 piece(shape_index).SetDynamicSize(dim_index, size);
543 }
544
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)545 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
546 const ShapeIndex& dest_shape_index,
547 const ShapeIndex& src_shape_index,
548 bool only_dynamic_bound) {
549 const Shape& dest_subshape =
550 ShapeUtil::GetSubshape(shape(), dest_shape_index);
551 const Shape& src_subshape =
552 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
553 if (only_dynamic_bound) {
554 auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
555 auto compact_shape =
556 dest_subshape.is_static() ? dest_subshape : src_subshape;
557 CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
558 << compact_shape.ToString() << " vs " << bound_shape.ToString();
559 } else {
560 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
561 return InvalidArgument(
562 "Destination subshape incompatible with source subshape: %s vs %s",
563 ShapeUtil::HumanString(dest_subshape),
564 ShapeUtil::HumanString(src_subshape));
565 }
566 }
567 return root_piece_->ForEachMutableSubpieceWithStatus(
568 [&](const ShapeIndex& index, Piece* piece) {
569 if (!piece->subshape().IsArray()) {
570 return Status::OK();
571 }
572
573 // Determine if this index is in the part of this literal that we want
574 // to copy over from src_literal.
575 bool in_subtree_to_copy = true;
576 for (int i = 0; i < dest_shape_index.size(); ++i) {
577 if (index[i] != dest_shape_index[i]) {
578 in_subtree_to_copy = false;
579 break;
580 }
581 }
582 if (!in_subtree_to_copy) {
583 return Status::OK();
584 }
585 // Construct the index of the corresponding piece in the source literal.
586 ShapeIndex src_piece_index = src_shape_index;
587 for (int64 i = dest_shape_index.size(), end = index.size(); i < end;
588 ++i) {
589 src_piece_index.push_back(index[i]);
590 }
591 TF_RETURN_IF_ERROR(
592 piece->CopyFrom(src_literal.piece(src_piece_index),
593 /*only_dynamic_bound=*/only_dynamic_bound));
594 return Status::OK();
595 });
596 }
597
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)598 Status Literal::MoveFrom(Literal&& src_literal,
599 const ShapeIndex& dest_shape_index) {
600 const Shape& dest_subshape =
601 ShapeUtil::GetSubshape(shape(), dest_shape_index);
602 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
603 return InvalidArgument(
604 "Destination subshape not equal to source shape: %s vs %s",
605 ShapeUtil::HumanString(dest_subshape),
606 ShapeUtil::HumanString(src_literal.shape()));
607 }
608
609 src_literal.root_piece_->ForEachSubpiece(
610 [&](const ShapeIndex& src_index, const Piece& src_piece) {
611 if (!src_piece.subshape().IsArray()) {
612 return;
613 }
614
615 ShapeIndex dest_index = dest_shape_index;
616 for (int64 i : src_index) {
617 dest_index.push_back(i);
618 }
619 Piece& dest_piece = piece(dest_index);
620 tensorflow::port::AlignedFree(dest_piece.buffer());
621 tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
622 dest_piece.set_buffer(src_piece.buffer());
623 dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
624 });
625
626 src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
627 delete src_literal.root_piece_;
628 src_literal.root_piece_ = new LiteralBase::Piece();
629 src_literal.root_piece_->set_subshape(src_literal.shape_.get());
630
631 return Status::OK();
632 }
633
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)634 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
635 absl::Span<const int64> src_base,
636 absl::Span<const int64> dest_base,
637 absl::Span<const int64> copy_size) {
638 TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
639 TF_RET_CHECK(src_literal.shape().IsArray())
640 << ShapeUtil::HumanString(src_literal.shape());
641 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
642
643 switch (shape().element_type()) {
644 case U8:
645 return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
646 copy_size);
647 case U16:
648 return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
649 copy_size);
650 case U32:
651 return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
652 copy_size);
653 case U64:
654 return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
655 copy_size);
656 case S8:
657 return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
658 copy_size);
659 case S16:
660 return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
661 copy_size);
662 case S32:
663 return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
664 copy_size);
665 case S64:
666 return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
667 copy_size);
668 case F16:
669 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
670 copy_size);
671 case BF16:
672 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
673 copy_size);
674 case F32:
675 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
676 copy_size);
677 case F64:
678 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
679 copy_size);
680 case C64:
681 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
682 copy_size);
683 case C128:
684 return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
685 copy_size);
686 case PRED:
687 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
688 copy_size);
689 default:
690 break;
691 }
692 return Unimplemented(
693 "Copying a slice from a Literal object with element type %d is not "
694 "implemented.",
695 shape().element_type());
696 }
697
PopulateR1(const tensorflow::core::Bitmap & values)698 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
699 CHECK(shape().IsArray());
700 CHECK_EQ(shape().rank(), 1);
701 CHECK_EQ(element_count(), values.bits());
702 CHECK_EQ(shape().element_type(), PRED);
703 for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
704 Set({i}, values.get(i));
705 }
706 }
707
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const708 Literal LiteralBase::Relayout(const Layout& new_layout,
709 const ShapeIndex& shape_index) const {
710 // Create new shape with 'new_layout' set at the given shape index.
711 Shape new_shape = shape();
712 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
713 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
714 *subshape->mutable_layout() = new_layout;
715 Literal result(new_shape);
716 TF_CHECK_OK(result.CopyFrom(*this));
717 return result;
718 }
719
Relayout(const Shape & shape_with_layout) const720 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
721 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
722 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
723 << " not compatible with literal shape "
724 << ShapeUtil::HumanString(shape());
725 Literal result = CreateFromShape(shape_with_layout);
726 ShapeUtil::ForEachSubshape(
727 result.shape(),
728 [this, &result](const Shape& subshape, const ShapeIndex& index) {
729 if (subshape.IsArray()) {
730 TF_CHECK_OK(result.CopyFrom(*this,
731 /*dest_shape_index=*/index,
732 /*src_shape_index=*/index));
733 }
734 });
735 return result;
736 }
737
ToBoundedDynamic(const Shape & bounded_shape) const738 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
739 CHECK(bounded_shape.is_dynamic());
740 Literal result(bounded_shape);
741 ShapeUtil::ForEachSubshape(
742 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
743 if (!subshape.IsArray()) {
744 return;
745 }
746 for (int64 i = 0; i < subshape.rank(); ++i) {
747 result.SetDynamicSize(i, subshape.dimensions(i));
748 }
749 });
750 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
751
752 return result;
753 }
754
ToStatic() const755 Literal LiteralBase::ToStatic() const {
756 // Create new shape with 'new_layout' set at the given shape index.
757 Shape new_shape = shape();
758 ShapeUtil::ForEachMutableSubshape(
759 &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
760 if (!subshape->IsArray()) {
761 return;
762 }
763 for (int64 i = 0; i < subshape->rank(); ++i) {
764 subshape->set_dynamic_dimension(i, false);
765 subshape->set_dimensions(i, GetDynamicSize(i, index));
766 }
767 });
768 Literal result(new_shape);
769 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
770 return result;
771 }
772
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const773 StatusOr<Literal> LiteralBase::Broadcast(
774 const Shape& result_shape, absl::Span<const int64> dimensions) const {
775 if (!shape().IsArray()) {
776 return InvalidArgument("Broadcast only supports arrays.");
777 }
778
779 for (int64 i = 0, end = dimensions.size(); i < end; i++) {
780 TF_RET_CHECK(shape().dimensions(i) ==
781 result_shape.dimensions(dimensions[i]));
782 }
783
784 Literal result(result_shape);
785
786 // scratch_source_index is temporary storage space for the computed index into
787 // the input literal. We put it here to avoid allocating an std::vector in
788 // every iteration of ShapeUtil::ForEachIndex.
789 std::vector<int64> scratch_source_index(shape().dimensions_size());
790
791 char* dest_data = static_cast<char*>(result.untyped_data());
792 const char* source_data = static_cast<const char*>(untyped_data());
793 const int64 primitive_size =
794 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
795
796 for (int64 i = 0; i < dimensions.size(); ++i) {
797 int64 dynamic_size = GetDynamicSize(i);
798 result.SetDynamicSize(dimensions[i], dynamic_size);
799 }
800
801 ShapeUtil::ForEachIndex(
802 result_shape, [&](absl::Span<const int64> output_index) {
803 for (int64 i = 0, end = dimensions.size(); i < end; ++i) {
804 scratch_source_index[i] = output_index[dimensions[i]];
805 }
806 int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
807 result_shape, output_index);
808 int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
809 shape(), scratch_source_index);
810 memcpy(dest_data + primitive_size * dest_index,
811 source_data + primitive_size * source_index, primitive_size);
812 return true;
813 });
814
815 return std::move(result);
816 }
817
Reshape(absl::Span<const int64> dimensions) const818 StatusOr<Literal> LiteralBase::Reshape(
819 absl::Span<const int64> dimensions) const {
820 if (!shape().IsArray()) {
821 return InvalidArgument("Reshape does not support tuples.");
822 }
823 if (shape().is_dynamic()) {
824 return Unimplemented("Dynamic reshape is not implemented.");
825 }
826 Literal output;
827 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
828 output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
829 } else {
830 output = Clone();
831 }
832 // Because the layout is monotonic, we can simply reuse the same sequence of
833 // values without changing their order.
834 *output.mutable_shape_do_not_use() =
835 ShapeUtil::MakeShape(shape().element_type(), dimensions);
836
837 int64 elements_before = ShapeUtil::ElementsIn(shape());
838 int64 elements_after = ShapeUtil::ElementsIn(output.shape());
839 if (elements_before != elements_after) {
840 return InvalidArgument(
841 "Shapes before and after Literal::Reshape have different numbers "
842 "of elements: %s vs %s.",
843 ShapeUtil::HumanString(shape()),
844 ShapeUtil::HumanString(output.shape()));
845 }
846 return std::move(output);
847 }
848
Transpose(absl::Span<const int64> permutation) const849 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
850 CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
851 CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
852 << "Given permutation is not a permutation of dimension numbers";
853 // To transpose the array, we just permute the dimensions and layout, and
854 // do a straight memory copy of the raw data set.
855 // This is considerably faster than iterating over every array element using
856 // the EachCell<>() and Set<>() APIs.
857 Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
858 // Replace the layout with one affine to this shape, such that a
859 // transpose operation can be performed by leaving the flat values
860 // representation intact.
861 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
862 // The shape with affine layout resulting from that operation will be
863 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
864 // most minor.
865 //
866 // Essentially, given MinMaj(Di) the position of the Di dimension within the
867 // minor to major vector, and given T(Di) the index that the original Di
868 // dimension has within the transposed array, a layout is affine if
869 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
870 // vector of the affine layout.
871 std::vector<int64> inverse_permutation = InversePermutation(permutation);
872 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
873 Layout* layout = permuted_shape.mutable_layout();
874 layout->clear_minor_to_major();
875 for (auto index : LayoutUtil::MinorToMajor(shape())) {
876 layout->add_minor_to_major(inverse_permutation[index]);
877 }
878 Literal new_literal(permuted_shape);
879 for (int64 i = 0; i < shape().rank(); i++) {
880 new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
881 }
882 DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
883 ShapeUtil::ByteSizeOf(shape()));
884 std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
885 return new_literal;
886 }
887
888 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const889 Literal LiteralBase::SliceInternal(
890 const Shape& result_shape, absl::Span<const int64> start_indices) const {
891 Literal result_literal(result_shape);
892 DimensionVector new_indices(result_shape.rank());
893 CHECK(result_literal
894 .Populate<NativeT>([&](absl::Span<const int64> indices) {
895 for (int64 i = 0; i < result_shape.rank(); ++i) {
896 new_indices[i] = indices[i] + start_indices[i];
897 }
898 return Get<NativeT>(new_indices);
899 })
900 .ok());
901 for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
902 if (shape().is_dynamic_dimension(dnum)) {
903 int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
904 CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
905 dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
906 result_literal.SetDynamicSize(dnum, dynamic_size);
907 }
908 }
909 return result_literal;
910 }
911
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const912 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
913 absl::Span<const int64> limit_indices) const {
914 CHECK(shape().IsArray()) << "tuple is not supported for slice";
915
916 DimensionVector result_dimensions;
917 for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
918 CHECK_GE(start_indices[dnum], 0);
919 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
920 << "dnum = " << dnum;
921 int64 dimension = limit_indices[dnum] - start_indices[dnum];
922 CHECK_GE(dimension, 0) << "dnum = " << dnum;
923 result_dimensions.push_back(dimension);
924 }
925 auto result_shape =
926 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
927 LayoutUtil::MinorToMajor(shape()));
928 ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
929 switch (result_shape.element_type()) {
930 case PRED:
931 return SliceInternal<bool>(result_shape, start_indices);
932 case U8:
933 return SliceInternal<uint8>(result_shape, start_indices);
934 case U16:
935 return SliceInternal<uint16>(result_shape, start_indices);
936 case U32:
937 return SliceInternal<uint32>(result_shape, start_indices);
938 case U64:
939 return SliceInternal<uint64>(result_shape, start_indices);
940 case S8:
941 return SliceInternal<int8>(result_shape, start_indices);
942 case S16:
943 return SliceInternal<int16>(result_shape, start_indices);
944 case S32:
945 return SliceInternal<int32>(result_shape, start_indices);
946 case S64:
947 return SliceInternal<int64>(result_shape, start_indices);
948 case F16:
949 return SliceInternal<half>(result_shape, start_indices);
950 case BF16:
951 return SliceInternal<bfloat16>(result_shape, start_indices);
952 case F32:
953 return SliceInternal<float>(result_shape, start_indices);
954 case F64:
955 return SliceInternal<double>(result_shape, start_indices);
956 case C64:
957 return SliceInternal<complex64>(result_shape, start_indices);
958 case C128:
959 return SliceInternal<complex128>(result_shape, start_indices);
960 default:
961 LOG(FATAL) << "not yet implemented: "
962 << PrimitiveType_Name(result_shape.element_type());
963 }
964 }
965
Clone() const966 Literal LiteralBase::Clone() const {
967 Literal result(shape());
968 TF_CHECK_OK(result.CopyFrom(*this));
969 return result;
970 }
971
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const972 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
973 const ShapeIndex& shape_index) const {
974 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
975 CHECK(LayoutUtil::IsDenseArray(subshape));
976 switch (subshape.element_type()) {
977 case PRED:
978 return Get<bool>(multi_index, shape_index) ? "true" : "false";
979 case S8:
980 return StrCat(Get<int8>(multi_index, shape_index));
981 case S16:
982 return StrCat(Get<int16>(multi_index, shape_index));
983 case S32:
984 return StrCat(Get<int32>(multi_index, shape_index));
985 case S64:
986 return StrCat(Get<int64>(multi_index, shape_index));
987 case U8:
988 return StrCat(Get<uint8>(multi_index, shape_index));
989 case U16:
990 return StrCat(Get<uint16>(multi_index, shape_index));
991 case U32:
992 return StrCat(Get<uint32>(multi_index, shape_index));
993 case U64:
994 return StrCat(Get<uint64>(multi_index, shape_index));
995 case F16:
996 return RoundTripFpToString(Get<half>(multi_index, shape_index));
997 case F32:
998 return RoundTripFpToString(Get<float>(multi_index, shape_index));
999 case BF16:
1000 return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1001 case F64:
1002 return RoundTripFpToString(Get<double>(multi_index, shape_index));
1003 case C64: {
1004 complex64 c = Get<complex64>(multi_index, shape_index);
1005 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1006 RoundTripFpToString(c.imag()), ")");
1007 }
1008 case C128: {
1009 complex128 c = Get<complex128>(multi_index, shape_index);
1010 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1011 RoundTripFpToString(c.imag()), ")");
1012 }
1013 default:
1014 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1015 }
1016 }
1017
GetIntegralAsS64(absl::Span<const int64> multi_index) const1018 absl::optional<int64> LiteralBase::GetIntegralAsS64(
1019 absl::Span<const int64> multi_index) const {
1020 CHECK(LayoutUtil::IsDenseArray(shape()));
1021 switch (shape().element_type()) {
1022 case PRED:
1023 return Get<bool>(multi_index);
1024 case S8:
1025 return Get<int8>(multi_index);
1026 case U8:
1027 return Get<uint8>(multi_index);
1028 case S16:
1029 return Get<int16>(multi_index);
1030 case U16:
1031 return Get<uint16>(multi_index);
1032 case S32:
1033 return Get<int32>(multi_index);
1034 case U32:
1035 return Get<uint32>(multi_index);
1036 case S64:
1037 return Get<int64>(multi_index);
1038 case U64:
1039 return Get<uint64>(multi_index);
1040 default:
1041 return absl::nullopt;
1042 }
1043 }
1044
GetAsDouble(absl::Span<const int64> multi_index) const1045 absl::optional<double> LiteralBase::GetAsDouble(
1046 absl::Span<const int64> multi_index) const {
1047 CHECK(LayoutUtil::IsDenseArray(shape()));
1048 switch (shape().element_type()) {
1049 case F16:
1050 return static_cast<double>(Get<half>(multi_index));
1051 case F32:
1052 return static_cast<double>(Get<float>(multi_index));
1053 case F64:
1054 return Get<double>(multi_index);
1055 case BF16:
1056 return static_cast<double>(Get<bfloat16>(multi_index));
1057 default:
1058 return absl::nullopt;
1059 }
1060 }
1061
GetAsComplex128(absl::Span<const int64> multi_index) const1062 absl::optional<complex128> LiteralBase::GetAsComplex128(
1063 absl::Span<const int64> multi_index) const {
1064 switch (shape().element_type()) {
1065 case BF16:
1066 return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1067 case F16:
1068 return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1069 case F32:
1070 return {{Get<float>(multi_index), 0}};
1071 case F64:
1072 return {{Get<double>(multi_index), 0}};
1073 case C64:
1074 return {Get<complex64>(multi_index)};
1075 case C128:
1076 return {Get<complex128>(multi_index)};
1077 case S8:
1078 return {Get<int8>(multi_index)};
1079 default:
1080 return absl::nullopt;
1081 }
1082 }
1083
Hash() const1084 size_t LiteralBase::Hash() const {
1085 using tensorflow::Hash64;
1086 using tensorflow::Hash64Combine;
1087
1088 size_t hash_value = ShapeUtil::Hash(shape());
1089
1090 ShapeUtil::ForEachSubshape(
1091 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1092 if (!subshape.IsArray()) {
1093 return;
1094 }
1095
1096 CHECK(LayoutUtil::IsDense(subshape.layout()));
1097 hash_value = Hash64Combine(
1098 hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
1099 size_bytes(index)));
1100 });
1101
1102 return hash_value;
1103 }
1104
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)1105 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
1106 int64 value) {
1107 CHECK(LayoutUtil::IsDenseArray(shape()));
1108 switch (shape().element_type()) {
1109 case PRED:
1110 Set<bool>(multi_index, value);
1111 break;
1112 case U8:
1113 Set<uint8>(multi_index, value);
1114 break;
1115 case S32:
1116 Set<int32>(multi_index, value);
1117 break;
1118 case S64:
1119 Set<int64>(multi_index, value);
1120 break;
1121 case U32:
1122 Set<uint32>(multi_index, value);
1123 break;
1124 case U64:
1125 Set<uint64>(multi_index, value);
1126 break;
1127 default:
1128 return FailedPrecondition("Array element type is not integral: %s",
1129 PrimitiveType_Name(shape().element_type()));
1130 }
1131 return Status::OK();
1132 }
1133
SetFromDouble(absl::Span<const int64> multi_index,double value)1134 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
1135 double value) {
1136 CHECK(LayoutUtil::IsDenseArray(shape()));
1137 switch (shape().element_type()) {
1138 case F16:
1139 Set<half>(multi_index, Eigen::half(value));
1140 break;
1141 case F32:
1142 Set<float>(multi_index, value);
1143 break;
1144 case F64:
1145 Set<double>(multi_index, value);
1146 break;
1147 case BF16:
1148 Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1149 break;
1150 default:
1151 return FailedPrecondition("Array element type is not floating: %s",
1152 PrimitiveType_Name(shape().element_type()));
1153 }
1154 return Status::OK();
1155 }
1156
1157 namespace {
1158
ShapeToString(bool print_layout,const Shape & shape)1159 string ShapeToString(bool print_layout, const Shape& shape) {
1160 return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1161 : ShapeUtil::HumanString(shape);
1162 }
1163
1164 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1165 bool print_shape, bool print_layout,
1166 std::vector<string>* pieces);
1167
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1168 void TupleToStringHelper(const LiteralBase& literal,
1169 const ShapeIndex& shape_index, bool print_shape,
1170 bool print_layout, std::vector<string>* pieces) {
1171 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1172 pieces->push_back("(\n");
1173 std::vector<string> tuple_pieces;
1174 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1175 ShapeIndex element_index = shape_index;
1176 element_index.push_back(i);
1177 std::vector<string> element_pieces;
1178 ToStringHelper(literal, element_index, print_shape, print_layout,
1179 &element_pieces);
1180 tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1181 }
1182 pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1183 pieces->push_back("\n)");
1184 }
1185
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1186 void DenseArrayToStringHelper(const LiteralBase& literal,
1187 const ShapeIndex& shape_index, bool print_shape,
1188 bool print_layout, std::vector<string>* pieces) {
1189 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1190 int64 rank = subshape.rank();
1191
1192 std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1193 to_string_recursive = [&](absl::Span<const int64> dimensions,
1194 std::vector<int64>* accum_indices) {
1195 // dimensions.size() decreases by 1 at each recursive call,
1196 // and accum_indices->size() increases by 1.
1197 // Their sum is equal to the rank of the tensor.
1198 CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1199
1200 auto brace_to_string = [&](string brace) -> string {
1201 // Handle 1D tensor
1202 if (rank == 1) {
1203 return brace;
1204 }
1205 // Handle the innermost tensor of a 2D+ tensor.
1206 if (dimensions.size() == 1 && brace == "{") {
1207 return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
1208 }
1209 if (dimensions.size() == 1 && brace == "}") {
1210 return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1211 }
1212 // Handle the non-innermost tensors of a 2D+ tensor.
1213 if (brace == "{") {
1214 const int64 accum_indices_size = accum_indices->size();
1215 if (rank > 3 && !accum_indices->empty() &&
1216 accum_indices_size < rank) {
1217 int index = accum_indices->size() - 1;
1218 int value = accum_indices->back();
1219 return StrCat(brace, " /*i", index, "=", value, "*/\n");
1220 }
1221 return StrCat(brace, "\n");
1222 }
1223 return StrCat("\n", brace);
1224 };
1225
1226 if (dimensions.empty()) {
1227 // Display predicates as 0s and 1s so that the string is more dense.
1228 string elem;
1229 if (subshape.element_type() == PRED && rank > 0) {
1230 elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1231 } else {
1232 elem = literal.GetAsString(*accum_indices, shape_index);
1233 }
1234 pieces->push_back(elem);
1235 } else {
1236 pieces->push_back(brace_to_string("{"));
1237 for (int i = 0; i < dimensions[0]; ++i) {
1238 std::vector<int64> cloned_indices(*accum_indices);
1239 cloned_indices.push_back(i);
1240 to_string_recursive(dimensions.subspan(1), &cloned_indices);
1241 if (i < dimensions[0] - 1) {
1242 pieces->push_back(",");
1243 pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1244 }
1245 }
1246 pieces->push_back(brace_to_string("}"));
1247 }
1248 };
1249
1250 if (print_shape) {
1251 pieces->push_back(ShapeToString(print_layout, subshape));
1252 if (subshape.is_dynamic()) {
1253 pieces->push_back("(");
1254 for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
1255 pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1256 if (i < subshape.dimensions_size() - 1) {
1257 pieces->push_back(",");
1258 }
1259 }
1260 pieces->push_back(")");
1261 }
1262 pieces->push_back(" ");
1263 }
1264 std::vector<int64> indices = {};
1265 std::vector<int64> dimensions;
1266 dimensions.reserve(subshape.rank());
1267 for (int64 i = 0; i < subshape.rank(); ++i) {
1268 dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1269 }
1270 to_string_recursive(dimensions, &indices);
1271 }
1272
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1273 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1274 bool print_shape, bool print_layout,
1275 std::vector<string>* pieces) {
1276 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1277 CHECK(LayoutUtil::HasLayout(literal.shape()));
1278 CHECK(LayoutUtil::HasLayout(subshape));
1279 if (subshape.IsTuple()) {
1280 TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1281 pieces);
1282 } else if (subshape.IsToken()) {
1283 pieces->push_back("token");
1284 } else {
1285 CHECK(LayoutUtil::IsDenseArray(subshape));
1286 DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1287 pieces);
1288 }
1289 }
1290
1291 } // namespace
1292
ToString() const1293 string LiteralBase::ToString() const {
1294 std::vector<string> pieces;
1295 CHECK(LayoutUtil::HasLayout(this->shape()));
1296 ToStringHelper(*this, {}, /*print_shape=*/true,
1297 /*print_layout=*/false, &pieces);
1298 return absl::StrJoin(pieces, "");
1299 }
1300
ToStringOneline() const1301 string LiteralBase::ToStringOneline() const {
1302 return CompactOneline(ToString());
1303 }
1304
ToStringWithoutShape() const1305 string LiteralBase::ToStringWithoutShape() const {
1306 std::vector<string> pieces;
1307 CHECK(LayoutUtil::HasLayout(this->shape()));
1308 ToStringHelper(*this, {}, /*print_shape=*/false,
1309 /*print_layout=*/false, &pieces);
1310 return absl::StrJoin(pieces, "");
1311 }
1312
ToStringWithoutShapeOneline() const1313 string LiteralBase::ToStringWithoutShapeOneline() const {
1314 return CompactOneline(ToStringWithoutShape());
1315 }
1316
ToStringWithLayout() const1317 string LiteralBase::ToStringWithLayout() const {
1318 std::vector<string> pieces;
1319 CHECK(LayoutUtil::HasLayout(this->shape()));
1320 ToStringHelper(*this, {}, /*print_shape=*/true,
1321 /*print_layout=*/true, &pieces);
1322 return absl::StrJoin(pieces, "");
1323 }
1324
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1325 void LiteralBase::EachCellAsString(
1326 const std::function<void(absl::Span<const int64> indices,
1327 const string& value)>& per_cell) const {
1328 if (ShapeUtil::IsZeroElementArray(shape())) {
1329 return;
1330 }
1331 std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1332 shape(), /*linear_index=*/0);
1333 do {
1334 per_cell(indices, GetAsString(indices));
1335 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1336 }
1337
1338 namespace {
1339 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1340 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1341 const ConverterType& converter) {
1342 CHECK(src_literal.shape().IsArray());
1343 Literal result_literal(ShapeUtil::ChangeElementType(
1344 src_literal.shape(),
1345 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1346 auto src_data = src_literal.data<NativeSrcT>();
1347 auto dest_data = result_literal.template data<NativeDestT>();
1348 int64 num_elements = src_literal.element_count();
1349
1350 for (int64 i = 0; i < num_elements; ++i) {
1351 dest_data[i] = converter(src_data[i]);
1352 }
1353 return result_literal;
1354 }
1355
1356 template <typename NativeSrcT, typename NativeDestT>
1357 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1358 (std::is_same<NativeDestT, complex64>::value ||
1359 std::is_same<NativeDestT, complex128>::value),
1360 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1361 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1362 auto converter = [](NativeSrcT src) {
1363 return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1364 };
1365 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1366 src_literal, converter);
1367 }
1368
1369 template <typename NativeSrcT, typename NativeDestT>
1370 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1371 (!std::is_same<NativeDestT, complex64>::value &&
1372 !std::is_same<NativeDestT, complex128>::value),
1373 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1374 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1375 auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1376 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1377 src_literal, converter);
1378 }
1379
1380 template <typename NativeSrcT, typename NativeDestT>
1381 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1382 !std::is_same<NativeDestT, Eigen::half>::value),
1383 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1384 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1385 auto converter = [](NativeSrcT src) {
1386 return absl::bit_cast<NativeDestT>(GetRawValue(src));
1387 };
1388 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1389 src_literal, converter);
1390 }
1391
1392 template <typename NativeSrcT, typename NativeDestT>
1393 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1394 std::is_same<NativeDestT, Eigen::half>::value),
1395 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1396 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1397 // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1398 // cast to unsigned short and then use raw_uint16_to_half.
1399 auto converter = [](NativeSrcT src) {
1400 return Eigen::half_impl::raw_uint16_to_half(
1401 absl::bit_cast<uint16>(GetRawValue(src)));
1402 };
1403 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1404 src_literal, converter);
1405 }
1406
1407 // This template specialization is here to make the compiler happy. bit_cast has
1408 // a static check that the types are the same size. This specialization should
1409 // never be used because the source and destination types are checked for
1410 // identical sizes higher up.
1411 template <typename NativeSrcT, typename NativeDestT>
1412 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1413 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1414 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1415 LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1416 }
1417
1418 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1419 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1420 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1421 if (bitcast) {
1422 return BitcastBetweenNativeTypes<
1423 typename primitive_util::PrimitiveTypeToNative<
1424 primitive_src_type>::type,
1425 typename primitive_util::PrimitiveTypeToNative<
1426 primitive_dest_type>::type>(src_literal);
1427 } else {
1428 return ConvertBetweenNativeTypes<
1429 typename primitive_util::PrimitiveTypeToNative<
1430 primitive_src_type>::type,
1431 typename primitive_util::PrimitiveTypeToNative<
1432 primitive_dest_type>::type>(src_literal);
1433 }
1434 }
1435
1436 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1437 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1438 PrimitiveType primitive_dest_type,
1439 bool bitcast) {
1440 switch (primitive_dest_type) {
1441 #define CONVERT_IF_TYPES_MATCH(type) \
1442 case (type): \
1443 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1444 bitcast);
1445 CONVERT_IF_TYPES_MATCH(PRED)
1446 CONVERT_IF_TYPES_MATCH(S8)
1447 CONVERT_IF_TYPES_MATCH(S16)
1448 CONVERT_IF_TYPES_MATCH(S32)
1449 CONVERT_IF_TYPES_MATCH(S64)
1450 CONVERT_IF_TYPES_MATCH(U8)
1451 CONVERT_IF_TYPES_MATCH(U16)
1452 CONVERT_IF_TYPES_MATCH(U32)
1453 CONVERT_IF_TYPES_MATCH(U64)
1454 CONVERT_IF_TYPES_MATCH(F16)
1455 CONVERT_IF_TYPES_MATCH(F32)
1456 CONVERT_IF_TYPES_MATCH(F64)
1457 CONVERT_IF_TYPES_MATCH(BF16)
1458 #undef CONVERT_IF_TYPES_MATCH
1459 case C64:
1460 if (bitcast) {
1461 break;
1462 }
1463 return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1464 case C128:
1465 if (bitcast) {
1466 break;
1467 }
1468 return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1469 // Other types are not yet supported.
1470 default:
1471 break;
1472 }
1473 return Unimplemented("Converting from type %s to type %s is not implemented.",
1474 PrimitiveType_Name(src_literal.shape().element_type()),
1475 PrimitiveType_Name(primitive_dest_type));
1476 }
1477
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1478 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1479 PrimitiveType primitive_dest_type,
1480 bool bitcast) {
1481 TF_RET_CHECK(literal.shape().IsArray());
1482 if (literal.shape().element_type() == primitive_dest_type) {
1483 return literal.Clone();
1484 }
1485 switch (literal.shape().element_type()) {
1486 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1487 case (type): \
1488 return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1489 bitcast);
1490 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1491 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1492 CONVERT_IF_DEST_TYPE_MATCHES(S16)
1493 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1494 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1495 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1496 CONVERT_IF_DEST_TYPE_MATCHES(U16)
1497 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1498 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1499 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1500 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1501 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1502 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1503 #undef CONVERT_IF_DEST_TYPE_MATCHES
1504 // Other types are not yet supported.
1505 default:
1506 return Unimplemented("%s from type %s to type %s is not implemented.",
1507 (bitcast ? "Bitcast converting" : "Converting"),
1508 PrimitiveType_Name(literal.shape().element_type()),
1509 PrimitiveType_Name(primitive_dest_type));
1510 }
1511 }
1512
1513 } // namespace
1514
Convert(PrimitiveType primitive_dest_type) const1515 StatusOr<Literal> LiteralBase::Convert(
1516 PrimitiveType primitive_dest_type) const {
1517 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1518 }
1519
BitcastConvert(PrimitiveType primitive_dest_type) const1520 StatusOr<Literal> LiteralBase::BitcastConvert(
1521 PrimitiveType primitive_dest_type) const {
1522 if (primitive_util::BitWidth(shape().element_type()) !=
1523 primitive_util::BitWidth(primitive_dest_type)) {
1524 return InvalidArgument(
1525 "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1526 "%d",
1527 PrimitiveType_Name(shape().element_type()),
1528 PrimitiveType_Name(primitive_dest_type),
1529 primitive_util::BitWidth(shape().element_type()),
1530 primitive_util::BitWidth(primitive_dest_type));
1531 }
1532 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1533 }
1534
ConvertToShape(const Shape & dest_shape) const1535 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1536 if (!dest_shape.IsTuple()) {
1537 return Convert(dest_shape.element_type());
1538 }
1539 std::vector<Literal> elements;
1540 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1541 auto element = LiteralSlice(*this, {i});
1542 TF_ASSIGN_OR_RETURN(
1543 auto new_element,
1544 element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1545 elements.push_back(std::move(new_element));
1546 }
1547 return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1548 }
1549
MoveIntoTuple(absl::Span<Literal> elements)1550 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1551 absl::Span<Literal> elements) {
1552 std::vector<Shape> element_shapes;
1553 for (const Literal& element : elements) {
1554 element_shapes.push_back(element.shape());
1555 }
1556 Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1557 /*allocate_arrays=*/false);
1558 for (int i = 0, end = elements.size(); i < end; ++i) {
1559 TF_CHECK_OK(
1560 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1561 }
1562 return literal;
1563 }
1564
1565 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1566 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1567 const LiteralBase::Piece& src) {
1568 auto dest_shape = subshape();
1569 auto src_shape = src.subshape();
1570
1571 // At least one shape has to be static as bound.
1572 CHECK(dest_shape.is_static() || src_shape.is_static());
1573 auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1574 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1575 return;
1576 }
1577 std::vector<int64> index(dest_shape.rank());
1578 do {
1579 bool out_of_bound = false;
1580 for (int64 i = 0; i < index.size(); ++i) {
1581 // Do not copy elements beyond dynamic bound.
1582 if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1583 out_of_bound = true;
1584 }
1585 }
1586 if (out_of_bound) {
1587 continue;
1588 }
1589 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1590 index)] =
1591 src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1592 src_shape, index)];
1593 } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1594 }
1595
1596 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1597 bool LiteralBase::Piece::EqualElementsInternal(
1598 const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1599 if (multi_index->size() == subshape().rank()) {
1600 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1601 }
1602 for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1603 multi_index->push_back(i);
1604 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1605 return false;
1606 }
1607 multi_index->pop_back();
1608 }
1609 return true;
1610 }
1611
EqualDynamicSize(const LiteralBase::Piece & other) const1612 bool LiteralBase::Piece::EqualDynamicSize(
1613 const LiteralBase::Piece& other) const {
1614 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1615 if (subshape().is_static()) {
1616 return true;
1617 }
1618
1619 for (int64 i = 0; i < subshape().rank(); ++i) {
1620 if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1621 return false;
1622 }
1623 }
1624 return true;
1625 }
1626
EqualElements(const LiteralBase::Piece & other) const1627 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1628 if (subshape().is_static() &&
1629 ShapeUtil::Equal(subshape(), other.subshape()) &&
1630 LayoutUtil::IsDenseArray(subshape())) {
1631 CHECK_EQ(size_bytes(), other.size_bytes());
1632 return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1633 }
1634
1635 std::vector<int64> multi_index;
1636 switch (subshape().element_type()) {
1637 case PRED:
1638 return EqualElementsInternal<bool>(other, &multi_index);
1639 case S8:
1640 return EqualElementsInternal<int8>(other, &multi_index);
1641 case S16:
1642 return EqualElementsInternal<int16>(other, &multi_index);
1643 case S32:
1644 return EqualElementsInternal<int32>(other, &multi_index);
1645 case S64:
1646 return EqualElementsInternal<int64>(other, &multi_index);
1647 case U8:
1648 return EqualElementsInternal<uint8>(other, &multi_index);
1649 case U16:
1650 return EqualElementsInternal<uint16>(other, &multi_index);
1651 case U32:
1652 return EqualElementsInternal<uint32>(other, &multi_index);
1653 case U64:
1654 return EqualElementsInternal<uint64>(other, &multi_index);
1655 case F32:
1656 return EqualElementsInternal<float>(other, &multi_index);
1657 case F64:
1658 return EqualElementsInternal<double>(other, &multi_index);
1659 case F16:
1660 return EqualElementsInternal<half>(other, &multi_index);
1661 case BF16:
1662 return EqualElementsInternal<bfloat16>(other, &multi_index);
1663 case C64:
1664 return EqualElementsInternal<complex64>(other, &multi_index);
1665 case C128:
1666 return EqualElementsInternal<complex128>(other, &multi_index);
1667 default:
1668 LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1669 << PrimitiveType_Name(subshape().element_type());
1670 }
1671 }
1672
operator ==(const LiteralBase & other) const1673 bool LiteralBase::operator==(const LiteralBase& other) const {
1674 // Checking the structure of tuple literals. Checks for dense arrays are
1675 // performed below.
1676 if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1677 return false;
1678 }
1679
1680 return root_piece().ForEachSubpieceWithBool(
1681 [&](const ShapeIndex& index, const Piece& piece) {
1682 const Piece& other_piece = other.piece(index);
1683 const Shape& subshape = piece.subshape();
1684 const Shape& other_subshape = other_piece.subshape();
1685 if (subshape.element_type() != other_subshape.element_type()) {
1686 return false;
1687 }
1688 if (!piece.subshape().IsArray()) {
1689 return true;
1690 }
1691 if (subshape.rank() != other_subshape.rank()) {
1692 return false;
1693 }
1694
1695 for (int64 i = 0; i < subshape.rank(); ++i) {
1696 if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1697 return false;
1698 }
1699 }
1700
1701 if (!piece.EqualElements(other_piece)) {
1702 return false;
1703 }
1704 return true;
1705 });
1706 }
1707
1708 namespace {
1709
1710 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1711 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1712 NativeT value) {
1713 for (int64 i = 0; i < data.size(); ++i) {
1714 if (data[i] != value) {
1715 return false;
1716 }
1717 }
1718 return true;
1719 }
1720
1721 } // namespace
1722
IsAll(int8 value) const1723 bool LiteralBase::IsAll(int8 value) const {
1724 return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1725 const Piece& piece) {
1726 if (!piece.subshape().IsArray()) {
1727 return true;
1728 }
1729
1730 auto piece_is_all = [&]() {
1731 switch (shape().element_type()) {
1732 case U8:
1733 if (value >= 0) {
1734 return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1735 }
1736 return false;
1737 case U16:
1738 if (value >= 0) {
1739 return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1740 }
1741 return false;
1742 case U32:
1743 if (value >= 0) {
1744 return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1745 }
1746 return false;
1747 case U64:
1748 if (value >= 0) {
1749 return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1750 }
1751 return false;
1752 case S8:
1753 return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1754 case S16:
1755 return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1756 case S32:
1757 return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1758 case S64:
1759 return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1760 case F32:
1761 return AllElementsEqualValue<float>(piece.data<float>(), value);
1762 case F64:
1763 return AllElementsEqualValue<double>(piece.data<double>(), value);
1764 case F16:
1765 return AllElementsEqualValue<half>(piece.data<half>(),
1766 static_cast<half>(value));
1767 case BF16:
1768 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1769 static_cast<bfloat16>(value));
1770 case PRED:
1771 if (value == 0) {
1772 return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1773 }
1774 if (value == 1) {
1775 return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1776 }
1777 return false;
1778 default:
1779 return false;
1780 }
1781 return false;
1782 };
1783
1784 if (!piece_is_all()) {
1785 return false;
1786 }
1787 return true;
1788 });
1789 }
1790
IsAllFloat(float value) const1791 bool LiteralBase::IsAllFloat(float value) const {
1792 return root_piece().ForEachSubpieceWithBool(
1793 [&](const ShapeIndex& index, const Piece& piece) {
1794 if (!piece.subshape().IsArray()) {
1795 return true;
1796 }
1797
1798 switch (shape().element_type()) {
1799 case F32:
1800 return AllElementsEqualValue<float>(piece.data<float>(), value);
1801 case F64:
1802 return AllElementsEqualValue<double>(piece.data<double>(), value);
1803 case F16:
1804 return AllElementsEqualValue<half>(piece.data<half>(),
1805 static_cast<half>(value));
1806 case BF16:
1807 return AllElementsEqualValue<bfloat16>(
1808 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1809 default:
1810 return false;
1811 }
1812 });
1813 }
1814
IsAllComplex(complex64 value) const1815 bool LiteralBase::IsAllComplex(complex64 value) const {
1816 switch (shape().element_type()) {
1817 case C64:
1818 return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1819 value);
1820 case C128:
1821 return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1822 value);
1823 default:
1824 return false;
1825 }
1826 }
1827
IsAllFirst() const1828 bool LiteralBase::IsAllFirst() const {
1829 return root_piece().ForEachSubpieceWithBool(
1830 [&](const ShapeIndex& index, const Piece& piece) {
1831 if (!piece.subshape().IsArray()) {
1832 return true;
1833 }
1834
1835 // Empty shapes are not all the first element since there is no first
1836 // element.
1837 if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1838 return false;
1839 }
1840 auto piece_is_all = [&]() {
1841 switch (piece.subshape().element_type()) {
1842 case PRED: {
1843 auto data = piece.data<bool>();
1844 return AllElementsEqualValue<bool>(data, data[0]);
1845 }
1846 // 8 bit types
1847 case S8: {
1848 auto data = piece.data<int8>();
1849 return AllElementsEqualValue<int8>(data, data[0]);
1850 }
1851 case U8: {
1852 auto data = piece.data<uint8>();
1853 return AllElementsEqualValue<uint8>(data, data[0]);
1854 }
1855 // 16 bit types
1856 case BF16: {
1857 auto data = piece.data<bfloat16>();
1858 return AllElementsEqualValue<bfloat16>(data, data[0]);
1859 }
1860 case F16: {
1861 auto data = piece.data<half>();
1862 return AllElementsEqualValue<half>(data, data[0]);
1863 }
1864 case S16: {
1865 auto data = piece.data<int16>();
1866 return AllElementsEqualValue<int16>(data, data[0]);
1867 }
1868 case U16: {
1869 auto data = piece.data<uint16>();
1870 return AllElementsEqualValue<uint16>(data, data[0]);
1871 }
1872 // 32 bit types
1873 case F32: {
1874 auto data = piece.data<float>();
1875 return AllElementsEqualValue<float>(data, data[0]);
1876 }
1877 case U32: {
1878 auto data = piece.data<uint32>();
1879 return AllElementsEqualValue<uint32>(data, data[0]);
1880 }
1881 case S32: {
1882 auto data = piece.data<int32>();
1883 return AllElementsEqualValue<int32>(data, data[0]);
1884 }
1885 // 64 bit types
1886 case C64: {
1887 auto data = piece.data<complex64>();
1888 return AllElementsEqualValue<complex64>(data, data[0]);
1889 }
1890 case F64: {
1891 auto data = piece.data<double>();
1892 return AllElementsEqualValue<double>(data, data[0]);
1893 }
1894 case S64: {
1895 auto data = piece.data<int64>();
1896 return AllElementsEqualValue<int64>(data, data[0]);
1897 }
1898 case U64: {
1899 auto data = piece.data<uint64>();
1900 return AllElementsEqualValue<uint64>(data, data[0]);
1901 }
1902
1903 case C128: {
1904 auto data = piece.data<complex128>();
1905 return AllElementsEqualValue<complex128>(data, data[0]);
1906 }
1907 default:
1908 return false;
1909 }
1910 };
1911
1912 if (!piece_is_all()) {
1913 return false;
1914 }
1915 return true;
1916 });
1917 }
1918
IsR1Iota() const1919 bool LiteralBase::IsR1Iota() const {
1920 if (!shape().IsArray()) {
1921 return false;
1922 }
1923
1924 if (shape().rank() != 1) {
1925 return false;
1926 }
1927
1928 auto is_iota_at_idx = [&](const int64 idx) {
1929 switch (shape().element_type()) {
1930 case U8:
1931 return static_cast<int64>(Get<uint8>({idx})) == idx;
1932 case U16:
1933 return static_cast<int64>(Get<uint16>({idx})) == idx;
1934 case U32:
1935 return static_cast<int64>(Get<uint32>({idx})) == idx;
1936 case U64:
1937 return static_cast<int64>(Get<uint64>({idx})) == idx;
1938 case S8:
1939 return Get<int8>({idx}) == idx;
1940 case S16:
1941 return Get<int16>({idx}) == idx;
1942 case S32:
1943 return Get<int32>({idx}) == idx;
1944 case S64:
1945 return Get<int64>({idx}) == idx;
1946 case F32:
1947 return Get<float>({idx}) == idx;
1948 case F64:
1949 return Get<double>({idx}) == idx;
1950 case F16:
1951 return Get<half>({idx}) == static_cast<half>(idx);
1952 case BF16:
1953 return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1954 case C64:
1955 return Get<complex64>({idx}) == complex64(idx, 0.0f);
1956 case C128:
1957 return Get<complex128>({idx}) == complex128(idx, 0.0f);
1958 // pred, token, opaque, tuple, etc. are all not iota.
1959 default:
1960 return false;
1961 }
1962 };
1963
1964 const int64 elements = ShapeUtil::ElementsIn(shape());
1965 for (int64 idx = 0; idx < elements; ++idx) {
1966 if (!is_iota_at_idx(idx)) {
1967 return false;
1968 }
1969 }
1970
1971 return true;
1972 }
1973
IsZero(absl::Span<const int64> indices) const1974 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1975 CHECK(shape().IsArray());
1976 switch (shape().element_type()) {
1977 case U8:
1978 return Get<uint8>(indices) == 0;
1979 case U16:
1980 return Get<uint16>(indices) == 0;
1981 case U32:
1982 return Get<uint32>(indices) == 0;
1983 case U64:
1984 return Get<uint64>(indices) == 0;
1985 case S8:
1986 return Get<int8>(indices) == 0;
1987 case S16:
1988 return Get<int16>(indices) == 0;
1989 case S32:
1990 return Get<int32>(indices) == 0;
1991 case S64:
1992 return Get<int64>(indices) == 0;
1993 case F32:
1994 return Get<float>(indices) == 0.0f;
1995 case F64:
1996 return Get<double>(indices) == 0.0;
1997 case C64:
1998 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1999 case C128:
2000 return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2001 case F16:
2002 return Get<half>(indices) == static_cast<half>(0.0f);
2003 case BF16:
2004 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2005 case PRED:
2006 return Get<bool>(indices) == false;
2007 default:
2008 LOG(FATAL) << "Input literal must be an array.";
2009 }
2010 }
2011
2012 namespace {
2013
2014 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2015 void CopyToRepeatedField(RepeatedFieldT* dest,
2016 const absl::Span<const NativeT> src) {
2017 *dest = RepeatedFieldT(src.begin(), src.end());
2018 }
2019
2020 } // namespace
2021
WriteToProto(LiteralProto * proto) const2022 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2023 *proto->mutable_shape() = subshape().ToProto();
2024 switch (subshape().element_type()) {
2025 case PRED:
2026 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2027 break;
2028 case S8:
2029 proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
2030 element_count());
2031 break;
2032 case U8:
2033 proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
2034 element_count());
2035 break;
2036 case U32:
2037 CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
2038 break;
2039 case U64:
2040 CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
2041 break;
2042 case S32:
2043 CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
2044 break;
2045 case S64:
2046 CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
2047 break;
2048 case U16:
2049 *proto->mutable_u16s() = string(
2050 reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2051 if (!kLittleEndian) {
2052 ConvertEndianShort(proto->mutable_u16s());
2053 }
2054 break;
2055 case S16:
2056 *proto->mutable_s16s() = string(
2057 reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2058 if (!kLittleEndian) {
2059 ConvertEndianShort(proto->mutable_s16s());
2060 }
2061 break;
2062 case F16:
2063 *proto->mutable_f16s() = string(
2064 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2065 if (!kLittleEndian) {
2066 ConvertEndianShort(proto->mutable_f16s());
2067 }
2068 break;
2069 case BF16:
2070 *proto->mutable_bf16s() = string(
2071 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2072 if (!kLittleEndian) {
2073 ConvertEndianShort(proto->mutable_bf16s());
2074 }
2075 break;
2076 case F32:
2077 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2078 break;
2079 case F64:
2080 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2081 break;
2082 case C64:
2083 for (complex64 value : data<complex64>()) {
2084 proto->add_c64s(value.real());
2085 proto->add_c64s(value.imag());
2086 }
2087 break;
2088 case C128:
2089 for (complex128 value : data<complex128>()) {
2090 proto->add_c128s(value.real());
2091 proto->add_c128s(value.imag());
2092 }
2093 break;
2094 case TUPLE:
2095 case TOKEN:
2096 // Nothing to do but assign the shape which is done above.
2097 return;
2098 default:
2099 // TODO(b/111551621): Support serializing more PrimitiveTypes.
2100 LOG(FATAL) << "Unhandled primitive type "
2101 << PrimitiveType_Name(subshape().element_type());
2102 }
2103 }
2104
untyped_data() const2105 const void* LiteralBase::Piece::untyped_data() const {
2106 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2107 return buffer();
2108 }
2109
untyped_data()2110 void* LiteralBase::Piece::untyped_data() {
2111 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2112 return buffer();
2113 }
2114
2115 namespace {
2116
2117 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2118 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2119 const RepeatedFieldT& src) {
2120 if (dest.size() != src.size()) {
2121 return InvalidArgument(
2122 "Expected %lu elements in LiteralProto repeated field, has %d",
2123 dest.size(), src.size());
2124 }
2125 std::copy(src.begin(), src.end(), dest.begin());
2126 return Status::OK();
2127 }
2128
2129 } // namespace
2130
CopyFromProto(const LiteralProto & proto)2131 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2132 // These conditions should have been checked in
2133 // MutableLiteralBase::CreateFromProto.
2134 TF_RET_CHECK(proto.has_shape());
2135 Shape shape(proto.shape());
2136 TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2137 TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2138
2139 switch (subshape().element_type()) {
2140 case PRED:
2141 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2142 break;
2143 case S8: {
2144 auto s8_data = data<int8>();
2145 TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2146 std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2147 } break;
2148 case U8: {
2149 auto u8_data = data<uint8>();
2150 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2151 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2152 } break;
2153 case S32:
2154 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2155 break;
2156 case S64:
2157 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2158 break;
2159 case U32:
2160 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2161 break;
2162 case U64:
2163 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2164 break;
2165 case S16: {
2166 const string& s(proto.s16s());
2167 TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2168 memcpy(untyped_data(), s.data(), s.size());
2169 if (!kLittleEndian) {
2170 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2171 }
2172 } break;
2173 case U16: {
2174 const string& s(proto.u16s());
2175 TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2176 memcpy(untyped_data(), s.data(), s.size());
2177 if (!kLittleEndian) {
2178 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2179 }
2180 } break;
2181 case F16: {
2182 const string& s(proto.f16s());
2183 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2184 memcpy(untyped_data(), s.data(), s.size());
2185 if (!kLittleEndian) {
2186 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2187 }
2188 } break;
2189
2190 case BF16: {
2191 const string& s(proto.bf16s());
2192 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2193 memcpy(untyped_data(), s.data(), s.size());
2194 if (!kLittleEndian) {
2195 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2196 }
2197 } break;
2198 case F32:
2199 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2200 break;
2201 case F64:
2202 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2203 break;
2204 case C64: {
2205 auto complex_data = data<complex64>();
2206 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2207 for (int64 i = 0; i < complex_data.size(); ++i) {
2208 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2209 }
2210 break;
2211 }
2212 case C128: {
2213 auto complex_data = data<complex128>();
2214 const int64 complex_data_size_doubled = complex_data.size() * 2;
2215 TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2216 for (int64 i = 0, end = complex_data.size(); i < end; ++i) {
2217 complex_data[i] =
2218 complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2219 }
2220 break;
2221 }
2222 case TUPLE:
2223 return InvalidArgument("Should not be called on tuple shapes: %s",
2224 ShapeUtil::HumanString(subshape()));
2225 default:
2226 return InvalidArgument("Is called on unsupported shape: %s",
2227 ShapeUtil::HumanString(subshape()));
2228 }
2229 return Status::OK();
2230 }
2231
ToProto() const2232 LiteralProto LiteralBase::ToProto() const {
2233 LiteralProto proto;
2234 root_piece().ForEachSubpiece(
2235 [&](const ShapeIndex& index, const Piece& piece) {
2236 LiteralProto* proto_piece = &proto;
2237 for (int64 i : index) {
2238 while (proto_piece->tuple_literals_size() <= i) {
2239 proto_piece->add_tuple_literals();
2240 }
2241 proto_piece = proto_piece->mutable_tuple_literals(i);
2242 }
2243 piece.WriteToProto(proto_piece);
2244 });
2245
2246 return proto;
2247 }
2248
untyped_data(const ShapeIndex & shape_index) const2249 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2250 return piece(shape_index).untyped_data();
2251 }
2252
untyped_data(const ShapeIndex & shape_index)2253 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2254 return piece(shape_index).untyped_data();
2255 }
2256
size_bytes(const ShapeIndex & shape_index) const2257 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2258 return piece(shape_index).size_bytes();
2259 }
2260
GetR1U8AsString() const2261 string LiteralBase::GetR1U8AsString() const {
2262 CHECK(shape().IsArray());
2263 CHECK_EQ(shape().rank(), 1);
2264 CHECK_EQ(shape().element_type(), U8);
2265 return string(absl::bit_cast<const char*>(data<uint8>().data()),
2266 ShapeUtil::ElementsIn(shape()));
2267 }
2268
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2269 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2270 Piece* src_piece,
2271 Piece* dest_piece) {
2272 DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2273 << "src_piece has shape: "
2274 << ShapeUtil::HumanString(src_piece->subshape())
2275 << "dest_piece has shape: "
2276 << ShapeUtil::HumanString(dest_piece->subshape());
2277 if (shape.IsTuple()) {
2278 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2279 const Shape& subshape = shape.tuple_shapes(i);
2280
2281 auto child_piece = Piece();
2282 child_piece.set_subshape(&subshape);
2283
2284 CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2285
2286 dest_piece->emplace_back(std::move(child_piece));
2287 }
2288 } else if (shape.IsArray()) {
2289 dest_piece->set_buffer(src_piece->buffer());
2290 dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
2291 } else {
2292 // If the shape is neither an array nor tuple, then it must be
2293 // zero-sized. Otherwise, some memory needs to be allocated for it.
2294 CHECK_EQ(dest_piece->size_bytes(), 0);
2295 }
2296 }
2297
~MutableLiteralBase()2298 MutableLiteralBase::~MutableLiteralBase() {}
2299
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2300 MutableBorrowingLiteral::MutableBorrowingLiteral(
2301 const MutableBorrowingLiteral& literal)
2302 : MutableLiteralBase() {
2303 shape_ = absl::make_unique<Shape>(literal.shape());
2304 CHECK(LayoutUtil::HasLayout(*shape_));
2305
2306 root_piece_ = new Piece();
2307 root_piece_->set_subshape(shape_.get());
2308
2309 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2310 }
2311
operator =(const MutableBorrowingLiteral & literal)2312 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2313 const MutableBorrowingLiteral& literal) {
2314 shape_ = absl::make_unique<Shape>(literal.shape());
2315 CHECK(LayoutUtil::HasLayout(*shape_));
2316
2317 root_piece_ = new Piece();
2318 root_piece_->set_subshape(shape_.get());
2319
2320 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2321
2322 return *this;
2323 }
2324
MutableBorrowingLiteral(MutableLiteralBase * literal)2325 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2326 : MutableLiteralBase() {
2327 shape_ = absl::make_unique<Shape>(literal->shape());
2328 CHECK(LayoutUtil::HasLayout(*shape_));
2329
2330 root_piece_ = new Piece();
2331 root_piece_->set_subshape(shape_.get());
2332
2333 CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2334 }
2335
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2336 MutableBorrowingLiteral::MutableBorrowingLiteral(
2337 MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2338 : MutableLiteralBase() {
2339 shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2340 CHECK(LayoutUtil::HasLayout(*shape_));
2341
2342 root_piece_ = new Piece();
2343 root_piece_->set_subshape(shape_.get());
2344
2345 CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2346 }
2347
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2348 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2349 const Shape& shape)
2350 : MutableLiteralBase() {
2351 shape_ = absl::make_unique<Shape>(shape);
2352 CHECK(LayoutUtil::HasLayout(*shape_));
2353 CHECK(!shape_->IsTuple());
2354
2355 root_piece_ = new Piece();
2356 root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2357 root_piece_->set_subshape(shape_.get());
2358 }
2359
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2360 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2361 const Shape& shape)
2362 : MutableLiteralBase() {
2363 shape_ = absl::make_unique<Shape>(shape);
2364 if (!shape_->IsTuple()) {
2365 CHECK_EQ(src_buf_ptrs.size(), 1);
2366 root_piece_ = new Piece();
2367 root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2368 root_piece_->set_subshape(shape_.get());
2369 } else {
2370 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2371 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2372 root_piece_ = new Piece();
2373 root_piece_->set_subshape(shape_.get());
2374
2375 for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2376 Piece child_piece;
2377 const auto& src_shape = shape_->tuple_shapes(i);
2378 CHECK(src_shape.IsArray());
2379 child_piece.set_subshape(&src_shape);
2380 child_piece.set_buffer(src_buf_ptrs[i]);
2381 root_piece_->emplace_back(std::move(child_piece));
2382 }
2383 }
2384 }
2385
~MutableBorrowingLiteral()2386 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2387 if (root_piece_ != nullptr) {
2388 delete root_piece_;
2389 }
2390 }
2391
LiteralSlice(const LiteralBase & literal)2392 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2393 : LiteralBase(), root_piece_(&literal.root_piece()) {}
2394
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2395 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2396 const ShapeIndex& view_root)
2397 : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2398
BuildPieceSubtree(const Shape & shape,Piece * piece)2399 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2400 CHECK(shape.IsTuple());
2401 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2402 const Shape& subshape = shape.tuple_shapes(i);
2403
2404 auto child_piece = Piece();
2405 child_piece.set_subshape(&subshape);
2406
2407 if (subshape.IsTuple()) {
2408 BuildPieceSubtree(subshape, &child_piece);
2409 }
2410
2411 piece->emplace_back(std::move(child_piece));
2412 }
2413 }
2414
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2415 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2416 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2417 CHECK(shape_->IsArray());
2418 CHECK(LayoutUtil::HasLayout(*shape_));
2419
2420 root_piece_ = Piece();
2421 root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2422 root_piece_.set_subshape(shape_.get());
2423 }
2424
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2425 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2426 const Shape& shape)
2427 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2428 CHECK(shape_->IsTuple());
2429 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2430 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2431 root_piece_ = Piece();
2432 root_piece_.set_subshape(shape_.get());
2433 BuildPieceSubtree(*shape_, &root_piece_);
2434
2435 for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2436 const auto& src_shape = shape_->tuple_shapes(i);
2437 CHECK(src_shape.IsArray());
2438 root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2439 }
2440 }
2441
2442 } // namespace xla
2443