1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <type_traits>
24 #include <vector>
25
26 #include "absl/base/casts.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/types/optional.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/permutation_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/mem.h"
46 #include "tensorflow/core/platform/types.h"
47
48 namespace xla {
49 namespace {
50
51 using absl::StrCat;
52
53 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
54 // Literals can be used as DMA targets, which can require alignment. We
55 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
56 // alignment.
57 constexpr int kMinimumAlignment = 64;
58
59 // Converts between little and big endian.
60 //
61 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)62 void ConvertEndianShort(string* bytes) {
63 CHECK_EQ(bytes->size() % 2, 0);
64 for (int64_t i = 0, end = bytes->size(); i < end; i += 2) {
65 std::swap((*bytes)[i], (*bytes)[i + 1]);
66 }
67 }
68
ConvertEndianShort(char * bytes,int64_t size)69 void ConvertEndianShort(char* bytes, int64_t size) {
70 CHECK_EQ(size % 2, 0);
71 for (int64_t i = 0; i < size; i += 2) {
72 std::swap(bytes[i], bytes[i + 1]);
73 }
74 }
75
CompactOneline(const string & input)76 string CompactOneline(const string& input) {
77 string result;
78 std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
79 bool first = true;
80 // Concatenate elements in "v" with spaces separating them, but ignoring
81 // empty entries.
82 for (const auto& s : v) {
83 if (s.empty()) {
84 continue;
85 }
86 absl::StrAppend(&result, (first ? "" : " "), s);
87 first = false;
88 }
89 return result;
90 }
91
92 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
93 // able to transparently access the raw 16-bit value contained within.
94 template <typename T>
GetRawValue(T val)95 T GetRawValue(T val) {
96 return val;
97 }
GetRawValue(Eigen::half val)98 uint16 GetRawValue(Eigen::half val) {
99 return Eigen::numext::bit_cast<uint16>(val);
100 }
101
LiteralProtoHasValues(const LiteralProto & proto)102 bool LiteralProtoHasValues(const LiteralProto& proto) {
103 return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
104 proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
105 proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
106 proto.c64s_size() || proto.c128s_size() ||
107 proto.tuple_literals_size() || !proto.f16s().empty() ||
108 !proto.bf16s().empty() || !proto.u16s().empty() ||
109 !proto.s16s().empty();
110 }
111
112 } // namespace
113
~LiteralBase()114 LiteralBase::~LiteralBase() {}
115
operator <<(std::ostream & out,const Literal & literal)116 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
117 out << literal.ToString();
118 return out;
119 }
120
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)121 MutableLiteralBase::StrideConfig::StrideConfig(
122 const Shape& source_shape, const Shape& dest_shape,
123 absl::Span<const int64> dimensions)
124 : dimensions(dimensions),
125 base(dimensions.size(), 0),
126 step(dimensions.size(), 1) {
127 if (!dimensions.empty()) {
128 // Selects the shape with the largest minor dimension as the one upon
129 // which to run the tight stride loop.
130 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
131 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
132 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
133 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
134 } else {
135 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
136 source_stride =
137 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
138 }
139 minor_loop_size = dimensions[minor_dimension];
140 step[minor_dimension] = minor_loop_size;
141 }
142 }
143
Literal(const Shape & shape)144 Literal::Literal(const Shape& shape)
145 : Literal(shape, /*allocate_arrays=*/true) {}
146
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)147 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
148 if (shape.IsTuple()) {
149 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
150 const Shape& subshape = shape.tuple_shapes(i);
151
152 auto child_piece = Piece();
153 child_piece.set_subshape(&subshape);
154
155 SetPiece(subshape, &child_piece, allocate_arrays);
156
157 piece->emplace_back(std::move(child_piece));
158 }
159 } else if (shape.IsArray()) {
160 if (allocate_arrays) {
161 piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
162 piece->size_bytes(), kMinimumAlignment)));
163 if (shape.is_dynamic()) {
164 CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
165 piece->set_dynamic_size_buffer(
166 static_cast<int32*>(tensorflow::port::AlignedMalloc(
167 piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
168 }
169 }
170 } else {
171 // If the shape is neither an array nor tuple, then it must be
172 // zero-sized. Otherwise, some memory needs to be allocated for it.
173 CHECK_EQ(piece->size_bytes(), 0);
174 }
175 }
176
Literal(const Shape & shape,bool allocate_arrays)177 Literal::Literal(const Shape& shape, bool allocate_arrays)
178 : MutableLiteralBase() {
179 shape_ = absl::make_unique<Shape>(shape);
180 CHECK(LayoutUtil::HasLayout(*shape_));
181 root_piece_ = new Piece();
182 root_piece_->set_subshape(shape_.get());
183 CHECK(&root_piece_->subshape() == shape_.get());
184
185 SetPiece(*shape_, root_piece_, allocate_arrays);
186 }
187
~Literal()188 Literal::~Literal() {
189 if (root_piece_ != nullptr) {
190 DeallocateBuffers();
191 delete root_piece_;
192 }
193 }
194
DeallocateBuffers()195 void Literal::DeallocateBuffers() {
196 root_piece_->ForEachMutableSubpiece(
197 [&](const ShapeIndex& index, Piece* piece) {
198 if (piece->buffer() != nullptr) {
199 tensorflow::port::AlignedFree(piece->buffer());
200 }
201 if (piece->dynamic_size_buffer() != nullptr) {
202 tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
203 }
204 });
205 }
206
Literal(Literal && other)207 Literal::Literal(Literal&& other) : MutableLiteralBase() {
208 *this = std::move(other);
209 }
210
operator =(Literal && other)211 Literal& Literal::operator=(Literal&& other) {
212 DCHECK(&other.root_piece_->subshape() == other.shape_.get());
213 using std::swap;
214 swap(shape_, other.shape_);
215 swap(root_piece_, other.root_piece_);
216 DCHECK(&root_piece_->subshape() == shape_.get());
217
218 return *this;
219 }
220
CreateFromShape(const Shape & shape)221 Literal LiteralBase::CreateFromShape(const Shape& shape) {
222 Literal literal(shape);
223 literal.root_piece_->ForEachMutableSubpiece(
224 [&](const ShapeIndex& index, Piece* piece) {
225 if (piece->subshape().IsArray()) {
226 memset(piece->untyped_data(), 0, piece->size_bytes());
227 }
228 });
229 return literal;
230 }
231
GetDynamicSize(int64_t dim_index) const232 int32 LiteralBase::GetDynamicSize(int64_t dim_index) const {
233 return GetDynamicSize(dim_index, {});
234 }
235
GetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index) const236 int32 LiteralBase::GetDynamicSize(int64_t dim_index,
237 const ShapeIndex& shape_index) const {
238 return piece(shape_index).GetDynamicSize(dim_index);
239 }
240
GetFirstInteger() const241 absl::optional<int64> LiteralBase::GetFirstInteger() const {
242 switch (shape().element_type()) {
243 case U8:
244 return GetFirstElement<uint8>();
245 case U16:
246 return GetFirstElement<uint16>();
247 case U32:
248 return GetFirstElement<uint32>();
249 case U64: {
250 int64_t v = GetFirstElement<uint64>();
251 if (v < 0) {
252 return absl::nullopt;
253 }
254 return v;
255 }
256 case S8:
257 return GetFirstElement<int8>();
258 case S16:
259 return GetFirstElement<int16>();
260 case S32:
261 return GetFirstElement<int32>();
262 case S64:
263 return GetFirstElement<int64>();
264 default:
265 return absl::nullopt;
266 }
267 }
268
269 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)270 Status MutableLiteralBase::CopySliceFromInternal(
271 const LiteralBase& src_literal, absl::Span<const int64> src_base,
272 absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
273 const int64_t src_base_size = src_base.size();
274 const int64_t dest_base_size = dest_base.size();
275 TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
276 TF_RET_CHECK(shape().rank() == dest_base_size);
277
278 auto linear_index = [](const Shape& shape,
279 absl::Span<const int64> multi_index) {
280 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
281 };
282
283 if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
284 // If any of the two shapes are scalars, we can just call the StridedCopy()
285 // directly, and we know we will be copying only one value.
286 TF_RET_CHECK(copy_size.empty());
287 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
288 src_literal.data<NativeT>(),
289 linear_index(src_literal.shape(), src_base), 0, 1);
290 } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
291 !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
292 // Perform copy if neither src nor dest has dimensions with zero element,
293 // otherwise it's a no-op.
294 TF_RET_CHECK(src_base.size() == dest_base.size());
295 TF_RET_CHECK(src_base.size() == copy_size.size());
296
297 // Scan the source from minor, stepping in copy size blocks, then within
298 // the index enumeration functor, do a strided copy advancing source index
299 // by one (walking through the minor dimension), and destination index by
300 // proper stride size at the matching dimension.
301 DimensionVector src_indexes(src_base.size(), 0);
302 DimensionVector dest_indexes(dest_base.size(), 0);
303 MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
304 copy_size);
305
306 auto copy_proc = [&](absl::Span<const int64> indexes) {
307 // Map from multi-dimensional index, to source index.
308 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
309 src_indexes.begin(), std::plus<int64>());
310 // Map from multi-dimensional index, to destination index.
311 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
312 dest_indexes.begin(), std::plus<int64>());
313
314 int64_t src_index = linear_index(src_literal.shape(), src_indexes);
315 int64_t dest_index = linear_index(shape(), dest_indexes);
316
317 // `this->` is needed to workaround MSVC bug: #16882
318 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
319 src_literal.data<NativeT>(), src_index,
320 stride_config.source_stride, stride_config.minor_loop_size);
321 return true;
322 };
323
324 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
325 stride_config.dimensions, stride_config.step,
326 copy_proc);
327 }
328 return Status::OK();
329 }
330
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)331 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
332 absl::Span<const int64> src_index,
333 absl::Span<const int64> dest_index) {
334 DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
335 const int64_t src_linear_index =
336 IndexUtil::MultidimensionalIndexToLinearIndex(src_literal.shape(),
337 src_index);
338 const int64_t dest_linear_index =
339 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
340 const int64_t primitive_size =
341 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
342
343 char* dest_address =
344 static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
345 const char* source_address =
346 static_cast<const char*>(src_literal.untyped_data()) +
347 src_linear_index * primitive_size;
348 if (dest_address != source_address) {
349 memcpy(dest_address, source_address, primitive_size);
350 }
351 return Status::OK();
352 }
353
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)354 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
355 const LiteralProto& proto, bool prohibit_empty_literal) {
356 if (!proto.has_shape()) {
357 return InvalidArgument("LiteralProto has no shape");
358 }
359 Shape shape(proto.shape());
360 if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
361 return InvalidArgument(
362 "Literal shape cannot include OPAQUE_TYPE sub-shape");
363 }
364 if (!LayoutUtil::HasLayout(shape)) {
365 return InvalidArgument("LiteralProto has no layout");
366 }
367
368 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
369
370 Literal literal(shape);
371
372 TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
373 [&](const ShapeIndex& index, Piece* piece) {
374 const LiteralProto* proto_element = &proto;
375 for (int64_t i : index) {
376 CHECK(i < proto_element->tuple_literals_size());
377 proto_element = &proto_element->tuple_literals(i);
378 }
379
380 if (piece->subshape().IsTuple()) {
381 if (proto_element->tuple_literals_size() !=
382 ShapeUtil::TupleElementCount(piece->subshape())) {
383 return InvalidArgument(
384 "Expected %d tuple elements in LiteralProto, has %d",
385 ShapeUtil::TupleElementCount(piece->subshape()),
386 proto_element->tuple_literals_size());
387 }
388 return Status::OK();
389 }
390 if (piece->subshape().element_type() == TOKEN) {
391 return Status::OK();
392 }
393
394 CHECK(piece->subshape().IsArray());
395
396 // When prohibit_empty_literal is false (allowing literal with no
397 // values), only copy from proto if the literal proto has values. This
398 // mode is used for a learned cost model.
399 if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
400 TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
401 }
402
403 return Status::OK();
404 }));
405
406 return std::move(literal);
407 }
408
SubLiteral(ShapeIndexView shape_index)409 Literal Literal::SubLiteral(ShapeIndexView shape_index) {
410 if (!shape_index.empty()) {
411 auto decomposed = this->DecomposeTuple();
412 return decomposed.at(shape_index.front())
413 .SubLiteral(shape_index.ConsumeFront());
414 } else {
415 return std::move(*this);
416 }
417 }
418
DecomposeTuple()419 std::vector<Literal> Literal::DecomposeTuple() {
420 CHECK(shape().IsTuple());
421 std::vector<Literal> elements;
422 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
423 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
424 /*allocate_arrays=*/false));
425 Literal& element = elements.back();
426 element.root_piece_->ForEachMutableSubpiece(
427 [&](const ShapeIndex& index, Piece* dest_piece) {
428 ShapeIndex src_index = {i};
429 for (int64_t j : index) {
430 src_index.push_back(j);
431 }
432 Piece& src_piece = piece(src_index);
433
434 // Move the respective buffer over to the element Literal.
435 dest_piece->set_buffer(src_piece.buffer());
436 dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
437 src_piece.set_buffer(nullptr);
438 src_piece.set_dynamic_size_buffer(nullptr);
439 });
440 }
441 // Set this literal to be nil-shaped.
442 *this = Literal();
443 return elements;
444 }
445
446 namespace {
447
448 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
449 // the array slices are indicated by dest_shape and src_shape respectively.
450 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)451 void CopyElementsBetween(absl::Span<NativeT> dest,
452 absl::Span<const NativeT> src, const Shape& dest_shape,
453 const Shape& src_shape) {
454 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
455 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
456 return;
457 }
458 std::vector<int64> index(dest_shape.rank());
459 do {
460 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
461 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
462 } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
463 }
464 } // namespace
465
GetDynamicSize(int64_t dim_index) const466 int32 LiteralBase::Piece::GetDynamicSize(int64_t dim_index) const {
467 CHECK(LayoutUtil::IsDenseArray(subshape()));
468 if (!subshape_->is_dynamic_dimension(dim_index)) {
469 // This is a static dimension, return size.
470 return subshape_->dimensions(dim_index);
471 }
472 CHECK_NE(dynamic_size_buffer(), nullptr);
473 return dynamic_size_buffer_[dim_index];
474 }
475
SetDynamicSize(int64_t dim_index,int32_t size)476 void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
477 CHECK(LayoutUtil::IsDenseArray(subshape()));
478 CHECK(subshape_->is_dynamic_dimension(dim_index));
479 if (dynamic_size_buffer() == nullptr) {
480 // Lazily initialize the dynamic size buffer.
481 set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
482 dynamic_size_buffer_bytes(), kMinimumAlignment)));
483 /*for (int64 i = 0; i < subshape().rank(); ++i) {
484 // Initialized to -1 to help debug.
485 dynamic_size_buffer_[i] = -1;
486 }*/
487 }
488 dynamic_size_buffer_[dim_index] = size;
489 }
490
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)491 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
492 bool only_dynamic_bound) {
493 CHECK(subshape_ != nullptr);
494 CHECK(src.subshape_ != nullptr);
495 if (ShapeUtil::Equal(subshape(), src.subshape())) {
496 // If the layouts are equal it's faster just to memcpy.
497 memcpy(buffer(), src.buffer(), src.size_bytes());
498 } else {
499 std::vector<int64> origin(subshape().rank(), 0);
500 switch (subshape().element_type()) {
501 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
502 case (XLA_T): \
503 if (only_dynamic_bound) { \
504 CopyElementsWithDynamicBound<NATIVE_T>(src); \
505 } else { \
506 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
507 subshape(), src.subshape()); \
508 } \
509 break;
510 COPY_ELEMENTS(U8, uint8);
511 COPY_ELEMENTS(U16, uint16);
512 COPY_ELEMENTS(U32, uint32);
513 COPY_ELEMENTS(U64, uint64);
514 COPY_ELEMENTS(S8, int8);
515 COPY_ELEMENTS(S16, int16);
516 COPY_ELEMENTS(S32, int32);
517 COPY_ELEMENTS(S64, int64);
518 COPY_ELEMENTS(F16, half);
519 COPY_ELEMENTS(BF16, bfloat16);
520 COPY_ELEMENTS(F32, float);
521 COPY_ELEMENTS(F64, double);
522 COPY_ELEMENTS(C64, complex64);
523 COPY_ELEMENTS(C128, complex128);
524 COPY_ELEMENTS(PRED, bool);
525 #undef COPY_ELEMENTS
526 default:
527 return Unimplemented(
528 "Copying a Literal object with element type %s is not implemented.",
529 PrimitiveType_Name(subshape().element_type()));
530 }
531 }
532 DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
533 if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
534 CHECK_NE(dynamic_size_buffer_, nullptr);
535 CHECK_NE(src.dynamic_size_buffer_, nullptr);
536 memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
537 src.dynamic_size_buffer_bytes());
538 }
539 return Status::OK();
540 }
541
SetDynamicSize(int64_t dim_index,int32_t size)542 void MutableLiteralBase::SetDynamicSize(int64_t dim_index, int32_t size) {
543 return SetDynamicSize(dim_index, {}, size);
544 }
545
SetDynamicSize(int64_t dim_index,const ShapeIndex & shape_index,int32_t size)546 void MutableLiteralBase::SetDynamicSize(int64_t dim_index,
547 const ShapeIndex& shape_index,
548 int32_t size) {
549 Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
550 CHECK_GE(subshape_->dimensions(dim_index), size);
551 if (subshape_->dimensions(dim_index) == size) {
552 subshape_->set_dynamic_dimension(dim_index, false);
553 return;
554 }
555 subshape_->set_dynamic_dimension(dim_index, true);
556 piece(shape_index).SetDynamicSize(dim_index, size);
557 }
558
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)559 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
560 const ShapeIndex& dest_shape_index,
561 const ShapeIndex& src_shape_index,
562 bool only_dynamic_bound) {
563 const Shape& dest_subshape =
564 ShapeUtil::GetSubshape(shape(), dest_shape_index);
565 const Shape& src_subshape =
566 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
567 if (only_dynamic_bound) {
568 auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
569 auto compact_shape =
570 dest_subshape.is_static() ? dest_subshape : src_subshape;
571 CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
572 << compact_shape.ToString() << " vs " << bound_shape.ToString();
573 } else {
574 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
575 return InvalidArgument(
576 "Destination subshape incompatible with source subshape: %s vs %s",
577 ShapeUtil::HumanString(dest_subshape),
578 ShapeUtil::HumanString(src_subshape));
579 }
580 }
581 return root_piece_->ForEachMutableSubpieceWithStatus(
582 [&](const ShapeIndex& index, Piece* piece) {
583 if (!piece->subshape().IsArray()) {
584 return Status::OK();
585 }
586
587 // Determine if this index is in the part of this literal that we want
588 // to copy over from src_literal.
589 bool in_subtree_to_copy = true;
590 for (int i = 0; i < dest_shape_index.size(); ++i) {
591 if (index[i] != dest_shape_index[i]) {
592 in_subtree_to_copy = false;
593 break;
594 }
595 }
596 if (!in_subtree_to_copy) {
597 return Status::OK();
598 }
599 // Construct the index of the corresponding piece in the source literal.
600 ShapeIndex src_piece_index = src_shape_index;
601 for (int64_t i = dest_shape_index.size(), end = index.size(); i < end;
602 ++i) {
603 src_piece_index.push_back(index[i]);
604 }
605 TF_RETURN_IF_ERROR(
606 piece->CopyFrom(src_literal.piece(src_piece_index),
607 /*only_dynamic_bound=*/only_dynamic_bound));
608 return Status::OK();
609 });
610 }
611
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)612 Status Literal::MoveFrom(Literal&& src_literal,
613 const ShapeIndex& dest_shape_index) {
614 const Shape& dest_subshape =
615 ShapeUtil::GetSubshape(shape(), dest_shape_index);
616 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
617 return InvalidArgument(
618 "Destination subshape not equal to source shape: %s vs %s",
619 ShapeUtil::HumanString(dest_subshape),
620 ShapeUtil::HumanString(src_literal.shape()));
621 }
622
623 src_literal.root_piece_->ForEachSubpiece(
624 [&](const ShapeIndex& src_index, const Piece& src_piece) {
625 if (!src_piece.subshape().IsArray()) {
626 return;
627 }
628
629 ShapeIndex dest_index = dest_shape_index;
630 for (int64_t i : src_index) {
631 dest_index.push_back(i);
632 }
633 Piece& dest_piece = piece(dest_index);
634 tensorflow::port::AlignedFree(dest_piece.buffer());
635 tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
636 dest_piece.set_buffer(src_piece.buffer());
637 dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
638 });
639
640 src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
641 delete src_literal.root_piece_;
642 src_literal.root_piece_ = new LiteralBase::Piece();
643 src_literal.root_piece_->set_subshape(src_literal.shape_.get());
644
645 return Status::OK();
646 }
647
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)648 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
649 absl::Span<const int64> src_base,
650 absl::Span<const int64> dest_base,
651 absl::Span<const int64> copy_size) {
652 TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
653 TF_RET_CHECK(src_literal.shape().IsArray())
654 << ShapeUtil::HumanString(src_literal.shape());
655 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
656
657 switch (shape().element_type()) {
658 case U8:
659 return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
660 copy_size);
661 case U16:
662 return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
663 copy_size);
664 case U32:
665 return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
666 copy_size);
667 case U64:
668 return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
669 copy_size);
670 case S8:
671 return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
672 copy_size);
673 case S16:
674 return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
675 copy_size);
676 case S32:
677 return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
678 copy_size);
679 case S64:
680 return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
681 copy_size);
682 case F16:
683 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
684 copy_size);
685 case BF16:
686 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
687 copy_size);
688 case F32:
689 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
690 copy_size);
691 case F64:
692 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
693 copy_size);
694 case C64:
695 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
696 copy_size);
697 case C128:
698 return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
699 copy_size);
700 case PRED:
701 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
702 copy_size);
703 default:
704 break;
705 }
706 return Unimplemented(
707 "Copying a slice from a Literal object with element type %d is not "
708 "implemented.",
709 shape().element_type());
710 }
711
PopulateR1(const tensorflow::core::Bitmap & values)712 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
713 CHECK(shape().IsArray());
714 CHECK_EQ(shape().rank(), 1);
715 CHECK_EQ(element_count(), values.bits());
716 CHECK_EQ(shape().element_type(), PRED);
717 for (int64_t i = 0; i < static_cast<int64>(values.bits()); ++i) {
718 Set({i}, values.get(i));
719 }
720 }
721
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const722 Literal LiteralBase::Relayout(const Layout& new_layout,
723 const ShapeIndex& shape_index) const {
724 // Create new shape with 'new_layout' set at the given shape index.
725 Shape new_shape = shape();
726 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
727 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
728 *subshape->mutable_layout() = new_layout;
729 Literal result(new_shape);
730 TF_CHECK_OK(result.CopyFrom(*this));
731 return result;
732 }
733
Relayout(const Shape & shape_with_layout) const734 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
735 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
736 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
737 << " not compatible with literal shape "
738 << ShapeUtil::HumanString(shape());
739 Literal result = CreateFromShape(shape_with_layout);
740 ShapeUtil::ForEachSubshape(
741 result.shape(),
742 [this, &result](const Shape& subshape, const ShapeIndex& index) {
743 if (subshape.IsArray()) {
744 TF_CHECK_OK(result.CopyFrom(*this,
745 /*dest_shape_index=*/index,
746 /*src_shape_index=*/index));
747 }
748 });
749 return result;
750 }
751
ToBoundedDynamic(const Shape & bounded_shape) const752 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
753 CHECK(bounded_shape.is_dynamic());
754 Literal result(bounded_shape);
755 ShapeUtil::ForEachSubshape(
756 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
757 if (!subshape.IsArray()) {
758 return;
759 }
760 for (int64_t i = 0; i < subshape.rank(); ++i) {
761 result.SetDynamicSize(i, subshape.dimensions(i));
762 }
763 });
764 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
765
766 return result;
767 }
768
ToStatic() const769 Literal LiteralBase::ToStatic() const {
770 // Create new shape with 'new_layout' set at the given shape index.
771 Shape new_shape = shape();
772 ShapeUtil::ForEachMutableSubshape(
773 &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
774 if (!subshape->IsArray()) {
775 return;
776 }
777 for (int64_t i = 0; i < subshape->rank(); ++i) {
778 subshape->set_dynamic_dimension(i, false);
779 subshape->set_dimensions(i, GetDynamicSize(i, index));
780 }
781 });
782 Literal result(new_shape);
783 TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
784 return result;
785 }
786
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const787 StatusOr<Literal> LiteralBase::Broadcast(
788 const Shape& result_shape, absl::Span<const int64> dimensions) const {
789 if (!shape().IsArray()) {
790 return InvalidArgument("Broadcast only supports arrays.");
791 }
792
793 for (int64_t i = 0, end = dimensions.size(); i < end; i++) {
794 TF_RET_CHECK(shape().dimensions(i) ==
795 result_shape.dimensions(dimensions[i]));
796 }
797
798 Literal result(result_shape);
799
800 // scratch_source_index is temporary storage space for the computed index into
801 // the input literal. We put it here to avoid allocating an std::vector in
802 // every iteration of ShapeUtil::ForEachIndex.
803 std::vector<int64> scratch_source_index(shape().dimensions_size());
804
805 char* dest_data = static_cast<char*>(result.untyped_data());
806 const char* source_data = static_cast<const char*>(untyped_data());
807 const int64_t primitive_size =
808 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
809
810 for (int64_t i = 0; i < dimensions.size(); ++i) {
811 int64_t dynamic_size = GetDynamicSize(i);
812 result.SetDynamicSize(dimensions[i], dynamic_size);
813 }
814
815 ShapeUtil::ForEachIndex(
816 result_shape, [&](absl::Span<const int64> output_index) {
817 for (int64_t i = 0, end = dimensions.size(); i < end; ++i) {
818 scratch_source_index[i] = output_index[dimensions[i]];
819 }
820 int64_t dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
821 result_shape, output_index);
822 int64_t source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
823 shape(), scratch_source_index);
824 memcpy(dest_data + primitive_size * dest_index,
825 source_data + primitive_size * source_index, primitive_size);
826 return true;
827 });
828
829 return std::move(result);
830 }
831
Reshape(absl::Span<const int64> dimensions) const832 StatusOr<Literal> LiteralBase::Reshape(
833 absl::Span<const int64> dimensions) const {
834 if (!shape().IsArray()) {
835 return InvalidArgument("Reshape does not support tuples.");
836 }
837 if (shape().is_dynamic()) {
838 return Unimplemented("Dynamic reshape is not implemented.");
839 }
840 Literal output;
841 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
842 output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
843 } else {
844 output = Clone();
845 }
846 // Because the layout is monotonic, we can simply reuse the same sequence of
847 // values without changing their order.
848 *output.mutable_shape_do_not_use() =
849 ShapeUtil::MakeShape(shape().element_type(), dimensions);
850
851 int64_t elements_before = ShapeUtil::ElementsIn(shape());
852 int64_t elements_after = ShapeUtil::ElementsIn(output.shape());
853 if (elements_before != elements_after) {
854 return InvalidArgument(
855 "Shapes before and after Literal::Reshape have different numbers "
856 "of elements: %s vs %s.",
857 ShapeUtil::HumanString(shape()),
858 ShapeUtil::HumanString(output.shape()));
859 }
860 return std::move(output);
861 }
862
Transpose(absl::Span<const int64> permutation) const863 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
864 CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
865 CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
866 << "Given permutation is not a permutation of dimension numbers";
867 // To transpose the array, we just permute the dimensions and layout, and
868 // do a straight memory copy of the raw data set.
869 // This is considerably faster than iterating over every array element using
870 // the EachCell<>() and Set<>() APIs.
871 Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
872 // Replace the layout with one affine to this shape, such that a
873 // transpose operation can be performed by leaving the flat values
874 // representation intact.
875 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
876 // The shape with affine layout resulting from that operation will be
877 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
878 // most minor.
879 //
880 // Essentially, given MinMaj(Di) the position of the Di dimension within the
881 // minor to major vector, and given T(Di) the index that the original Di
882 // dimension has within the transposed array, a layout is affine if
883 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
884 // vector of the affine layout.
885 std::vector<int64> inverse_permutation = InversePermutation(permutation);
886 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
887 Layout* layout = permuted_shape.mutable_layout();
888 layout->clear_minor_to_major();
889 for (auto index : LayoutUtil::MinorToMajor(shape())) {
890 layout->add_minor_to_major(inverse_permutation[index]);
891 }
892 Literal new_literal(permuted_shape);
893 for (int64_t i = 0; i < shape().rank(); i++) {
894 new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
895 }
896 DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
897 ShapeUtil::ByteSizeOf(shape()));
898 std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
899 return new_literal;
900 }
901
902 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const903 Literal LiteralBase::SliceInternal(
904 const Shape& result_shape, absl::Span<const int64> start_indices) const {
905 Literal result_literal(result_shape);
906 DimensionVector new_indices(result_shape.rank());
907 CHECK(result_literal
908 .Populate<NativeT>([&](absl::Span<const int64> indices) {
909 for (int64_t i = 0; i < result_shape.rank(); ++i) {
910 new_indices[i] = indices[i] + start_indices[i];
911 }
912 return Get<NativeT>(new_indices);
913 })
914 .ok());
915 for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
916 if (shape().is_dynamic_dimension(dnum)) {
917 int64_t dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
918 CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
919 dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
920 result_literal.SetDynamicSize(dnum, dynamic_size);
921 }
922 }
923 return result_literal;
924 }
925
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const926 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
927 absl::Span<const int64> limit_indices) const {
928 CHECK(shape().IsArray()) << "tuple is not supported for slice";
929
930 DimensionVector result_dimensions;
931 for (int64_t dnum = 0; dnum < shape().rank(); ++dnum) {
932 CHECK_GE(start_indices[dnum], 0);
933 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
934 << "dnum = " << dnum;
935 int64_t dimension = limit_indices[dnum] - start_indices[dnum];
936 CHECK_GE(dimension, 0) << "dnum = " << dnum;
937 result_dimensions.push_back(dimension);
938 }
939 auto result_shape =
940 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
941 LayoutUtil::MinorToMajor(shape()));
942 ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
943 switch (result_shape.element_type()) {
944 case PRED:
945 return SliceInternal<bool>(result_shape, start_indices);
946 case U8:
947 return SliceInternal<uint8>(result_shape, start_indices);
948 case U16:
949 return SliceInternal<uint16>(result_shape, start_indices);
950 case U32:
951 return SliceInternal<uint32>(result_shape, start_indices);
952 case U64:
953 return SliceInternal<uint64>(result_shape, start_indices);
954 case S8:
955 return SliceInternal<int8>(result_shape, start_indices);
956 case S16:
957 return SliceInternal<int16>(result_shape, start_indices);
958 case S32:
959 return SliceInternal<int32>(result_shape, start_indices);
960 case S64:
961 return SliceInternal<int64>(result_shape, start_indices);
962 case F16:
963 return SliceInternal<half>(result_shape, start_indices);
964 case BF16:
965 return SliceInternal<bfloat16>(result_shape, start_indices);
966 case F32:
967 return SliceInternal<float>(result_shape, start_indices);
968 case F64:
969 return SliceInternal<double>(result_shape, start_indices);
970 case C64:
971 return SliceInternal<complex64>(result_shape, start_indices);
972 case C128:
973 return SliceInternal<complex128>(result_shape, start_indices);
974 default:
975 LOG(FATAL) << "not yet implemented: "
976 << PrimitiveType_Name(result_shape.element_type());
977 }
978 }
979
Clone() const980 Literal LiteralBase::Clone() const {
981 Literal result(shape());
982 TF_CHECK_OK(result.CopyFrom(*this));
983 return result;
984 }
985
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const986 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
987 const ShapeIndex& shape_index) const {
988 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
989 CHECK(LayoutUtil::IsDenseArray(subshape));
990 switch (subshape.element_type()) {
991 case PRED:
992 return Get<bool>(multi_index, shape_index) ? "true" : "false";
993 case S8:
994 return StrCat(Get<int8>(multi_index, shape_index));
995 case S16:
996 return StrCat(Get<int16>(multi_index, shape_index));
997 case S32:
998 return StrCat(Get<int32>(multi_index, shape_index));
999 case S64:
1000 return StrCat(Get<int64>(multi_index, shape_index));
1001 case U8:
1002 return StrCat(Get<uint8>(multi_index, shape_index));
1003 case U16:
1004 return StrCat(Get<uint16>(multi_index, shape_index));
1005 case U32:
1006 return StrCat(Get<uint32>(multi_index, shape_index));
1007 case U64:
1008 return StrCat(Get<uint64>(multi_index, shape_index));
1009 case F16:
1010 return RoundTripFpToString(Get<half>(multi_index, shape_index));
1011 case F32:
1012 return RoundTripFpToString(Get<float>(multi_index, shape_index));
1013 case BF16:
1014 return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1015 case F64:
1016 return RoundTripFpToString(Get<double>(multi_index, shape_index));
1017 case C64: {
1018 complex64 c = Get<complex64>(multi_index, shape_index);
1019 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1020 RoundTripFpToString(c.imag()), ")");
1021 }
1022 case C128: {
1023 complex128 c = Get<complex128>(multi_index, shape_index);
1024 return StrCat("(", RoundTripFpToString(c.real()), ", ",
1025 RoundTripFpToString(c.imag()), ")");
1026 }
1027 default:
1028 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1029 }
1030 }
1031
GetIntegralAsS64(absl::Span<const int64> multi_index) const1032 absl::optional<int64> LiteralBase::GetIntegralAsS64(
1033 absl::Span<const int64> multi_index) const {
1034 CHECK(LayoutUtil::IsDenseArray(shape()));
1035 switch (shape().element_type()) {
1036 case PRED:
1037 return Get<bool>(multi_index);
1038 case S8:
1039 return Get<int8>(multi_index);
1040 case U8:
1041 return Get<uint8>(multi_index);
1042 case S16:
1043 return Get<int16>(multi_index);
1044 case U16:
1045 return Get<uint16>(multi_index);
1046 case S32:
1047 return Get<int32>(multi_index);
1048 case U32:
1049 return Get<uint32>(multi_index);
1050 case S64:
1051 return Get<int64>(multi_index);
1052 case U64:
1053 return Get<uint64>(multi_index);
1054 default:
1055 return absl::nullopt;
1056 }
1057 }
1058
GetAsDouble(absl::Span<const int64> multi_index) const1059 absl::optional<double> LiteralBase::GetAsDouble(
1060 absl::Span<const int64> multi_index) const {
1061 CHECK(LayoutUtil::IsDenseArray(shape()));
1062 switch (shape().element_type()) {
1063 case F16:
1064 return static_cast<double>(Get<half>(multi_index));
1065 case F32:
1066 return static_cast<double>(Get<float>(multi_index));
1067 case F64:
1068 return Get<double>(multi_index);
1069 case BF16:
1070 return static_cast<double>(Get<bfloat16>(multi_index));
1071 default:
1072 return absl::nullopt;
1073 }
1074 }
1075
GetAsComplex128(absl::Span<const int64> multi_index) const1076 absl::optional<complex128> LiteralBase::GetAsComplex128(
1077 absl::Span<const int64> multi_index) const {
1078 switch (shape().element_type()) {
1079 case BF16:
1080 return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1081 case F16:
1082 return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1083 case F32:
1084 return {{Get<float>(multi_index), 0}};
1085 case F64:
1086 return {{Get<double>(multi_index), 0}};
1087 case C64:
1088 return {Get<complex64>(multi_index)};
1089 case C128:
1090 return {Get<complex128>(multi_index)};
1091 case S8:
1092 return {Get<int8>(multi_index)};
1093 default:
1094 return absl::nullopt;
1095 }
1096 }
1097
Hash() const1098 size_t LiteralBase::Hash() const {
1099 using tensorflow::Hash64;
1100 using tensorflow::Hash64Combine;
1101
1102 size_t hash_value = ShapeUtil::Hash(shape());
1103
1104 ShapeUtil::ForEachSubshape(
1105 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1106 if (!subshape.IsArray()) {
1107 return;
1108 }
1109
1110 CHECK(LayoutUtil::IsDense(subshape.layout()));
1111 hash_value = Hash64Combine(
1112 hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
1113 size_bytes(index)));
1114 });
1115
1116 return hash_value;
1117 }
1118
SetIntegralAsS64(absl::Span<const int64> multi_index,int64_t value)1119 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
1120 int64_t value) {
1121 CHECK(LayoutUtil::IsDenseArray(shape()));
1122 switch (shape().element_type()) {
1123 case PRED:
1124 Set<bool>(multi_index, value);
1125 break;
1126 case U8:
1127 Set<uint8>(multi_index, value);
1128 break;
1129 case S32:
1130 Set<int32>(multi_index, value);
1131 break;
1132 case S64:
1133 Set<int64>(multi_index, value);
1134 break;
1135 case U32:
1136 Set<uint32>(multi_index, value);
1137 break;
1138 case U64:
1139 Set<uint64>(multi_index, value);
1140 break;
1141 default:
1142 return FailedPrecondition("Array element type is not integral: %s",
1143 PrimitiveType_Name(shape().element_type()));
1144 }
1145 return Status::OK();
1146 }
1147
SetFromDouble(absl::Span<const int64> multi_index,double value)1148 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
1149 double value) {
1150 CHECK(LayoutUtil::IsDenseArray(shape()));
1151 switch (shape().element_type()) {
1152 case F16:
1153 Set<half>(multi_index, Eigen::half(value));
1154 break;
1155 case F32:
1156 Set<float>(multi_index, value);
1157 break;
1158 case F64:
1159 Set<double>(multi_index, value);
1160 break;
1161 case BF16:
1162 Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1163 break;
1164 default:
1165 return FailedPrecondition("Array element type is not floating: %s",
1166 PrimitiveType_Name(shape().element_type()));
1167 }
1168 return Status::OK();
1169 }
1170
1171 namespace {
1172
ShapeToString(bool print_layout,const Shape & shape)1173 string ShapeToString(bool print_layout, const Shape& shape) {
1174 return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1175 : ShapeUtil::HumanString(shape);
1176 }
1177
1178 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1179 bool print_shape, bool print_layout,
1180 std::vector<string>* pieces);
1181
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1182 void TupleToStringHelper(const LiteralBase& literal,
1183 const ShapeIndex& shape_index, bool print_shape,
1184 bool print_layout, std::vector<string>* pieces) {
1185 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1186 pieces->push_back("(\n");
1187 std::vector<string> tuple_pieces;
1188 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1189 ShapeIndex element_index = shape_index;
1190 element_index.push_back(i);
1191 std::vector<string> element_pieces;
1192 ToStringHelper(literal, element_index, print_shape, print_layout,
1193 &element_pieces);
1194 tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1195 }
1196 pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1197 pieces->push_back("\n)");
1198 }
1199
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1200 void DenseArrayToStringHelper(const LiteralBase& literal,
1201 const ShapeIndex& shape_index, bool print_shape,
1202 bool print_layout, std::vector<string>* pieces) {
1203 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1204 int64_t rank = subshape.rank();
1205
1206 std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1207 to_string_recursive = [&](absl::Span<const int64> dimensions,
1208 std::vector<int64>* accum_indices) {
1209 // dimensions.size() decreases by 1 at each recursive call,
1210 // and accum_indices->size() increases by 1.
1211 // Their sum is equal to the rank of the tensor.
1212 CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1213
1214 auto brace_to_string = [&](string brace) -> string {
1215 // Handle 1D tensor
1216 if (rank == 1) {
1217 return brace;
1218 }
1219 // Handle the innermost tensor of a 2D+ tensor.
1220 if (dimensions.size() == 1 && brace == "{") {
1221 return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
1222 }
1223 if (dimensions.size() == 1 && brace == "}") {
1224 return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1225 }
1226 // Handle the non-innermost tensors of a 2D+ tensor.
1227 if (brace == "{") {
1228 const int64_t accum_indices_size = accum_indices->size();
1229 if (rank > 3 && !accum_indices->empty() &&
1230 accum_indices_size < rank) {
1231 int index = accum_indices->size() - 1;
1232 int value = accum_indices->back();
1233 return StrCat(brace, " /*i", index, "=", value, "*/\n");
1234 }
1235 return StrCat(brace, "\n");
1236 }
1237 return StrCat("\n", brace);
1238 };
1239
1240 if (dimensions.empty()) {
1241 // Display predicates as 0s and 1s so that the string is more dense.
1242 string elem;
1243 if (subshape.element_type() == PRED && rank > 0) {
1244 elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1245 } else {
1246 elem = literal.GetAsString(*accum_indices, shape_index);
1247 }
1248 pieces->push_back(elem);
1249 } else {
1250 pieces->push_back(brace_to_string("{"));
1251 for (int i = 0; i < dimensions[0]; ++i) {
1252 std::vector<int64> cloned_indices(*accum_indices);
1253 cloned_indices.push_back(i);
1254 to_string_recursive(dimensions.subspan(1), &cloned_indices);
1255 if (i < dimensions[0] - 1) {
1256 pieces->push_back(",");
1257 pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1258 }
1259 }
1260 pieces->push_back(brace_to_string("}"));
1261 }
1262 };
1263
1264 if (print_shape) {
1265 pieces->push_back(ShapeToString(print_layout, subshape));
1266 if (subshape.is_dynamic()) {
1267 pieces->push_back("(");
1268 for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1269 pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1270 if (i < subshape.dimensions_size() - 1) {
1271 pieces->push_back(",");
1272 }
1273 }
1274 pieces->push_back(")");
1275 }
1276 pieces->push_back(" ");
1277 }
1278 std::vector<int64> indices = {};
1279 std::vector<int64> dimensions;
1280 dimensions.reserve(subshape.rank());
1281 for (int64_t i = 0; i < subshape.rank(); ++i) {
1282 dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1283 }
1284 to_string_recursive(dimensions, &indices);
1285 }
1286
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1287 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1288 bool print_shape, bool print_layout,
1289 std::vector<string>* pieces) {
1290 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1291 CHECK(LayoutUtil::HasLayout(literal.shape()));
1292 CHECK(LayoutUtil::HasLayout(subshape));
1293 if (subshape.IsTuple()) {
1294 TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1295 pieces);
1296 } else if (subshape.IsToken()) {
1297 pieces->push_back("token");
1298 } else {
1299 CHECK(LayoutUtil::IsDenseArray(subshape));
1300 DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1301 pieces);
1302 }
1303 }
1304
1305 } // namespace
1306
ToString() const1307 string LiteralBase::ToString() const {
1308 std::vector<string> pieces;
1309 CHECK(LayoutUtil::HasLayout(this->shape()));
1310 ToStringHelper(*this, {}, /*print_shape=*/true,
1311 /*print_layout=*/false, &pieces);
1312 return absl::StrJoin(pieces, "");
1313 }
1314
ToStringOneline() const1315 string LiteralBase::ToStringOneline() const {
1316 return CompactOneline(ToString());
1317 }
1318
ToStringWithoutShape() const1319 string LiteralBase::ToStringWithoutShape() const {
1320 std::vector<string> pieces;
1321 CHECK(LayoutUtil::HasLayout(this->shape()));
1322 ToStringHelper(*this, {}, /*print_shape=*/false,
1323 /*print_layout=*/false, &pieces);
1324 return absl::StrJoin(pieces, "");
1325 }
1326
ToStringWithoutShapeOneline() const1327 string LiteralBase::ToStringWithoutShapeOneline() const {
1328 return CompactOneline(ToStringWithoutShape());
1329 }
1330
ToStringWithLayout() const1331 string LiteralBase::ToStringWithLayout() const {
1332 std::vector<string> pieces;
1333 CHECK(LayoutUtil::HasLayout(this->shape()));
1334 ToStringHelper(*this, {}, /*print_shape=*/true,
1335 /*print_layout=*/true, &pieces);
1336 return absl::StrJoin(pieces, "");
1337 }
1338
ToStringWithLayoutOneline() const1339 string LiteralBase::ToStringWithLayoutOneline() const {
1340 return CompactOneline(ToStringWithLayout());
1341 }
1342
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1343 void LiteralBase::EachCellAsString(
1344 const std::function<void(absl::Span<const int64> indices,
1345 const string& value)>& per_cell) const {
1346 if (ShapeUtil::IsZeroElementArray(shape())) {
1347 return;
1348 }
1349 std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1350 shape(), /*linear_index=*/0);
1351 do {
1352 per_cell(indices, GetAsString(indices));
1353 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1354 }
1355
1356 namespace {
1357 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1358 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1359 const ConverterType& converter) {
1360 CHECK(src_literal.shape().IsArray());
1361 Literal result_literal(ShapeUtil::ChangeElementType(
1362 src_literal.shape(),
1363 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1364 auto src_data = src_literal.data<NativeSrcT>();
1365 auto dest_data = result_literal.template data<NativeDestT>();
1366 int64_t num_elements = src_literal.element_count();
1367
1368 for (int64_t i = 0; i < num_elements; ++i) {
1369 dest_data[i] = converter(src_data[i]);
1370 }
1371 return result_literal;
1372 }
1373
1374 template <typename NativeSrcT, typename NativeDestT>
1375 typename std::enable_if<std::is_same<NativeSrcT, Eigen::half>::value &&
1376 (std::is_same<NativeDestT, complex64>::value ||
1377 std::is_same<NativeDestT, complex128>::value),
1378 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1379 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1380 auto converter = [](NativeSrcT src) {
1381 return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1382 };
1383 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1384 src_literal, converter);
1385 }
1386
1387 template <typename NativeSrcT, typename NativeDestT>
1388 typename std::enable_if<std::is_floating_point<NativeSrcT>::value &&
1389 std::is_integral<NativeDestT>::value,
1390 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1391 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1392 auto converter = [](NativeSrcT src) {
1393 // C++ [conv.bool]p1:
1394 // A prvalue of arithmetic [...] type can be converted to a prvalue of
1395 // type bool. A zero value [...] is converted to false; any other value is
1396 // converted to true.
1397 // C++ [conv.fpint]p1:
1398 // [...] The behavior is undefined if the truncated value cannot be
1399 // represented in the destination type.
1400 //
1401 // Using static_cast to convert a float to an integral type other than bool
1402 // may be undefined if the value's magnitude is too large or it is a NaN.
1403 // Let's choose saturating arithmetic as it captures the spirit of infinity
1404 // and arbitrarily map NaN to zero.
1405 if (!std::is_same<NativeDestT, bool>::value) {
1406 if (src != src) {
1407 return NativeDestT{0};
1408 }
1409 if (src >= std::numeric_limits<NativeDestT>::max()) {
1410 return std::numeric_limits<NativeDestT>::max();
1411 }
1412 if (src <= std::numeric_limits<NativeDestT>::lowest()) {
1413 return std::numeric_limits<NativeDestT>::lowest();
1414 }
1415 }
1416 return static_cast<NativeDestT>(src);
1417 };
1418 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1419 src_literal, converter);
1420 }
1421
1422 template <typename NativeSrcT, typename NativeDestT>
1423 typename std::enable_if<!(std::is_floating_point<NativeSrcT>::value &&
1424 std::is_integral<NativeDestT>::value) &&
1425 !(std::is_same<NativeSrcT, Eigen::half>::value &&
1426 (std::is_same<NativeDestT, complex64>::value ||
1427 std::is_same<NativeDestT, complex128>::value)),
1428 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1429 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1430 auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1431 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1432 src_literal, converter);
1433 }
1434
1435 template <typename NativeSrcT, typename NativeDestT>
1436 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1437 !std::is_same<NativeDestT, Eigen::half>::value),
1438 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1439 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1440 auto converter = [](NativeSrcT src) {
1441 return absl::bit_cast<NativeDestT>(GetRawValue(src));
1442 };
1443 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1444 src_literal, converter);
1445 }
1446
1447 template <typename NativeSrcT, typename NativeDestT>
1448 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1449 std::is_same<NativeDestT, Eigen::half>::value),
1450 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1451 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1452 // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1453 // cast to unsigned short first.
1454 auto converter = [](NativeSrcT src) {
1455 return Eigen::numext::bit_cast<Eigen::half>(
1456 absl::bit_cast<uint16>(GetRawValue(src)));
1457 };
1458 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1459 src_literal, converter);
1460 }
1461
1462 // This template specialization is here to make the compiler happy. bit_cast has
1463 // a static check that the types are the same size. This specialization should
1464 // never be used because the source and destination types are checked for
1465 // identical sizes higher up.
1466 template <typename NativeSrcT, typename NativeDestT>
1467 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1468 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1469 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1470 LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1471 }
1472
1473 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1474 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1475 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1476 if (bitcast) {
1477 return BitcastBetweenNativeTypes<
1478 typename primitive_util::PrimitiveTypeToNative<
1479 primitive_src_type>::type,
1480 typename primitive_util::PrimitiveTypeToNative<
1481 primitive_dest_type>::type>(src_literal);
1482 } else {
1483 return ConvertBetweenNativeTypes<
1484 typename primitive_util::PrimitiveTypeToNative<
1485 primitive_src_type>::type,
1486 typename primitive_util::PrimitiveTypeToNative<
1487 primitive_dest_type>::type>(src_literal);
1488 }
1489 }
1490
1491 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1492 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1493 PrimitiveType primitive_dest_type,
1494 bool bitcast) {
1495 switch (primitive_dest_type) {
1496 #define CONVERT_IF_TYPES_MATCH(type) \
1497 case (type): \
1498 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1499 bitcast);
1500 CONVERT_IF_TYPES_MATCH(PRED)
1501 CONVERT_IF_TYPES_MATCH(S8)
1502 CONVERT_IF_TYPES_MATCH(S16)
1503 CONVERT_IF_TYPES_MATCH(S32)
1504 CONVERT_IF_TYPES_MATCH(S64)
1505 CONVERT_IF_TYPES_MATCH(U8)
1506 CONVERT_IF_TYPES_MATCH(U16)
1507 CONVERT_IF_TYPES_MATCH(U32)
1508 CONVERT_IF_TYPES_MATCH(U64)
1509 CONVERT_IF_TYPES_MATCH(F16)
1510 CONVERT_IF_TYPES_MATCH(F32)
1511 CONVERT_IF_TYPES_MATCH(F64)
1512 CONVERT_IF_TYPES_MATCH(BF16)
1513 #undef CONVERT_IF_TYPES_MATCH
1514 case C64:
1515 if (bitcast) {
1516 break;
1517 }
1518 return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1519 case C128:
1520 if (bitcast) {
1521 break;
1522 }
1523 return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1524 // Other types are not yet supported.
1525 default:
1526 break;
1527 }
1528 return Unimplemented("Converting from type %s to type %s is not implemented.",
1529 PrimitiveType_Name(src_literal.shape().element_type()),
1530 PrimitiveType_Name(primitive_dest_type));
1531 }
1532
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1533 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1534 PrimitiveType primitive_dest_type,
1535 bool bitcast) {
1536 TF_RET_CHECK(literal.shape().IsArray());
1537 if (literal.shape().element_type() == primitive_dest_type) {
1538 return literal.Clone();
1539 }
1540 switch (literal.shape().element_type()) {
1541 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1542 case (type): \
1543 return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1544 bitcast);
1545 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1546 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1547 CONVERT_IF_DEST_TYPE_MATCHES(S16)
1548 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1549 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1550 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1551 CONVERT_IF_DEST_TYPE_MATCHES(U16)
1552 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1553 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1554 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1555 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1556 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1557 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1558 #undef CONVERT_IF_DEST_TYPE_MATCHES
1559 // Other types are not yet supported.
1560 default:
1561 return Unimplemented("%s from type %s to type %s is not implemented.",
1562 (bitcast ? "Bitcast converting" : "Converting"),
1563 PrimitiveType_Name(literal.shape().element_type()),
1564 PrimitiveType_Name(primitive_dest_type));
1565 }
1566 }
1567
1568 } // namespace
1569
Convert(PrimitiveType primitive_dest_type) const1570 StatusOr<Literal> LiteralBase::Convert(
1571 PrimitiveType primitive_dest_type) const {
1572 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1573 }
1574
BitcastConvert(PrimitiveType primitive_dest_type) const1575 StatusOr<Literal> LiteralBase::BitcastConvert(
1576 PrimitiveType primitive_dest_type) const {
1577 if (primitive_util::BitWidth(shape().element_type()) !=
1578 primitive_util::BitWidth(primitive_dest_type)) {
1579 return InvalidArgument(
1580 "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1581 "%d",
1582 PrimitiveType_Name(shape().element_type()),
1583 PrimitiveType_Name(primitive_dest_type),
1584 primitive_util::BitWidth(shape().element_type()),
1585 primitive_util::BitWidth(primitive_dest_type));
1586 }
1587 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1588 }
1589
ConvertToShape(const Shape & dest_shape) const1590 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1591 if (!dest_shape.IsTuple()) {
1592 return Convert(dest_shape.element_type());
1593 }
1594 std::vector<Literal> elements;
1595 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1596 auto element = LiteralSlice(*this, {i});
1597 TF_ASSIGN_OR_RETURN(
1598 auto new_element,
1599 element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1600 elements.push_back(std::move(new_element));
1601 }
1602 return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1603 }
1604
MoveIntoTuple(absl::Span<Literal> elements)1605 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1606 absl::Span<Literal> elements) {
1607 std::vector<Shape> element_shapes;
1608 for (const Literal& element : elements) {
1609 element_shapes.push_back(element.shape());
1610 }
1611 Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1612 /*allocate_arrays=*/false);
1613 for (int i = 0, end = elements.size(); i < end; ++i) {
1614 TF_CHECK_OK(
1615 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1616 }
1617 return literal;
1618 }
1619
1620 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1621 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1622 const LiteralBase::Piece& src) {
1623 auto dest_shape = subshape();
1624 auto src_shape = src.subshape();
1625
1626 // At least one shape has to be static as bound.
1627 CHECK(dest_shape.is_static() || src_shape.is_static());
1628 auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1629 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1630 return;
1631 }
1632 std::vector<int64> index(dest_shape.rank());
1633 do {
1634 bool out_of_bound = false;
1635 for (int64_t i = 0; i < index.size(); ++i) {
1636 // Do not copy elements beyond dynamic bound.
1637 if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1638 out_of_bound = true;
1639 }
1640 }
1641 if (out_of_bound) {
1642 continue;
1643 }
1644 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1645 index)] =
1646 src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1647 src_shape, index)];
1648 } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1649 }
1650
1651 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1652 bool LiteralBase::Piece::EqualElementsInternal(
1653 const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1654 if (multi_index->size() == subshape().rank()) {
1655 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1656 }
1657 for (int64_t i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1658 multi_index->push_back(i);
1659 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1660 return false;
1661 }
1662 multi_index->pop_back();
1663 }
1664 return true;
1665 }
1666
EqualDynamicSize(const LiteralBase::Piece & other) const1667 bool LiteralBase::Piece::EqualDynamicSize(
1668 const LiteralBase::Piece& other) const {
1669 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1670 if (subshape().is_static()) {
1671 return true;
1672 }
1673
1674 for (int64_t i = 0; i < subshape().rank(); ++i) {
1675 if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1676 return false;
1677 }
1678 }
1679 return true;
1680 }
1681
EqualElements(const LiteralBase::Piece & other) const1682 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1683 if (subshape().is_static() &&
1684 ShapeUtil::Equal(subshape(), other.subshape()) &&
1685 LayoutUtil::IsDenseArray(subshape())) {
1686 CHECK_EQ(size_bytes(), other.size_bytes());
1687 return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1688 }
1689
1690 std::vector<int64> multi_index;
1691 switch (subshape().element_type()) {
1692 case PRED:
1693 return EqualElementsInternal<bool>(other, &multi_index);
1694 case S8:
1695 return EqualElementsInternal<int8>(other, &multi_index);
1696 case S16:
1697 return EqualElementsInternal<int16>(other, &multi_index);
1698 case S32:
1699 return EqualElementsInternal<int32>(other, &multi_index);
1700 case S64:
1701 return EqualElementsInternal<int64>(other, &multi_index);
1702 case U8:
1703 return EqualElementsInternal<uint8>(other, &multi_index);
1704 case U16:
1705 return EqualElementsInternal<uint16>(other, &multi_index);
1706 case U32:
1707 return EqualElementsInternal<uint32>(other, &multi_index);
1708 case U64:
1709 return EqualElementsInternal<uint64>(other, &multi_index);
1710 case F32:
1711 return EqualElementsInternal<float>(other, &multi_index);
1712 case F64:
1713 return EqualElementsInternal<double>(other, &multi_index);
1714 case F16:
1715 return EqualElementsInternal<half>(other, &multi_index);
1716 case BF16:
1717 return EqualElementsInternal<bfloat16>(other, &multi_index);
1718 case C64:
1719 return EqualElementsInternal<complex64>(other, &multi_index);
1720 case C128:
1721 return EqualElementsInternal<complex128>(other, &multi_index);
1722 default:
1723 LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1724 << PrimitiveType_Name(subshape().element_type());
1725 }
1726 }
1727
operator ==(const LiteralBase & other) const1728 bool LiteralBase::operator==(const LiteralBase& other) const {
1729 // Checking the structure of tuple literals. Checks for dense arrays are
1730 // performed below.
1731 if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1732 return false;
1733 }
1734
1735 return root_piece().ForEachSubpieceWithBool(
1736 [&](const ShapeIndex& index, const Piece& piece) {
1737 const Piece& other_piece = other.piece(index);
1738 const Shape& subshape = piece.subshape();
1739 const Shape& other_subshape = other_piece.subshape();
1740 if (subshape.element_type() != other_subshape.element_type()) {
1741 return false;
1742 }
1743 if (!piece.subshape().IsArray()) {
1744 return true;
1745 }
1746 if (subshape.rank() != other_subshape.rank()) {
1747 return false;
1748 }
1749
1750 for (int64_t i = 0; i < subshape.rank(); ++i) {
1751 if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1752 return false;
1753 }
1754 }
1755
1756 if (!piece.EqualElements(other_piece)) {
1757 return false;
1758 }
1759 return true;
1760 });
1761 }
1762
1763 namespace {
1764
1765 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1766 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1767 NativeT value) {
1768 for (int64_t i = 0; i < data.size(); ++i) {
1769 if (data[i] != value) {
1770 return false;
1771 }
1772 }
1773 return true;
1774 }
1775
1776 } // namespace
1777
IsAll(int8_t value) const1778 bool LiteralBase::IsAll(int8_t value) const {
1779 return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1780 const Piece& piece) {
1781 if (!piece.subshape().IsArray()) {
1782 return true;
1783 }
1784
1785 auto piece_is_all = [&]() {
1786 switch (shape().element_type()) {
1787 case U8:
1788 if (value >= 0) {
1789 return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1790 }
1791 return false;
1792 case U16:
1793 if (value >= 0) {
1794 return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1795 }
1796 return false;
1797 case U32:
1798 if (value >= 0) {
1799 return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1800 }
1801 return false;
1802 case U64:
1803 if (value >= 0) {
1804 return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1805 }
1806 return false;
1807 case S8:
1808 return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1809 case S16:
1810 return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1811 case S32:
1812 return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1813 case S64:
1814 return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1815 case F32:
1816 return AllElementsEqualValue<float>(piece.data<float>(), value);
1817 case F64:
1818 return AllElementsEqualValue<double>(piece.data<double>(), value);
1819 case F16:
1820 return AllElementsEqualValue<half>(piece.data<half>(),
1821 static_cast<half>(value));
1822 case BF16:
1823 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1824 static_cast<bfloat16>(value));
1825 case PRED:
1826 if (value == 0) {
1827 return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1828 }
1829 if (value == 1) {
1830 return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1831 }
1832 return false;
1833 default:
1834 return false;
1835 }
1836 return false;
1837 };
1838
1839 if (!piece_is_all()) {
1840 return false;
1841 }
1842 return true;
1843 });
1844 }
1845
IsAllFloat(float value) const1846 bool LiteralBase::IsAllFloat(float value) const {
1847 return root_piece().ForEachSubpieceWithBool(
1848 [&](const ShapeIndex& index, const Piece& piece) {
1849 if (!piece.subshape().IsArray()) {
1850 return true;
1851 }
1852
1853 switch (shape().element_type()) {
1854 case F32:
1855 return AllElementsEqualValue<float>(piece.data<float>(), value);
1856 case F64:
1857 return AllElementsEqualValue<double>(piece.data<double>(), value);
1858 case F16:
1859 return AllElementsEqualValue<half>(piece.data<half>(),
1860 static_cast<half>(value));
1861 case BF16:
1862 return AllElementsEqualValue<bfloat16>(
1863 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1864 default:
1865 return false;
1866 }
1867 });
1868 }
1869
IsAllComplex(complex64 value) const1870 bool LiteralBase::IsAllComplex(complex64 value) const {
1871 switch (shape().element_type()) {
1872 case C64:
1873 return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1874 value);
1875 case C128:
1876 return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1877 value);
1878 default:
1879 return false;
1880 }
1881 }
1882
IsAllFirst() const1883 bool LiteralBase::IsAllFirst() const {
1884 return root_piece().ForEachSubpieceWithBool(
1885 [&](const ShapeIndex& index, const Piece& piece) {
1886 if (!piece.subshape().IsArray()) {
1887 return true;
1888 }
1889
1890 // Empty shapes are not all the first element since there is no first
1891 // element.
1892 if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1893 return false;
1894 }
1895 auto piece_is_all = [&]() {
1896 switch (piece.subshape().element_type()) {
1897 case PRED: {
1898 auto data = piece.data<bool>();
1899 return AllElementsEqualValue<bool>(data, data[0]);
1900 }
1901 // 8 bit types
1902 case S8: {
1903 auto data = piece.data<int8>();
1904 return AllElementsEqualValue<int8>(data, data[0]);
1905 }
1906 case U8: {
1907 auto data = piece.data<uint8>();
1908 return AllElementsEqualValue<uint8>(data, data[0]);
1909 }
1910 // 16 bit types
1911 case BF16: {
1912 auto data = piece.data<bfloat16>();
1913 return AllElementsEqualValue<bfloat16>(data, data[0]);
1914 }
1915 case F16: {
1916 auto data = piece.data<half>();
1917 return AllElementsEqualValue<half>(data, data[0]);
1918 }
1919 case S16: {
1920 auto data = piece.data<int16>();
1921 return AllElementsEqualValue<int16>(data, data[0]);
1922 }
1923 case U16: {
1924 auto data = piece.data<uint16>();
1925 return AllElementsEqualValue<uint16>(data, data[0]);
1926 }
1927 // 32 bit types
1928 case F32: {
1929 auto data = piece.data<float>();
1930 return AllElementsEqualValue<float>(data, data[0]);
1931 }
1932 case U32: {
1933 auto data = piece.data<uint32>();
1934 return AllElementsEqualValue<uint32>(data, data[0]);
1935 }
1936 case S32: {
1937 auto data = piece.data<int32>();
1938 return AllElementsEqualValue<int32>(data, data[0]);
1939 }
1940 // 64 bit types
1941 case C64: {
1942 auto data = piece.data<complex64>();
1943 return AllElementsEqualValue<complex64>(data, data[0]);
1944 }
1945 case F64: {
1946 auto data = piece.data<double>();
1947 return AllElementsEqualValue<double>(data, data[0]);
1948 }
1949 case S64: {
1950 auto data = piece.data<int64>();
1951 return AllElementsEqualValue<int64>(data, data[0]);
1952 }
1953 case U64: {
1954 auto data = piece.data<uint64>();
1955 return AllElementsEqualValue<uint64>(data, data[0]);
1956 }
1957
1958 case C128: {
1959 auto data = piece.data<complex128>();
1960 return AllElementsEqualValue<complex128>(data, data[0]);
1961 }
1962 default:
1963 return false;
1964 }
1965 };
1966
1967 if (!piece_is_all()) {
1968 return false;
1969 }
1970 return true;
1971 });
1972 }
1973
IsR1Iota() const1974 bool LiteralBase::IsR1Iota() const {
1975 if (!shape().IsArray()) {
1976 return false;
1977 }
1978
1979 if (shape().rank() != 1) {
1980 return false;
1981 }
1982
1983 auto is_iota_at_idx = [&](const int64_t idx) {
1984 switch (shape().element_type()) {
1985 case U8:
1986 return static_cast<int64>(Get<uint8>({idx})) == idx;
1987 case U16:
1988 return static_cast<int64>(Get<uint16>({idx})) == idx;
1989 case U32:
1990 return static_cast<int64>(Get<uint32>({idx})) == idx;
1991 case U64:
1992 return static_cast<int64>(Get<uint64>({idx})) == idx;
1993 case S8:
1994 return Get<int8>({idx}) == idx;
1995 case S16:
1996 return Get<int16>({idx}) == idx;
1997 case S32:
1998 return Get<int32>({idx}) == idx;
1999 case S64:
2000 return Get<int64>({idx}) == idx;
2001 case F32:
2002 return Get<float>({idx}) == idx;
2003 case F64:
2004 return Get<double>({idx}) == idx;
2005 case F16:
2006 return Get<half>({idx}) == static_cast<half>(idx);
2007 case BF16:
2008 return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
2009 case C64:
2010 return Get<complex64>({idx}) == complex64(idx, 0.0f);
2011 case C128:
2012 return Get<complex128>({idx}) == complex128(idx, 0.0f);
2013 // pred, token, opaque, tuple, etc. are all not iota.
2014 default:
2015 return false;
2016 }
2017 };
2018
2019 const int64_t elements = ShapeUtil::ElementsIn(shape());
2020 for (int64_t idx = 0; idx < elements; ++idx) {
2021 if (!is_iota_at_idx(idx)) {
2022 return false;
2023 }
2024 }
2025
2026 return true;
2027 }
2028
2029 // Returns a stride if the literal is a strided iota, i.e., iota multiplied by a
2030 // stride. Only applicable for integer iotas. Returns absl::nullopt if the
2031 // literal is not a strided iota.
IsR1StridedIota() const2032 absl::optional<int64> LiteralBase::IsR1StridedIota() const {
2033 if (!shape().IsArray() || shape().rank() != 1) {
2034 return absl::nullopt;
2035 }
2036
2037 const int64_t elements = ShapeUtil::ElementsIn(shape());
2038 const PrimitiveType type = shape().element_type();
2039 if (elements <= 1 || !primitive_util::IsIntegralType(type)) {
2040 return absl::nullopt;
2041 }
2042
2043 auto get_element_at = [&](const int64_t idx) -> int64 {
2044 switch (type) {
2045 case U8:
2046 return static_cast<int64>(Get<uint8>({idx}));
2047 case U16:
2048 return static_cast<int64>(Get<uint16>({idx}));
2049 case U32:
2050 return static_cast<int64>(Get<uint32>({idx}));
2051 case U64:
2052 return static_cast<int64>(Get<uint64>({idx}));
2053 case S8:
2054 return Get<int8>({idx});
2055 case S16:
2056 return Get<int16>({idx});
2057 case S32:
2058 return Get<int32>({idx});
2059 case S64:
2060 return Get<int64>({idx});
2061 default:
2062 CHECK(0);
2063 return 0;
2064 }
2065 };
2066
2067 // Infer the stride as the second element (since first element is supposed
2068 // to be zero).
2069 int64_t stride = get_element_at(1);
2070 if (stride == 0) {
2071 return absl::nullopt;
2072 }
2073
2074 for (int64_t idx = 0; idx < elements; ++idx) {
2075 if (get_element_at(idx) != idx * stride) {
2076 return absl::nullopt;
2077 }
2078 }
2079
2080 return stride;
2081 }
2082
IsZero(absl::Span<const int64> indices) const2083 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
2084 CHECK(shape().IsArray());
2085 switch (shape().element_type()) {
2086 case U8:
2087 return Get<uint8>(indices) == 0;
2088 case U16:
2089 return Get<uint16>(indices) == 0;
2090 case U32:
2091 return Get<uint32>(indices) == 0;
2092 case U64:
2093 return Get<uint64>(indices) == 0;
2094 case S8:
2095 return Get<int8>(indices) == 0;
2096 case S16:
2097 return Get<int16>(indices) == 0;
2098 case S32:
2099 return Get<int32>(indices) == 0;
2100 case S64:
2101 return Get<int64>(indices) == 0;
2102 case F32:
2103 return Get<float>(indices) == 0.0f;
2104 case F64:
2105 return Get<double>(indices) == 0.0;
2106 case C64:
2107 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
2108 case C128:
2109 return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2110 case F16:
2111 return Get<half>(indices) == static_cast<half>(0.0f);
2112 case BF16:
2113 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2114 case PRED:
2115 return Get<bool>(indices) == false;
2116 default:
2117 LOG(FATAL) << "Input literal must be an array.";
2118 }
2119 }
2120
2121 namespace {
2122
2123 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2124 void CopyToRepeatedField(RepeatedFieldT* dest,
2125 const absl::Span<const NativeT> src) {
2126 *dest = RepeatedFieldT(src.begin(), src.end());
2127 }
2128
2129 } // namespace
2130
WriteToProto(LiteralProto * proto) const2131 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2132 *proto->mutable_shape() = subshape().ToProto();
2133 switch (subshape().element_type()) {
2134 case PRED:
2135 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2136 break;
2137 case S8:
2138 proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
2139 element_count());
2140 break;
2141 case U8:
2142 proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
2143 element_count());
2144 break;
2145 case U32:
2146 CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
2147 break;
2148 case U64:
2149 CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
2150 break;
2151 case S32:
2152 CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
2153 break;
2154 case S64:
2155 CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
2156 break;
2157 case U16:
2158 *proto->mutable_u16s() = string(
2159 reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2160 if (!kLittleEndian) {
2161 ConvertEndianShort(proto->mutable_u16s());
2162 }
2163 break;
2164 case S16:
2165 *proto->mutable_s16s() = string(
2166 reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2167 if (!kLittleEndian) {
2168 ConvertEndianShort(proto->mutable_s16s());
2169 }
2170 break;
2171 case F16:
2172 *proto->mutable_f16s() = string(
2173 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2174 if (!kLittleEndian) {
2175 ConvertEndianShort(proto->mutable_f16s());
2176 }
2177 break;
2178 case BF16:
2179 *proto->mutable_bf16s() = string(
2180 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2181 if (!kLittleEndian) {
2182 ConvertEndianShort(proto->mutable_bf16s());
2183 }
2184 break;
2185 case F32:
2186 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2187 break;
2188 case F64:
2189 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2190 break;
2191 case C64:
2192 for (complex64 value : data<complex64>()) {
2193 proto->add_c64s(value.real());
2194 proto->add_c64s(value.imag());
2195 }
2196 break;
2197 case C128:
2198 for (complex128 value : data<complex128>()) {
2199 proto->add_c128s(value.real());
2200 proto->add_c128s(value.imag());
2201 }
2202 break;
2203 case TUPLE:
2204 case TOKEN:
2205 // Nothing to do but assign the shape which is done above.
2206 return;
2207 default:
2208 // TODO(b/111551621): Support serializing more PrimitiveTypes.
2209 LOG(FATAL) << "Unhandled primitive type "
2210 << PrimitiveType_Name(subshape().element_type());
2211 }
2212 }
2213
untyped_data() const2214 const void* LiteralBase::Piece::untyped_data() const {
2215 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2216 return buffer();
2217 }
2218
untyped_data()2219 void* LiteralBase::Piece::untyped_data() {
2220 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2221 return buffer();
2222 }
2223
2224 namespace {
2225
2226 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2227 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2228 const RepeatedFieldT& src) {
2229 if (dest.size() != src.size()) {
2230 return InvalidArgument(
2231 "Expected %lu elements in LiteralProto repeated field, has %d",
2232 dest.size(), src.size());
2233 }
2234 std::copy(src.begin(), src.end(), dest.begin());
2235 return Status::OK();
2236 }
2237
2238 } // namespace
2239
CopyFromProto(const LiteralProto & proto)2240 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2241 // These conditions should have been checked in
2242 // MutableLiteralBase::CreateFromProto.
2243 TF_RET_CHECK(proto.has_shape());
2244 Shape shape(proto.shape());
2245 TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2246 TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2247
2248 switch (subshape().element_type()) {
2249 case PRED:
2250 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2251 break;
2252 case S8: {
2253 auto s8_data = data<int8>();
2254 TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2255 std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2256 } break;
2257 case U8: {
2258 auto u8_data = data<uint8>();
2259 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2260 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2261 } break;
2262 case S32:
2263 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2264 break;
2265 case S64:
2266 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2267 break;
2268 case U32:
2269 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2270 break;
2271 case U64:
2272 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2273 break;
2274 case S16: {
2275 const string& s(proto.s16s());
2276 TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2277 memcpy(untyped_data(), s.data(), s.size());
2278 if (!kLittleEndian) {
2279 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2280 }
2281 } break;
2282 case U16: {
2283 const string& s(proto.u16s());
2284 TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2285 memcpy(untyped_data(), s.data(), s.size());
2286 if (!kLittleEndian) {
2287 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2288 }
2289 } break;
2290 case F16: {
2291 const string& s(proto.f16s());
2292 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2293 memcpy(untyped_data(), s.data(), s.size());
2294 if (!kLittleEndian) {
2295 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2296 }
2297 } break;
2298
2299 case BF16: {
2300 const string& s(proto.bf16s());
2301 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2302 memcpy(untyped_data(), s.data(), s.size());
2303 if (!kLittleEndian) {
2304 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2305 }
2306 } break;
2307 case F32:
2308 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2309 break;
2310 case F64:
2311 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2312 break;
2313 case C64: {
2314 auto complex_data = data<complex64>();
2315 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2316 for (int64_t i = 0; i < complex_data.size(); ++i) {
2317 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2318 }
2319 break;
2320 }
2321 case C128: {
2322 auto complex_data = data<complex128>();
2323 const int64_t complex_data_size_doubled = complex_data.size() * 2;
2324 TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2325 for (int64_t i = 0, end = complex_data.size(); i < end; ++i) {
2326 complex_data[i] =
2327 complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2328 }
2329 break;
2330 }
2331 case TUPLE:
2332 return InvalidArgument("Should not be called on tuple shapes: %s",
2333 ShapeUtil::HumanString(subshape()));
2334 default:
2335 return InvalidArgument("Is called on unsupported shape: %s",
2336 ShapeUtil::HumanString(subshape()));
2337 }
2338 return Status::OK();
2339 }
2340
ToProto() const2341 LiteralProto LiteralBase::ToProto() const {
2342 LiteralProto proto;
2343 root_piece().ForEachSubpiece(
2344 [&](const ShapeIndex& index, const Piece& piece) {
2345 LiteralProto* proto_piece = &proto;
2346 for (int64_t i : index) {
2347 while (proto_piece->tuple_literals_size() <= i) {
2348 proto_piece->add_tuple_literals();
2349 }
2350 proto_piece = proto_piece->mutable_tuple_literals(i);
2351 }
2352 piece.WriteToProto(proto_piece);
2353 });
2354
2355 return proto;
2356 }
2357
untyped_data(const ShapeIndex & shape_index) const2358 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2359 return piece(shape_index).untyped_data();
2360 }
2361
untyped_data(const ShapeIndex & shape_index)2362 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2363 return piece(shape_index).untyped_data();
2364 }
2365
size_bytes(const ShapeIndex & shape_index) const2366 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2367 return piece(shape_index).size_bytes();
2368 }
2369
GetR1U8AsString() const2370 string LiteralBase::GetR1U8AsString() const {
2371 CHECK(shape().IsArray());
2372 CHECK_EQ(shape().rank(), 1);
2373 CHECK_EQ(shape().element_type(), U8);
2374 return string(absl::bit_cast<const char*>(data<uint8>().data()),
2375 ShapeUtil::ElementsIn(shape()));
2376 }
2377
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2378 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2379 Piece* src_piece,
2380 Piece* dest_piece) {
2381 DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2382 << "src_piece has shape: "
2383 << ShapeUtil::HumanString(src_piece->subshape())
2384 << "dest_piece has shape: "
2385 << ShapeUtil::HumanString(dest_piece->subshape());
2386 if (shape.IsTuple()) {
2387 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2388 const Shape& subshape = shape.tuple_shapes(i);
2389
2390 auto child_piece = Piece();
2391 child_piece.set_subshape(&subshape);
2392
2393 CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2394
2395 dest_piece->emplace_back(std::move(child_piece));
2396 }
2397 } else if (shape.IsArray()) {
2398 dest_piece->set_buffer(src_piece->buffer());
2399 dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
2400 } else {
2401 // If the shape is neither an array nor tuple, then it must be
2402 // zero-sized. Otherwise, some memory needs to be allocated for it.
2403 CHECK_EQ(dest_piece->size_bytes(), 0);
2404 }
2405 }
2406
~MutableLiteralBase()2407 MutableLiteralBase::~MutableLiteralBase() {}
2408
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2409 MutableBorrowingLiteral::MutableBorrowingLiteral(
2410 const MutableBorrowingLiteral& literal)
2411 : MutableLiteralBase() {
2412 shape_ = absl::make_unique<Shape>(literal.shape());
2413 CHECK(LayoutUtil::HasLayout(*shape_));
2414
2415 root_piece_ = new Piece();
2416 root_piece_->set_subshape(shape_.get());
2417
2418 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2419 }
2420
operator =(const MutableBorrowingLiteral & literal)2421 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2422 const MutableBorrowingLiteral& literal) {
2423 shape_ = absl::make_unique<Shape>(literal.shape());
2424 CHECK(LayoutUtil::HasLayout(*shape_));
2425
2426 root_piece_ = new Piece();
2427 root_piece_->set_subshape(shape_.get());
2428
2429 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2430
2431 return *this;
2432 }
2433
MutableBorrowingLiteral(MutableLiteralBase * literal)2434 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2435 : MutableLiteralBase() {
2436 shape_ = absl::make_unique<Shape>(literal->shape());
2437 CHECK(LayoutUtil::HasLayout(*shape_));
2438
2439 root_piece_ = new Piece();
2440 root_piece_->set_subshape(shape_.get());
2441
2442 CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2443 }
2444
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2445 MutableBorrowingLiteral::MutableBorrowingLiteral(
2446 MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2447 : MutableLiteralBase() {
2448 shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2449 CHECK(LayoutUtil::HasLayout(*shape_));
2450
2451 root_piece_ = new Piece();
2452 root_piece_->set_subshape(shape_.get());
2453
2454 CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2455 }
2456
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2457 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2458 const Shape& shape)
2459 : MutableLiteralBase() {
2460 shape_ = absl::make_unique<Shape>(shape);
2461 CHECK(LayoutUtil::HasLayout(*shape_));
2462 CHECK(!shape_->IsTuple());
2463
2464 root_piece_ = new Piece();
2465 root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2466 root_piece_->set_subshape(shape_.get());
2467 }
2468
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2469 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2470 const Shape& shape)
2471 : MutableLiteralBase() {
2472 shape_ = absl::make_unique<Shape>(shape);
2473 if (!shape_->IsTuple()) {
2474 CHECK_EQ(src_buf_ptrs.size(), 1);
2475 root_piece_ = new Piece();
2476 root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2477 root_piece_->set_subshape(shape_.get());
2478 } else {
2479 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2480 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2481 root_piece_ = new Piece();
2482 root_piece_->set_subshape(shape_.get());
2483
2484 for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2485 Piece child_piece;
2486 const auto& src_shape = shape_->tuple_shapes(i);
2487 CHECK(src_shape.IsArray());
2488 child_piece.set_subshape(&src_shape);
2489 child_piece.set_buffer(src_buf_ptrs[i]);
2490 root_piece_->emplace_back(std::move(child_piece));
2491 }
2492 }
2493 }
2494
~MutableBorrowingLiteral()2495 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2496 if (root_piece_ != nullptr) {
2497 delete root_piece_;
2498 }
2499 }
2500
LiteralSlice(const LiteralBase & literal)2501 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2502 : LiteralBase(), root_piece_(&literal.root_piece()) {}
2503
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2504 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2505 const ShapeIndex& view_root)
2506 : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2507
BuildPieceSubtree(const Shape & shape,Piece * piece)2508 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2509 CHECK(shape.IsTuple());
2510 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2511 const Shape& subshape = shape.tuple_shapes(i);
2512
2513 auto child_piece = Piece();
2514 child_piece.set_subshape(&subshape);
2515
2516 if (subshape.IsTuple()) {
2517 BuildPieceSubtree(subshape, &child_piece);
2518 }
2519
2520 piece->emplace_back(std::move(child_piece));
2521 }
2522 }
2523
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2524 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2525 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2526 CHECK(shape_->IsArray());
2527 CHECK(LayoutUtil::HasLayout(*shape_));
2528
2529 root_piece_ = Piece();
2530 root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2531 root_piece_.set_subshape(shape_.get());
2532 }
2533
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2534 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2535 const Shape& shape)
2536 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2537 CHECK(shape_->IsTuple());
2538 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2539 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2540 root_piece_ = Piece();
2541 root_piece_.set_subshape(shape_.get());
2542 BuildPieceSubtree(*shape_, &root_piece_);
2543
2544 for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2545 const auto& src_shape = shape_->tuple_shapes(i);
2546 CHECK(src_shape.IsArray());
2547 root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2548 }
2549 }
2550
2551 } // namespace xla
2552