1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/index_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/mem.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace xla {
45 namespace {
46
47 using absl::StrCat;
48
49 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
50
51 // Converts between little and big endian.
52 //
53 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)54 void ConvertEndianShort(string* bytes) {
55 CHECK_EQ(bytes->size() / 2, 0);
56 for (int64 i = 0; i < bytes->size(); i += 2) {
57 std::swap((*bytes)[i], (*bytes)[i + 1]);
58 }
59 }
60
ConvertEndianShort(char * bytes,int64 size)61 void ConvertEndianShort(char* bytes, int64 size) {
62 CHECK_EQ(size / 2, 0);
63 for (int64 i = 0; i < size; i += 2) {
64 std::swap(bytes[i], bytes[i + 1]);
65 }
66 }
67
68 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
69 // able to transparently access the raw 16-bit value contained within.
70 template <typename T>
GetRawValue(T val)71 T GetRawValue(T val) {
72 return val;
73 }
GetRawValue(Eigen::half val)74 uint16 GetRawValue(Eigen::half val) { return val.x; }
75
LiteralProtoHasValues(const LiteralProto & proto)76 bool LiteralProtoHasValues(const LiteralProto& proto) {
77 return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
78 proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
79 proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
80 proto.c64s_size() || proto.c128s_size() ||
81 proto.tuple_literals_size() || !proto.f16s().empty() ||
82 !proto.bf16s().empty() || !proto.u16s().empty() ||
83 !proto.s16s().empty();
84 }
85
86 } // namespace
87
~LiteralBase()88 LiteralBase::~LiteralBase() {}
89
operator <<(std::ostream & out,const Literal & literal)90 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
91 out << literal.ToString();
92 return out;
93 }
94
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)95 MutableLiteralBase::StrideConfig::StrideConfig(
96 const Shape& source_shape, const Shape& dest_shape,
97 absl::Span<const int64> dimensions)
98 : dimensions(dimensions),
99 base(dimensions.size(), 0),
100 step(dimensions.size(), 1) {
101 if (!dimensions.empty()) {
102 // Selects the shape with the largest minor dimension as the one upon
103 // which to run the tight stride loop.
104 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
105 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
106 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
107 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
108 } else {
109 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
110 source_stride =
111 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
112 }
113 minor_loop_size = dimensions[minor_dimension];
114 step[minor_dimension] = minor_loop_size;
115 }
116 }
117
Literal(const Shape & shape)118 Literal::Literal(const Shape& shape)
119 : Literal(shape, /*allocate_arrays=*/true) {}
120
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)121 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
122 if (shape.IsTuple()) {
123 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
124 const Shape& subshape = shape.tuple_shapes(i);
125
126 auto child_piece = Piece();
127 child_piece.set_subshape(&subshape);
128
129 SetPiece(subshape, &child_piece, allocate_arrays);
130
131 piece->emplace_back(std::move(child_piece));
132 }
133 } else if (shape.IsArray()) {
134 if (allocate_arrays) {
135 // Literals can be used as DMA targets, which can require alignment. We
136 // force a 16-byte minimum alignment.
137 constexpr int kMinimumAlignment = 16;
138 piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
139 piece->size_bytes(), kMinimumAlignment)));
140 }
141 } else {
142 // If the shape is neither an array nor tuple, then it must be
143 // zero-sized. Otherwise, some memory needs to be allocated for it.
144 CHECK_EQ(piece->size_bytes(), 0);
145 }
146 }
147
Literal(const Shape & shape,bool allocate_arrays)148 Literal::Literal(const Shape& shape, bool allocate_arrays)
149 : MutableLiteralBase() {
150 shape_ = absl::make_unique<Shape>(shape);
151 CHECK(LayoutUtil::HasLayout(*shape_));
152 root_piece_ = new Piece();
153 root_piece_->set_subshape(shape_.get());
154 CHECK(&root_piece_->subshape() == shape_.get());
155
156 SetPiece(*shape_, root_piece_, allocate_arrays);
157 }
158
~Literal()159 Literal::~Literal() {
160 if (root_piece_ != nullptr) {
161 DeallocateBuffers();
162 delete root_piece_;
163 }
164 }
165
DeallocateBuffers()166 void Literal::DeallocateBuffers() {
167 root_piece_->ForEachMutableSubpiece(
168 [&](const ShapeIndex& index, Piece* piece) {
169 if (piece->buffer() != nullptr) {
170 tensorflow::port::AlignedFree(piece->buffer());
171 }
172 });
173 }
174
Literal(Literal && other)175 Literal::Literal(Literal&& other) : MutableLiteralBase() {
176 *this = std::move(other);
177 }
178
operator =(Literal && other)179 Literal& Literal::operator=(Literal&& other) {
180 DCHECK(&other.root_piece_->subshape() == other.shape_.get());
181 using std::swap;
182 swap(shape_, other.shape_);
183 swap(root_piece_, other.root_piece_);
184 DCHECK(&root_piece_->subshape() == shape_.get());
185
186 return *this;
187 }
188
CreateFromShape(const Shape & shape)189 Literal LiteralBase::CreateFromShape(const Shape& shape) {
190 Literal literal(shape);
191 literal.root_piece_->ForEachMutableSubpiece(
192 [&](const ShapeIndex& index, Piece* piece) {
193 if (piece->subshape().IsArray()) {
194 memset(piece->untyped_data(), 0, piece->size_bytes());
195 }
196 });
197 return literal;
198 }
199
200 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)201 Status MutableLiteralBase::CopySliceFromInternal(
202 const LiteralBase& src_literal, absl::Span<const int64> src_base,
203 absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
204 TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
205 TF_RET_CHECK(shape().rank() == dest_base.size());
206
207 auto linear_index = [](const Shape& shape,
208 absl::Span<const int64> multi_index) {
209 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
210 };
211
212 if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
213 // If any of the two shapes are scalars, we can just call the StridedCopy()
214 // directly, and we know we will be copying only one value.
215 TF_RET_CHECK(copy_size.empty());
216 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
217 src_literal.data<NativeT>(),
218 linear_index(src_literal.shape(), src_base), 0, 1);
219 } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
220 !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
221 // Perform copy if neither src nor dest has dimensions with zero element,
222 // otherwise it's a no-op.
223 TF_RET_CHECK(src_base.size() == dest_base.size());
224 TF_RET_CHECK(src_base.size() == copy_size.size());
225
226 // Scan the source from minor, stepping in copy size blocks, then within
227 // the index enumeration functor, do a strided copy advancing source index
228 // by one (walking through the minor dimension), and destination index by
229 // proper stride size at the matching dimension.
230 DimensionVector src_indexes(src_base.size(), 0);
231 DimensionVector dest_indexes(dest_base.size(), 0);
232 MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
233 copy_size);
234
235 auto copy_proc = [&](absl::Span<const int64> indexes) {
236 // Map from multi-dimensional index, to source index.
237 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
238 src_indexes.begin(), std::plus<int64>());
239 // Map from multi-dimensional index, to destination index.
240 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
241 dest_indexes.begin(), std::plus<int64>());
242
243 int64 src_index = linear_index(src_literal.shape(), src_indexes);
244 int64 dest_index = linear_index(shape(), dest_indexes);
245
246 // `this->` is needed to workaround MSVC bug: #16882
247 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
248 src_literal.data<NativeT>(), src_index,
249 stride_config.source_stride, stride_config.minor_loop_size);
250 return true;
251 };
252
253 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
254 stride_config.dimensions, stride_config.step,
255 copy_proc);
256 }
257 return Status::OK();
258 }
259
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)260 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
261 absl::Span<const int64> src_index,
262 absl::Span<const int64> dest_index) {
263 DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
264 const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
265 src_literal.shape(), src_index);
266 const int64 dest_linear_index =
267 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
268 const int64 primitive_size =
269 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
270
271 char* dest_address =
272 static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
273 const char* source_address =
274 static_cast<const char*>(src_literal.untyped_data()) +
275 src_linear_index * primitive_size;
276 if (dest_address != source_address) {
277 memcpy(dest_address, source_address, primitive_size);
278 }
279 return Status::OK();
280 }
281
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)282 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
283 const LiteralProto& proto, bool prohibit_empty_literal) {
284 if (!proto.has_shape()) {
285 return InvalidArgument("LiteralProto has no shape");
286 }
287 Shape shape(proto.shape());
288 if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
289 return InvalidArgument(
290 "Literal shape cannot include OPAQUE_TYPE sub-shape");
291 }
292 if (!LayoutUtil::HasLayout(shape)) {
293 return InvalidArgument("LiteralProto has no layout");
294 }
295
296 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
297
298 Literal literal(shape);
299
300 TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
301 [&](const ShapeIndex& index, Piece* piece) {
302 const LiteralProto* proto_element = &proto;
303 for (int64 i : index) {
304 CHECK(i < proto_element->tuple_literals_size());
305 proto_element = &proto_element->tuple_literals(i);
306 }
307
308 if (piece->subshape().IsTuple()) {
309 if (proto_element->tuple_literals_size() !=
310 ShapeUtil::TupleElementCount(piece->subshape())) {
311 return InvalidArgument(
312 "Expected %d tuple elements in LiteralProto, has %d",
313 ShapeUtil::TupleElementCount(piece->subshape()),
314 proto_element->tuple_literals_size());
315 }
316 return Status::OK();
317 }
318 if (piece->subshape().element_type() == TOKEN) {
319 return Status::OK();
320 }
321
322 CHECK(piece->subshape().IsArray());
323
324 // When prohibit_empty_literal is false (allowing literal with no
325 // values), only copy from proto if the literal proto has values. This
326 // mode is used for a learned cost model.
327 if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
328 TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
329 }
330
331 return Status::OK();
332 }));
333
334 return std::move(literal);
335 }
336
DecomposeTuple()337 std::vector<Literal> Literal::DecomposeTuple() {
338 CHECK(shape().IsTuple());
339 std::vector<Literal> elements;
340 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
341 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
342 /*allocate_arrays=*/false));
343 Literal& element = elements.back();
344 element.root_piece_->ForEachMutableSubpiece(
345 [&](const ShapeIndex& index, Piece* dest_piece) {
346 ShapeIndex src_index = {i};
347 for (int64 j : index) {
348 src_index.push_back(j);
349 }
350 Piece& src_piece = piece(src_index);
351
352 // Move the respective buffer over to the element Literal.
353 dest_piece->set_buffer(src_piece.buffer());
354 src_piece.set_buffer(nullptr);
355 });
356 }
357 // Set this literal to be nil-shaped.
358 *this = Literal();
359 return elements;
360 }
361
362 namespace {
363
364 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
365 // the array slices are indicated by dest_shape and src_shape respectively.
366 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)367 void CopyElementsBetween(absl::Span<NativeT> dest,
368 absl::Span<const NativeT> src, const Shape& dest_shape,
369 const Shape& src_shape) {
370 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
371 if (ShapeUtil::IsZeroElementArray(dest_shape)) {
372 return;
373 }
374 std::vector<int64> index(dest_shape.rank());
375 do {
376 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
377 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
378 } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
379 }
380
381 } // namespace
382
CopyFrom(const LiteralBase::Piece & src)383 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
384 CHECK(subshape_ != nullptr);
385 CHECK(src.subshape_ != nullptr);
386 if (ShapeUtil::Equal(subshape(), src.subshape())) {
387 // If the layouts are equal it's faster just to memcpy.
388 memcpy(buffer(), src.buffer(), src.size_bytes());
389 } else {
390 TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
391 std::vector<int64> origin(subshape().rank(), 0);
392 switch (subshape().element_type()) {
393 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
394 case (XLA_T): \
395 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
396 subshape(), src.subshape()); \
397 break;
398 COPY_ELEMENTS(U8, uint8);
399 COPY_ELEMENTS(U16, uint16);
400 COPY_ELEMENTS(U32, uint32);
401 COPY_ELEMENTS(U64, uint64);
402 COPY_ELEMENTS(S8, int8);
403 COPY_ELEMENTS(S16, int16);
404 COPY_ELEMENTS(S32, int32);
405 COPY_ELEMENTS(S64, int64);
406 COPY_ELEMENTS(F16, half);
407 COPY_ELEMENTS(BF16, bfloat16);
408 COPY_ELEMENTS(F32, float);
409 COPY_ELEMENTS(F64, double);
410 COPY_ELEMENTS(C64, complex64);
411 COPY_ELEMENTS(C128, complex128);
412 COPY_ELEMENTS(PRED, bool);
413 #undef COPY_ELEMENTS
414 default:
415 return Unimplemented(
416 "Copying a Literal object with element type %s is not implemented.",
417 PrimitiveType_Name(subshape().element_type()));
418 }
419 }
420 return Status::OK();
421 }
422
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index)423 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
424 const ShapeIndex& dest_shape_index,
425 const ShapeIndex& src_shape_index) {
426 const Shape& dest_subshape =
427 ShapeUtil::GetSubshape(shape(), dest_shape_index);
428 const Shape& src_subshape =
429 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
430 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
431 return InvalidArgument(
432 "Destination subshape incompatible with source subshape: %s vs %s",
433 ShapeUtil::HumanString(dest_subshape),
434 ShapeUtil::HumanString(src_subshape));
435 }
436 return root_piece_->ForEachMutableSubpieceWithStatus(
437 [&](const ShapeIndex& index, Piece* piece) {
438 if (!piece->subshape().IsArray()) {
439 return Status::OK();
440 }
441
442 // Determine if this index is in the part of this literal that we want
443 // to copy over from src_literal.
444 bool in_subtree_to_copy = true;
445 for (int i = 0; i < dest_shape_index.size(); ++i) {
446 if (index[i] != dest_shape_index[i]) {
447 in_subtree_to_copy = false;
448 break;
449 }
450 }
451 if (!in_subtree_to_copy) {
452 return Status::OK();
453 }
454 // Construct the index of the corresponding piece in the source literal.
455 ShapeIndex src_piece_index = src_shape_index;
456 for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
457 src_piece_index.push_back(index[i]);
458 }
459 TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
460 return Status::OK();
461 });
462 }
463
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)464 Status Literal::MoveFrom(Literal&& src_literal,
465 const ShapeIndex& dest_shape_index) {
466 const Shape& dest_subshape =
467 ShapeUtil::GetSubshape(shape(), dest_shape_index);
468 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
469 return InvalidArgument(
470 "Destination subshape not equal to source shape: %s vs %s",
471 ShapeUtil::HumanString(dest_subshape),
472 ShapeUtil::HumanString(src_literal.shape()));
473 }
474
475 src_literal.root_piece_->ForEachSubpiece(
476 [&](const ShapeIndex& src_index, const Piece& src_piece) {
477 if (!src_piece.subshape().IsArray()) {
478 return;
479 }
480
481 ShapeIndex dest_index = dest_shape_index;
482 for (int64 i : src_index) {
483 dest_index.push_back(i);
484 }
485 Piece& dest_piece = piece(dest_index);
486 tensorflow::port::AlignedFree(dest_piece.buffer());
487 dest_piece.set_buffer(src_piece.buffer());
488 });
489
490 src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
491 delete src_literal.root_piece_;
492 src_literal.root_piece_ = new LiteralBase::Piece();
493 src_literal.root_piece_->set_subshape(src_literal.shape_.get());
494
495 return Status::OK();
496 }
497
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)498 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
499 absl::Span<const int64> src_base,
500 absl::Span<const int64> dest_base,
501 absl::Span<const int64> copy_size) {
502 TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
503 TF_RET_CHECK(src_literal.shape().IsArray())
504 << ShapeUtil::HumanString(src_literal.shape());
505 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
506
507 switch (shape().element_type()) {
508 case U8:
509 return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
510 copy_size);
511 case U16:
512 return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
513 copy_size);
514 case U32:
515 return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
516 copy_size);
517 case U64:
518 return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
519 copy_size);
520 case S8:
521 return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
522 copy_size);
523 case S16:
524 return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
525 copy_size);
526 case S32:
527 return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
528 copy_size);
529 case S64:
530 return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
531 copy_size);
532 case F16:
533 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
534 copy_size);
535 case BF16:
536 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
537 copy_size);
538 case F32:
539 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
540 copy_size);
541 case F64:
542 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
543 copy_size);
544 case C64:
545 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
546 copy_size);
547 case C128:
548 return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
549 copy_size);
550 case PRED:
551 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
552 copy_size);
553 default:
554 break;
555 }
556 return Unimplemented(
557 "Copying a slice from a Literal object with element type %d is not "
558 "implemented.",
559 shape().element_type());
560 }
561
PopulateR1(const tensorflow::core::Bitmap & values)562 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
563 CHECK(shape().IsArray());
564 CHECK_EQ(shape().rank(), 1);
565 CHECK_EQ(element_count(), values.bits());
566 CHECK_EQ(shape().element_type(), PRED);
567 for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
568 Set({i}, values.get(i));
569 }
570 }
571
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const572 Literal LiteralBase::Relayout(const Layout& new_layout,
573 const ShapeIndex& shape_index) const {
574 // Create new shape with 'new_layout' set at the given shape index.
575 Shape new_shape = shape();
576 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
577 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
578 *subshape->mutable_layout() = new_layout;
579 Literal result(new_shape);
580 TF_CHECK_OK(result.CopyFrom(*this));
581 return result;
582 }
583
Relayout(const Shape & shape_with_layout) const584 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
585 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
586 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
587 << " not compatible with literal shape "
588 << ShapeUtil::HumanString(shape());
589 Literal result = CreateFromShape(shape_with_layout);
590 ShapeUtil::ForEachSubshape(
591 result.shape(),
592 [this, &result](const Shape& subshape, const ShapeIndex& index) {
593 if (subshape.IsArray()) {
594 TF_CHECK_OK(result.CopyFrom(*this,
595 /*dest_shape_index=*/index,
596 /*src_shape_index=*/index));
597 }
598 });
599 return result;
600 }
601
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const602 StatusOr<Literal> LiteralBase::Broadcast(
603 const Shape& result_shape, absl::Span<const int64> dimensions) const {
604 if (!shape().IsArray()) {
605 return InvalidArgument("Broadcast only supports arrays.");
606 }
607
608 for (int64 i = 0; i < dimensions.size(); i++) {
609 TF_RET_CHECK(shape().dimensions(i) ==
610 result_shape.dimensions(dimensions[i]));
611 }
612
613 Literal result(result_shape);
614
615 // scratch_source_index is temporary storage space for the computed index into
616 // the input literal. We put it here to avoid allocating an std::vector in
617 // every iteration of ShapeUtil::ForEachIndex.
618 std::vector<int64> scratch_source_index(shape().dimensions_size());
619
620 char* dest_data = static_cast<char*>(result.untyped_data());
621 const char* source_data = static_cast<const char*>(untyped_data());
622 const int64 primitive_size =
623 ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
624
625 ShapeUtil::ForEachIndex(
626 result_shape, [&](absl::Span<const int64> output_index) {
627 for (int64 i = 0; i < dimensions.size(); ++i) {
628 scratch_source_index[i] = output_index[dimensions[i]];
629 }
630 int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
631 result_shape, output_index);
632 int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
633 shape(), scratch_source_index);
634 memcpy(dest_data + primitive_size * dest_index,
635 source_data + primitive_size * source_index, primitive_size);
636 return true;
637 });
638
639 return std::move(result);
640 }
641
Reshape(absl::Span<const int64> dimensions) const642 StatusOr<Literal> LiteralBase::Reshape(
643 absl::Span<const int64> dimensions) const {
644 if (!shape().IsArray()) {
645 return InvalidArgument("Reshape does not support tuples.");
646 }
647 Literal output;
648 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
649 output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
650 } else {
651 output = Clone();
652 }
653 // Because the layout is monotonic, we can simply reuse the same sequence of
654 // values without changing their order.
655 *output.mutable_shape_do_not_use() =
656 ShapeUtil::MakeShape(shape().element_type(), dimensions);
657
658 int64 elements_before = ShapeUtil::ElementsIn(shape());
659 int64 elements_after = ShapeUtil::ElementsIn(output.shape());
660 if (elements_before != elements_after) {
661 return InvalidArgument(
662 "Shapes before and after Literal::Reshape have different numbers "
663 "of elements: %s vs %s.",
664 ShapeUtil::HumanString(shape()),
665 ShapeUtil::HumanString(output.shape()));
666 }
667 return std::move(output);
668 }
669
Transpose(absl::Span<const int64> permutation) const670 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
671 CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
672 CHECK(IsPermutation(permutation, shape().rank()))
673 << "Given permutation is not a permutation of dimension numbers";
674 // To transpose the array, we just permute the dimensions and layout, and
675 // do a straight memory copy of the raw data set.
676 // This is considerably faster than iterating over every array element using
677 // the EachCell<>() and Set<>() APIs.
678 std::vector<int64> inverse_permutation = InversePermutation(permutation);
679 Shape permuted_shape =
680 ShapeUtil::PermuteDimensions(inverse_permutation, shape());
681 // Replace the layout with one affine to this shape, such that a
682 // transpose operation can be performed by leaving the flat values
683 // representation intact.
684 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
685 // The shape with affine layout resulting from that operation will be
686 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
687 // most minor.
688 //
689 // Essentially, given MinMaj(Di) the position of the Di dimension within the
690 // minor to major vector, and given T(Di) the index that the original Di
691 // dimension has within the transposed array, a layout is affine if
692 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
693 // vector of the affine layout.
694 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
695 Layout* layout = permuted_shape.mutable_layout();
696 layout->clear_minor_to_major();
697 for (auto index : LayoutUtil::MinorToMajor(shape())) {
698 layout->add_minor_to_major(inverse_permutation[index]);
699 }
700 Literal new_literal(permuted_shape);
701 DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
702 ShapeUtil::ByteSizeOf(shape()));
703 std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
704 return new_literal;
705 }
706
707 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const708 Literal LiteralBase::SliceInternal(
709 const Shape& result_shape, absl::Span<const int64> start_indices) const {
710 Literal result_literal(result_shape);
711 DimensionVector new_indices(result_shape.rank());
712 CHECK(result_literal
713 .Populate<NativeT>([&](absl::Span<const int64> indices) {
714 for (int64 i = 0; i < result_shape.rank(); ++i) {
715 new_indices[i] = indices[i] + start_indices[i];
716 }
717 return Get<NativeT>(new_indices);
718 })
719 .ok());
720 return result_literal;
721 }
722
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const723 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
724 absl::Span<const int64> limit_indices) const {
725 CHECK(shape().IsArray()) << "tuple is not supported for slice";
726
727 DimensionVector result_dimensions;
728 for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
729 CHECK_GE(start_indices[dnum], 0);
730 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
731 << "dnum = " << dnum;
732 int64 dimension = limit_indices[dnum] - start_indices[dnum];
733 CHECK_GE(dimension, 0) << "dnum = " << dnum;
734 result_dimensions.push_back(dimension);
735 }
736 const auto result_shape =
737 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
738 LayoutUtil::MinorToMajor(shape()));
739 switch (result_shape.element_type()) {
740 case PRED:
741 return SliceInternal<bool>(result_shape, start_indices);
742 case U8:
743 return SliceInternal<uint8>(result_shape, start_indices);
744 case U16:
745 return SliceInternal<uint16>(result_shape, start_indices);
746 case U32:
747 return SliceInternal<uint32>(result_shape, start_indices);
748 case U64:
749 return SliceInternal<uint64>(result_shape, start_indices);
750 case S8:
751 return SliceInternal<int8>(result_shape, start_indices);
752 case S16:
753 return SliceInternal<int16>(result_shape, start_indices);
754 case S32:
755 return SliceInternal<int32>(result_shape, start_indices);
756 case S64:
757 return SliceInternal<int64>(result_shape, start_indices);
758 case F16:
759 return SliceInternal<half>(result_shape, start_indices);
760 case BF16:
761 return SliceInternal<bfloat16>(result_shape, start_indices);
762 case F32:
763 return SliceInternal<float>(result_shape, start_indices);
764 case F64:
765 return SliceInternal<double>(result_shape, start_indices);
766 case C64:
767 return SliceInternal<complex64>(result_shape, start_indices);
768 case C128:
769 return SliceInternal<complex128>(result_shape, start_indices);
770 default:
771 LOG(FATAL) << "not yet implemented: "
772 << PrimitiveType_Name(result_shape.element_type());
773 }
774 }
775
Clone() const776 Literal LiteralBase::Clone() const {
777 Literal result(shape());
778 TF_CHECK_OK(result.CopyFrom(*this));
779 return result;
780 }
781
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const782 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
783 const ShapeIndex& shape_index) const {
784 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
785 CHECK(LayoutUtil::IsDenseArray(subshape));
786 switch (subshape.element_type()) {
787 case PRED:
788 return Get<bool>(multi_index, shape_index) ? "true" : "false";
789 case S8:
790 return StrCat(Get<int8>(multi_index, shape_index));
791 case S16:
792 return StrCat(Get<int16>(multi_index, shape_index));
793 case S32:
794 return StrCat(Get<int32>(multi_index, shape_index));
795 case S64:
796 return StrCat(Get<int64>(multi_index, shape_index));
797 case U8:
798 return StrCat(Get<uint8>(multi_index, shape_index));
799 case U16:
800 return StrCat(Get<uint16>(multi_index, shape_index));
801 case U32:
802 return StrCat(Get<uint32>(multi_index, shape_index));
803 case U64:
804 return StrCat(Get<uint64>(multi_index, shape_index));
805 case F16:
806 return RoundTripFpToString(Get<half>(multi_index, shape_index));
807 case F32:
808 return RoundTripFpToString(Get<float>(multi_index, shape_index));
809 case BF16:
810 return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
811 case F64:
812 return RoundTripFpToString(Get<double>(multi_index, shape_index));
813 case C64: {
814 complex64 c = Get<complex64>(multi_index, shape_index);
815 return StrCat("(", RoundTripFpToString(c.real()), ", ",
816 RoundTripFpToString(c.imag()), ")");
817 }
818 case C128: {
819 complex128 c = Get<complex128>(multi_index, shape_index);
820 return StrCat("(", RoundTripFpToString(c.real()), ", ",
821 RoundTripFpToString(c.imag()), ")");
822 }
823 default:
824 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
825 }
826 }
827
GetIntegralAsS64(absl::Span<const int64> multi_index) const828 absl::optional<int64> LiteralBase::GetIntegralAsS64(
829 absl::Span<const int64> multi_index) const {
830 CHECK(LayoutUtil::IsDenseArray(shape()));
831 switch (shape().element_type()) {
832 case PRED:
833 return Get<bool>(multi_index);
834 case U8:
835 return Get<uint8>(multi_index);
836 case S32:
837 return Get<int32>(multi_index);
838 case S64:
839 return Get<int64>(multi_index);
840 case U32:
841 return Get<uint32>(multi_index);
842 case U64:
843 return Get<uint64>(multi_index);
844 default:
845 return absl::nullopt;
846 }
847 }
848
GetAsDouble(absl::Span<const int64> multi_index) const849 absl::optional<double> LiteralBase::GetAsDouble(
850 absl::Span<const int64> multi_index) const {
851 CHECK(LayoutUtil::IsDenseArray(shape()));
852 switch (shape().element_type()) {
853 case F16:
854 return static_cast<double>(Get<half>(multi_index));
855 case F32:
856 return static_cast<double>(Get<float>(multi_index));
857 case F64:
858 return Get<double>(multi_index);
859 case BF16:
860 return static_cast<double>(Get<bfloat16>(multi_index));
861 default:
862 return absl::nullopt;
863 }
864 }
865
GetAsComplex128(absl::Span<const int64> multi_index) const866 absl::optional<complex128> LiteralBase::GetAsComplex128(
867 absl::Span<const int64> multi_index) const {
868 switch (shape().element_type()) {
869 case BF16:
870 return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
871 case F16:
872 return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
873 case F32:
874 return {{Get<float>(multi_index), 0}};
875 case F64:
876 return {{Get<double>(multi_index), 0}};
877 case C64:
878 return {Get<complex64>(multi_index)};
879 case C128:
880 return {Get<complex128>(multi_index)};
881 case S8:
882 return {Get<int8>(multi_index)};
883 default:
884 return absl::nullopt;
885 }
886 }
887
Hash() const888 size_t LiteralBase::Hash() const {
889 using tensorflow::Hash64;
890 using tensorflow::Hash64Combine;
891
892 size_t hash_value = ShapeUtil::Hash(shape());
893
894 ShapeUtil::ForEachSubshape(
895 shape(), [&](const Shape& subshape, const ShapeIndex& index) {
896 if (!subshape.IsArray()) {
897 return;
898 }
899
900 CHECK(LayoutUtil::IsDense(subshape.layout()));
901 hash_value = Hash64Combine(
902 hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
903 size_bytes(index)));
904 });
905
906 return hash_value;
907 }
908
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)909 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
910 int64 value) {
911 CHECK(LayoutUtil::IsDenseArray(shape()));
912 switch (shape().element_type()) {
913 case PRED:
914 Set<bool>(multi_index, value);
915 break;
916 case U8:
917 Set<uint8>(multi_index, value);
918 break;
919 case S32:
920 Set<int32>(multi_index, value);
921 break;
922 case S64:
923 Set<int64>(multi_index, value);
924 break;
925 case U32:
926 Set<uint32>(multi_index, value);
927 break;
928 case U64:
929 Set<uint64>(multi_index, value);
930 break;
931 default:
932 return FailedPrecondition("Array element type is not integral: %s",
933 PrimitiveType_Name(shape().element_type()));
934 }
935 return Status::OK();
936 }
937
SetFromDouble(absl::Span<const int64> multi_index,double value)938 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
939 double value) {
940 CHECK(LayoutUtil::IsDenseArray(shape()));
941 switch (shape().element_type()) {
942 case F16:
943 Set<half>(multi_index, Eigen::half(value));
944 break;
945 case F32:
946 Set<float>(multi_index, value);
947 break;
948 case F64:
949 Set<double>(multi_index, value);
950 break;
951 case BF16:
952 Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
953 break;
954 default:
955 return FailedPrecondition("Array element type is not floating: %s",
956 PrimitiveType_Name(shape().element_type()));
957 }
958 return Status::OK();
959 }
960
961 namespace {
962
ShapeToString(bool print_layout,const Shape & shape)963 string ShapeToString(bool print_layout, const Shape& shape) {
964 return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
965 : ShapeUtil::HumanString(shape);
966 }
967
968 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
969 bool print_shape, bool print_layout,
970 std::vector<string>* pieces);
971
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)972 void TupleToStringHelper(const LiteralBase& literal,
973 const ShapeIndex& shape_index, bool print_shape,
974 bool print_layout, std::vector<string>* pieces) {
975 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
976 pieces->push_back("(\n");
977 std::vector<string> tuple_pieces;
978 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
979 ShapeIndex element_index = shape_index;
980 element_index.push_back(i);
981 std::vector<string> element_pieces;
982 ToStringHelper(literal, element_index, print_shape, print_layout,
983 &element_pieces);
984 tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
985 }
986 pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
987 pieces->push_back("\n)");
988 }
989
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)990 void DenseArrayToStringHelper(const LiteralBase& literal,
991 const ShapeIndex& shape_index, bool print_shape,
992 bool print_layout, std::vector<string>* pieces) {
993 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
994 int64 rank = subshape.rank();
995
996 std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
997 to_string_recursive = [&](absl::Span<const int64> dimensions,
998 std::vector<int64>* accum_indices) {
999 // dimensions.size() decreases by 1 at each recursive call,
1000 // and accum_indices->size() increases by 1.
1001 // Their sum is equal to the rank of the tensor.
1002 CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1003
1004 auto brace_to_string = [&](string brace) -> string {
1005 // Handle 1D tensor
1006 if (rank == 1) {
1007 return brace;
1008 }
1009 // Handle the innermost tensor of a 2D+ tensor.
1010 if (dimensions.size() == 1 && brace == "{") {
1011 return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
1012 }
1013 if (dimensions.size() == 1 && brace == "}") {
1014 return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1015 }
1016 // Handle the non-innermost tensors of a 2D+ tensor.
1017 if (brace == "{") {
1018 if (rank > 3 && !accum_indices->empty() &&
1019 accum_indices->size() < rank) {
1020 int index = accum_indices->size() - 1;
1021 int value = accum_indices->back();
1022 return StrCat(brace, " /*i", index, "=", value, "*/\n");
1023 }
1024 return StrCat(brace, "\n");
1025 }
1026 return StrCat("\n", brace);
1027 };
1028
1029 if (dimensions.empty()) {
1030 // Display predicates as 0s and 1s so that the string is more dense.
1031 string elem;
1032 if (subshape.element_type() == PRED && rank > 0) {
1033 elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1034 } else {
1035 elem = literal.GetAsString(*accum_indices, shape_index);
1036 }
1037 pieces->push_back(elem);
1038 } else {
1039 pieces->push_back(brace_to_string("{"));
1040 for (int i = 0; i < dimensions[0]; ++i) {
1041 std::vector<int64> cloned_indices(*accum_indices);
1042 cloned_indices.push_back(i);
1043 to_string_recursive(dimensions.subspan(1), &cloned_indices);
1044 if (i < dimensions[0] - 1) {
1045 pieces->push_back(",");
1046 pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1047 }
1048 }
1049 pieces->push_back(brace_to_string("}"));
1050 }
1051 };
1052
1053 if (print_shape) {
1054 pieces->push_back(ShapeToString(print_layout, subshape));
1055 pieces->push_back(" ");
1056 }
1057 std::vector<int64> indices = {};
1058 std::vector<int64> dimensions(subshape.dimensions().begin(),
1059 subshape.dimensions().end());
1060 to_string_recursive(dimensions, &indices);
1061 }
1062
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1063 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1064 bool print_shape, bool print_layout,
1065 std::vector<string>* pieces) {
1066 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1067 CHECK(LayoutUtil::HasLayout(literal.shape()));
1068 CHECK(LayoutUtil::HasLayout(subshape));
1069 if (subshape.IsTuple()) {
1070 TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1071 pieces);
1072 } else if (subshape.IsToken()) {
1073 pieces->push_back("token");
1074 } else {
1075 CHECK(LayoutUtil::IsDenseArray(subshape));
1076 DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1077 pieces);
1078 }
1079 }
1080
1081 } // namespace
1082
ToString() const1083 string LiteralBase::ToString() const {
1084 std::vector<string> pieces;
1085 CHECK(LayoutUtil::HasLayout(this->shape()));
1086 ToStringHelper(*this, {}, /*print_shape=*/true,
1087 /*print_layout=*/false, &pieces);
1088 return absl::StrJoin(pieces, "");
1089 }
1090
ToStringWithoutShape() const1091 string LiteralBase::ToStringWithoutShape() const {
1092 std::vector<string> pieces;
1093 CHECK(LayoutUtil::HasLayout(this->shape()));
1094 ToStringHelper(*this, {}, /*print_shape=*/false,
1095 /*print_layout=*/false, &pieces);
1096 return absl::StrJoin(pieces, "");
1097 }
1098
ToStringWithLayout() const1099 string LiteralBase::ToStringWithLayout() const {
1100 std::vector<string> pieces;
1101 CHECK(LayoutUtil::HasLayout(this->shape()));
1102 ToStringHelper(*this, {}, /*print_shape=*/true,
1103 /*print_layout=*/true, &pieces);
1104 return absl::StrJoin(pieces, "");
1105 }
1106
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1107 void LiteralBase::EachCellAsString(
1108 const std::function<void(absl::Span<const int64> indices,
1109 const string& value)>& per_cell) const {
1110 if (ShapeUtil::IsZeroElementArray(shape())) {
1111 return;
1112 }
1113 std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1114 shape(), /*linear_index=*/0);
1115 do {
1116 per_cell(indices, GetAsString(indices));
1117 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1118 }
1119
1120 namespace {
1121 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1122 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1123 const ConverterType& converter) {
1124 CHECK(src_literal.shape().IsArray());
1125 Literal result_literal(ShapeUtil::ChangeElementType(
1126 src_literal.shape(),
1127 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1128 auto src_data = src_literal.data<NativeSrcT>();
1129 auto dest_data = result_literal.template data<NativeDestT>();
1130 int64 num_elements = src_literal.element_count();
1131
1132 for (int64 i = 0; i < num_elements; ++i) {
1133 dest_data[i] = converter(src_data[i]);
1134 }
1135 return result_literal;
1136 }
1137
1138 template <typename NativeSrcT, typename NativeDestT>
1139 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1140 (std::is_same<NativeDestT, complex64>::value ||
1141 std::is_same<NativeDestT, complex128>::value),
1142 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1143 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1144 auto converter = [](NativeSrcT src) {
1145 return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1146 };
1147 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1148 src_literal, converter);
1149 }
1150
1151 template <typename NativeSrcT, typename NativeDestT>
1152 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1153 (!std::is_same<NativeDestT, complex64>::value &&
1154 !std::is_same<NativeDestT, complex128>::value),
1155 Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1156 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1157 auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1158 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1159 src_literal, converter);
1160 }
1161
1162 template <typename NativeSrcT, typename NativeDestT>
1163 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1164 !std::is_same<NativeDestT, Eigen::half>::value),
1165 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1166 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1167 auto converter = [](NativeSrcT src) {
1168 return absl::bit_cast<NativeDestT>(GetRawValue(src));
1169 };
1170 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1171 src_literal, converter);
1172 }
1173
1174 template <typename NativeSrcT, typename NativeDestT>
1175 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1176 std::is_same<NativeDestT, Eigen::half>::value),
1177 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1178 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1179 // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1180 // cast to unsigned short and then use raw_uint16_to_half.
1181 auto converter = [](NativeSrcT src) {
1182 return Eigen::half_impl::raw_uint16_to_half(
1183 absl::bit_cast<uint16>(GetRawValue(src)));
1184 };
1185 return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1186 src_literal, converter);
1187 }
1188
1189 // This template specialization is here to make the compiler happy. bit_cast has
1190 // a static check that the types are the same size. This specialization should
1191 // never be used because the source and destination types are checked for
1192 // identical sizes higher up.
1193 template <typename NativeSrcT, typename NativeDestT>
1194 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1195 Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1196 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1197 LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1198 }
1199
1200 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1201 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1202 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1203 if (bitcast) {
1204 return BitcastBetweenNativeTypes<
1205 typename primitive_util::PrimitiveTypeToNative<
1206 primitive_src_type>::type,
1207 typename primitive_util::PrimitiveTypeToNative<
1208 primitive_dest_type>::type>(src_literal);
1209 } else {
1210 return ConvertBetweenNativeTypes<
1211 typename primitive_util::PrimitiveTypeToNative<
1212 primitive_src_type>::type,
1213 typename primitive_util::PrimitiveTypeToNative<
1214 primitive_dest_type>::type>(src_literal);
1215 }
1216 }
1217
1218 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1219 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1220 PrimitiveType primitive_dest_type,
1221 bool bitcast) {
1222 switch (primitive_dest_type) {
1223 #define CONVERT_IF_TYPES_MATCH(type) \
1224 case (type): \
1225 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1226 bitcast);
1227 CONVERT_IF_TYPES_MATCH(PRED)
1228 CONVERT_IF_TYPES_MATCH(S8)
1229 CONVERT_IF_TYPES_MATCH(S16)
1230 CONVERT_IF_TYPES_MATCH(S32)
1231 CONVERT_IF_TYPES_MATCH(S64)
1232 CONVERT_IF_TYPES_MATCH(U8)
1233 CONVERT_IF_TYPES_MATCH(U16)
1234 CONVERT_IF_TYPES_MATCH(U32)
1235 CONVERT_IF_TYPES_MATCH(U64)
1236 CONVERT_IF_TYPES_MATCH(F16)
1237 CONVERT_IF_TYPES_MATCH(F32)
1238 CONVERT_IF_TYPES_MATCH(F64)
1239 CONVERT_IF_TYPES_MATCH(BF16)
1240 #undef CONVERT_IF_TYPES_MATCH
1241 case C64:
1242 if (bitcast) {
1243 break;
1244 }
1245 return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1246 case C128:
1247 if (bitcast) {
1248 break;
1249 }
1250 return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1251 // Other types are not yet supported.
1252 default:
1253 break;
1254 }
1255 return Unimplemented("Converting from type %s to type %s is not implemented.",
1256 PrimitiveType_Name(src_literal.shape().element_type()),
1257 PrimitiveType_Name(primitive_dest_type));
1258 }
1259
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1260 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1261 PrimitiveType primitive_dest_type,
1262 bool bitcast) {
1263 TF_RET_CHECK(literal.shape().IsArray());
1264 if (literal.shape().element_type() == primitive_dest_type) {
1265 return literal.Clone();
1266 }
1267 switch (literal.shape().element_type()) {
1268 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1269 case (type): \
1270 return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1271 bitcast);
1272 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1273 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1274 CONVERT_IF_DEST_TYPE_MATCHES(S16)
1275 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1276 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1277 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1278 CONVERT_IF_DEST_TYPE_MATCHES(U16)
1279 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1280 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1281 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1282 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1283 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1284 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1285 #undef CONVERT_IF_DEST_TYPE_MATCHES
1286 // Other types are not yet supported.
1287 default:
1288 return Unimplemented("%s from type %s to type %s is not implemented.",
1289 (bitcast ? "Bitcast converting" : "Converting"),
1290 PrimitiveType_Name(literal.shape().element_type()),
1291 PrimitiveType_Name(primitive_dest_type));
1292 }
1293 }
1294
1295 } // namespace
1296
Convert(PrimitiveType primitive_dest_type) const1297 StatusOr<Literal> LiteralBase::Convert(
1298 PrimitiveType primitive_dest_type) const {
1299 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1300 }
1301
BitcastConvert(PrimitiveType primitive_dest_type) const1302 StatusOr<Literal> LiteralBase::BitcastConvert(
1303 PrimitiveType primitive_dest_type) const {
1304 if (primitive_util::BitWidth(shape().element_type()) !=
1305 primitive_util::BitWidth(primitive_dest_type)) {
1306 return InvalidArgument(
1307 "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1308 "%d",
1309 PrimitiveType_Name(shape().element_type()),
1310 PrimitiveType_Name(primitive_dest_type),
1311 primitive_util::BitWidth(shape().element_type()),
1312 primitive_util::BitWidth(primitive_dest_type));
1313 }
1314 return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1315 }
1316
ConvertToShape(const Shape & dest_shape) const1317 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1318 if (!dest_shape.IsTuple()) {
1319 return Convert(dest_shape.element_type());
1320 }
1321 std::vector<Literal> elements;
1322 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1323 auto element = LiteralSlice(*this, {i});
1324 TF_ASSIGN_OR_RETURN(
1325 auto new_element,
1326 element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1327 elements.push_back(std::move(new_element));
1328 }
1329 return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1330 }
1331
MoveIntoTuple(absl::Span<Literal> elements)1332 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1333 absl::Span<Literal> elements) {
1334 std::vector<Shape> element_shapes;
1335 for (const Literal& element : elements) {
1336 element_shapes.push_back(element.shape());
1337 }
1338 Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1339 /*allocate_arrays=*/false);
1340 for (int i = 0; i < elements.size(); ++i) {
1341 TF_CHECK_OK(
1342 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1343 }
1344 return literal;
1345 }
1346
1347 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1348 bool LiteralBase::Piece::EqualElementsInternal(
1349 const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1350 if (multi_index->size() == subshape().rank()) {
1351 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1352 }
1353 for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
1354 multi_index->push_back(i);
1355 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1356 return false;
1357 }
1358 multi_index->pop_back();
1359 }
1360 return true;
1361 }
1362
EqualElements(const LiteralBase::Piece & other) const1363 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1364 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1365
1366 if (ShapeUtil::Equal(subshape(), other.subshape()) &&
1367 LayoutUtil::IsDenseArray(subshape())) {
1368 CHECK_EQ(size_bytes(), other.size_bytes());
1369 return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1370 }
1371
1372 std::vector<int64> multi_index;
1373 switch (subshape().element_type()) {
1374 case PRED:
1375 return EqualElementsInternal<bool>(other, &multi_index);
1376 case U8:
1377 return EqualElementsInternal<uint8>(other, &multi_index);
1378 case S16:
1379 return EqualElementsInternal<int16>(other, &multi_index);
1380 case S32:
1381 return EqualElementsInternal<int32>(other, &multi_index);
1382 case S64:
1383 return EqualElementsInternal<int64>(other, &multi_index);
1384 case U16:
1385 return EqualElementsInternal<uint16>(other, &multi_index);
1386 case U32:
1387 return EqualElementsInternal<uint32>(other, &multi_index);
1388 case U64:
1389 return EqualElementsInternal<uint64>(other, &multi_index);
1390 case F32:
1391 return EqualElementsInternal<float>(other, &multi_index);
1392 case F64:
1393 return EqualElementsInternal<double>(other, &multi_index);
1394 case F16:
1395 return EqualElementsInternal<half>(other, &multi_index);
1396 case BF16:
1397 return EqualElementsInternal<bfloat16>(other, &multi_index);
1398 case C64:
1399 return EqualElementsInternal<complex64>(other, &multi_index);
1400 case C128:
1401 return EqualElementsInternal<complex128>(other, &multi_index);
1402 default:
1403 LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1404 << PrimitiveType_Name(subshape().element_type());
1405 }
1406 }
1407
operator ==(const LiteralBase & other) const1408 bool LiteralBase::operator==(const LiteralBase& other) const {
1409 if (!ShapeUtil::Compatible(shape(), other.shape())) {
1410 return false;
1411 }
1412
1413 return root_piece().ForEachSubpieceWithBool(
1414 [&](const ShapeIndex& index, const Piece& piece) {
1415 if (!piece.subshape().IsArray()) {
1416 return true;
1417 }
1418
1419 const Piece& other_piece = other.piece(index);
1420 if (!piece.EqualElements(other_piece)) {
1421 return false;
1422 }
1423 return true;
1424 });
1425 }
1426
1427 namespace {
1428
1429 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1430 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1431 NativeT value) {
1432 for (int64 i = 0; i < data.size(); ++i) {
1433 if (data[i] != value) {
1434 return false;
1435 }
1436 }
1437 return true;
1438 }
1439
1440 } // namespace
1441
IsAll(int8 value) const1442 bool LiteralBase::IsAll(int8 value) const {
1443 return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1444 const Piece& piece) {
1445 if (!piece.subshape().IsArray()) {
1446 return true;
1447 }
1448
1449 auto piece_is_all = [&]() {
1450 switch (shape().element_type()) {
1451 case U8:
1452 if (value >= 0) {
1453 return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1454 }
1455 return false;
1456 case U16:
1457 if (value >= 0) {
1458 return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1459 }
1460 return false;
1461 case U32:
1462 if (value >= 0) {
1463 return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1464 }
1465 return false;
1466 case U64:
1467 if (value >= 0) {
1468 return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1469 }
1470 return false;
1471 case S8:
1472 return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1473 case S16:
1474 return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1475 case S32:
1476 return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1477 case S64:
1478 return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1479 case F32:
1480 return AllElementsEqualValue<float>(piece.data<float>(), value);
1481 case F64:
1482 return AllElementsEqualValue<double>(piece.data<double>(), value);
1483 case F16:
1484 return AllElementsEqualValue<half>(piece.data<half>(),
1485 static_cast<half>(value));
1486 case BF16:
1487 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1488 static_cast<bfloat16>(value));
1489 case PRED:
1490 if (value == 0) {
1491 return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1492 }
1493 if (value == 1) {
1494 return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1495 }
1496 return false;
1497 default:
1498 return false;
1499 }
1500 return false;
1501 };
1502
1503 if (!piece_is_all()) {
1504 return false;
1505 }
1506 return true;
1507 });
1508 }
1509
IsAllFloat(float value) const1510 bool LiteralBase::IsAllFloat(float value) const {
1511 return root_piece().ForEachSubpieceWithBool(
1512 [&](const ShapeIndex& index, const Piece& piece) {
1513 if (!piece.subshape().IsArray()) {
1514 return true;
1515 }
1516
1517 switch (shape().element_type()) {
1518 case F32:
1519 return AllElementsEqualValue<float>(piece.data<float>(), value);
1520 case F64:
1521 return AllElementsEqualValue<double>(piece.data<double>(), value);
1522 case F16:
1523 return AllElementsEqualValue<half>(piece.data<half>(),
1524 static_cast<half>(value));
1525 case BF16:
1526 return AllElementsEqualValue<bfloat16>(
1527 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1528 default:
1529 return false;
1530 }
1531 });
1532 }
1533
IsAllComplex(complex64 value) const1534 bool LiteralBase::IsAllComplex(complex64 value) const {
1535 switch (shape().element_type()) {
1536 case C64:
1537 return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1538 value);
1539 case C128:
1540 return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1541 value);
1542 default:
1543 return false;
1544 }
1545 }
1546
IsAllFirst() const1547 bool LiteralBase::IsAllFirst() const {
1548 return root_piece().ForEachSubpieceWithBool(
1549 [&](const ShapeIndex& index, const Piece& piece) {
1550 if (!piece.subshape().IsArray()) {
1551 return true;
1552 }
1553
1554 // Empty shapes are not all the first element since there is no first
1555 // element.
1556 if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1557 return false;
1558 }
1559 auto piece_is_all = [&]() {
1560 switch (piece.subshape().element_type()) {
1561 case PRED: {
1562 auto data = piece.data<bool>();
1563 return AllElementsEqualValue<bool>(data, data[0]);
1564 }
1565 // 8 bit types
1566 case S8: {
1567 auto data = piece.data<int8>();
1568 return AllElementsEqualValue<int8>(data, data[0]);
1569 }
1570 case U8: {
1571 auto data = piece.data<uint8>();
1572 return AllElementsEqualValue<uint8>(data, data[0]);
1573 }
1574 // 16 bit types
1575 case BF16: {
1576 auto data = piece.data<bfloat16>();
1577 return AllElementsEqualValue<bfloat16>(data, data[0]);
1578 }
1579 case F16: {
1580 auto data = piece.data<half>();
1581 return AllElementsEqualValue<half>(data, data[0]);
1582 }
1583 case S16: {
1584 auto data = piece.data<int16>();
1585 return AllElementsEqualValue<int16>(data, data[0]);
1586 }
1587 case U16: {
1588 auto data = piece.data<uint16>();
1589 return AllElementsEqualValue<uint16>(data, data[0]);
1590 }
1591 // 32 bit types
1592 case F32: {
1593 auto data = piece.data<float>();
1594 return AllElementsEqualValue<float>(data, data[0]);
1595 }
1596 case U32: {
1597 auto data = piece.data<uint32>();
1598 return AllElementsEqualValue<uint32>(data, data[0]);
1599 }
1600 case S32: {
1601 auto data = piece.data<int32>();
1602 return AllElementsEqualValue<int32>(data, data[0]);
1603 }
1604 // 64 bit types
1605 case C64: {
1606 auto data = piece.data<complex64>();
1607 return AllElementsEqualValue<complex64>(data, data[0]);
1608 }
1609 case F64: {
1610 auto data = piece.data<double>();
1611 return AllElementsEqualValue<double>(data, data[0]);
1612 }
1613 case S64: {
1614 auto data = piece.data<int64>();
1615 return AllElementsEqualValue<int64>(data, data[0]);
1616 }
1617 case U64: {
1618 auto data = piece.data<uint64>();
1619 return AllElementsEqualValue<uint64>(data, data[0]);
1620 }
1621
1622 case C128: {
1623 auto data = piece.data<complex128>();
1624 return AllElementsEqualValue<complex128>(data, data[0]);
1625 }
1626 default:
1627 return false;
1628 }
1629 };
1630
1631 if (!piece_is_all()) {
1632 return false;
1633 }
1634 return true;
1635 });
1636 }
1637
IsR1Iota() const1638 bool LiteralBase::IsR1Iota() const {
1639 if (!shape().IsArray()) {
1640 return false;
1641 }
1642
1643 if (shape().rank() != 1) {
1644 return false;
1645 }
1646
1647 auto is_iota_at_idx = [&](const int64 idx) {
1648 switch (shape().element_type()) {
1649 case U8:
1650 return Get<uint8>({idx}) == idx;
1651 case U16:
1652 return Get<uint16>({idx}) == idx;
1653 case U32:
1654 return Get<uint32>({idx}) == idx;
1655 case U64:
1656 return Get<uint64>({idx}) == idx;
1657 case S8:
1658 return Get<int8>({idx}) == idx;
1659 case S16:
1660 return Get<int16>({idx}) == idx;
1661 case S32:
1662 return Get<int32>({idx}) == idx;
1663 case S64:
1664 return Get<int64>({idx}) == idx;
1665 case F32:
1666 return Get<float>({idx}) == idx;
1667 case F64:
1668 return Get<double>({idx}) == idx;
1669 case F16:
1670 return Get<half>({idx}) == static_cast<half>(idx);
1671 case BF16:
1672 return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1673 case C64:
1674 return Get<complex64>({idx}) == complex64(idx, 0.0f);
1675 case C128:
1676 return Get<complex128>({idx}) == complex128(idx, 0.0f);
1677 // pred, token, opaque, tuple, etc. are all not iota.
1678 default:
1679 return false;
1680 }
1681 };
1682
1683 const int64 elements = ShapeUtil::ElementsIn(shape());
1684 for (int64 idx = 0; idx < elements; ++idx) {
1685 if (!is_iota_at_idx(idx)) {
1686 return false;
1687 }
1688 }
1689
1690 return true;
1691 }
1692
IsZero(absl::Span<const int64> indices) const1693 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1694 CHECK(shape().IsArray());
1695 switch (shape().element_type()) {
1696 case U8:
1697 return Get<uint8>(indices) == 0;
1698 case U16:
1699 return Get<uint16>(indices) == 0;
1700 case U32:
1701 return Get<uint32>(indices) == 0;
1702 case U64:
1703 return Get<uint64>(indices) == 0;
1704 case S8:
1705 return Get<int8>(indices) == 0;
1706 case S16:
1707 return Get<int16>(indices) == 0;
1708 case S32:
1709 return Get<int32>(indices) == 0;
1710 case S64:
1711 return Get<int64>(indices) == 0;
1712 case F32:
1713 return Get<float>(indices) == 0.0f;
1714 case F64:
1715 return Get<double>(indices) == 0.0;
1716 case C64:
1717 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1718 case C128:
1719 return Get<complex128>(indices) == complex128(0.0f, 0.0f);
1720 case F16:
1721 return Get<half>(indices) == static_cast<half>(0.0f);
1722 case BF16:
1723 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
1724 case PRED:
1725 return Get<bool>(indices) == false;
1726 default:
1727 LOG(FATAL) << "Input literal must be an array.";
1728 }
1729 }
1730
1731 namespace {
1732
1733 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)1734 void CopyToRepeatedField(RepeatedFieldT* dest,
1735 const absl::Span<const NativeT> src) {
1736 *dest = RepeatedFieldT(src.begin(), src.end());
1737 }
1738
1739 } // namespace
1740
WriteToProto(LiteralProto * proto) const1741 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
1742 *proto->mutable_shape() = subshape().ToProto();
1743 switch (subshape().element_type()) {
1744 case PRED:
1745 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
1746 break;
1747 case S8:
1748 proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
1749 element_count());
1750 break;
1751 case U8:
1752 proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
1753 element_count());
1754 break;
1755 case U32:
1756 CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
1757 break;
1758 case U64:
1759 CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
1760 break;
1761 case S32:
1762 CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
1763 break;
1764 case S64:
1765 CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
1766 break;
1767 case U16:
1768 *proto->mutable_u16s() = string(
1769 reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
1770 if (!kLittleEndian) {
1771 ConvertEndianShort(proto->mutable_u16s());
1772 }
1773 break;
1774 case S16:
1775 *proto->mutable_s16s() = string(
1776 reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
1777 if (!kLittleEndian) {
1778 ConvertEndianShort(proto->mutable_s16s());
1779 }
1780 break;
1781 case F16:
1782 *proto->mutable_f16s() = string(
1783 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
1784 if (!kLittleEndian) {
1785 ConvertEndianShort(proto->mutable_f16s());
1786 }
1787 break;
1788 case BF16:
1789 *proto->mutable_bf16s() = string(
1790 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
1791 if (!kLittleEndian) {
1792 ConvertEndianShort(proto->mutable_bf16s());
1793 }
1794 break;
1795 case F32:
1796 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
1797 break;
1798 case F64:
1799 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
1800 break;
1801 case C64:
1802 for (complex64 value : data<complex64>()) {
1803 proto->add_c64s(value.real());
1804 proto->add_c64s(value.imag());
1805 }
1806 break;
1807 case C128:
1808 for (complex128 value : data<complex128>()) {
1809 proto->add_c128s(value.real());
1810 proto->add_c128s(value.imag());
1811 }
1812 break;
1813 case TUPLE:
1814 case TOKEN:
1815 // Nothing to do but assign the shape which is done above.
1816 return;
1817 default:
1818 // TODO(b/111551621): Support serializing more PrimitiveTypes.
1819 LOG(FATAL) << "Unhandled primitive type "
1820 << PrimitiveType_Name(subshape().element_type());
1821 }
1822 }
1823
untyped_data() const1824 const void* LiteralBase::Piece::untyped_data() const {
1825 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1826 return buffer();
1827 }
1828
untyped_data()1829 void* LiteralBase::Piece::untyped_data() {
1830 CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1831 return buffer();
1832 }
1833
1834 namespace {
1835
1836 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)1837 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
1838 const RepeatedFieldT& src) {
1839 if (dest.size() != src.size()) {
1840 return InvalidArgument(
1841 "Expected %lu elements in LiteralProto repeated field, has %d",
1842 dest.size(), src.size());
1843 }
1844 std::copy(src.begin(), src.end(), dest.begin());
1845 return Status::OK();
1846 }
1847
1848 } // namespace
1849
CopyFromProto(const LiteralProto & proto)1850 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
1851 // These conditions should have been checked in
1852 // MutableLiteralBase::CreateFromProto.
1853 TF_RET_CHECK(proto.has_shape());
1854 Shape shape(proto.shape());
1855 TF_RET_CHECK(LayoutUtil::HasLayout(shape));
1856 TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
1857
1858 switch (subshape().element_type()) {
1859 case PRED:
1860 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
1861 break;
1862 case S8: {
1863 auto s8_data = data<int8>();
1864 TF_RET_CHECK(proto.s8s().size() == s8_data.size());
1865 std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
1866 } break;
1867 case U8: {
1868 auto u8_data = data<uint8>();
1869 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
1870 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
1871 } break;
1872 case S32:
1873 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
1874 break;
1875 case S64:
1876 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
1877 break;
1878 case U32:
1879 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
1880 break;
1881 case U64:
1882 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
1883 break;
1884 case S16: {
1885 const string& s(proto.s16s());
1886 TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
1887 memcpy(untyped_data(), s.data(), s.size());
1888 if (!kLittleEndian) {
1889 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1890 }
1891 } break;
1892 case U16: {
1893 const string& s(proto.u16s());
1894 TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
1895 memcpy(untyped_data(), s.data(), s.size());
1896 if (!kLittleEndian) {
1897 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1898 }
1899 } break;
1900 case F16: {
1901 const string& s(proto.f16s());
1902 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
1903 memcpy(untyped_data(), s.data(), s.size());
1904 if (!kLittleEndian) {
1905 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1906 }
1907 } break;
1908
1909 case BF16: {
1910 const string& s(proto.bf16s());
1911 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
1912 memcpy(untyped_data(), s.data(), s.size());
1913 if (!kLittleEndian) {
1914 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1915 }
1916 } break;
1917 case F32:
1918 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
1919 break;
1920 case F64:
1921 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
1922 break;
1923 case C64: {
1924 auto complex_data = data<complex64>();
1925 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
1926 for (int64 i = 0; i < complex_data.size(); ++i) {
1927 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
1928 }
1929 break;
1930 }
1931 case C128: {
1932 auto complex_data = data<complex128>();
1933 TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
1934 for (int64 i = 0; i < complex_data.size(); ++i) {
1935 complex_data[i] =
1936 complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
1937 }
1938 break;
1939 }
1940 case TUPLE:
1941 return InvalidArgument("Should not be called on tuple shapes: %s",
1942 ShapeUtil::HumanString(subshape()));
1943 default:
1944 return InvalidArgument("Is called on unsupported shape: %s",
1945 ShapeUtil::HumanString(subshape()));
1946 }
1947 return Status::OK();
1948 }
1949
ToProto() const1950 LiteralProto LiteralBase::ToProto() const {
1951 LiteralProto proto;
1952 root_piece().ForEachSubpiece(
1953 [&](const ShapeIndex& index, const Piece& piece) {
1954 LiteralProto* proto_piece = &proto;
1955 for (int64 i : index) {
1956 while (proto_piece->tuple_literals_size() <= i) {
1957 proto_piece->add_tuple_literals();
1958 }
1959 proto_piece = proto_piece->mutable_tuple_literals(i);
1960 }
1961 piece.WriteToProto(proto_piece);
1962 });
1963
1964 return proto;
1965 }
1966
untyped_data(const ShapeIndex & shape_index) const1967 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
1968 return piece(shape_index).untyped_data();
1969 }
1970
untyped_data(const ShapeIndex & shape_index)1971 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
1972 return piece(shape_index).untyped_data();
1973 }
1974
size_bytes(const ShapeIndex & shape_index) const1975 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
1976 return piece(shape_index).size_bytes();
1977 }
1978
GetR1U8AsString() const1979 string LiteralBase::GetR1U8AsString() const {
1980 CHECK(shape().IsArray());
1981 CHECK_EQ(shape().rank(), 1);
1982 CHECK_EQ(shape().element_type(), U8);
1983 return string(absl::bit_cast<const char*>(data<uint8>().data()),
1984 ShapeUtil::ElementsIn(shape()));
1985 }
1986
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)1987 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
1988 Piece* src_piece,
1989 Piece* dest_piece) {
1990 DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
1991 << "src_piece has shape: "
1992 << ShapeUtil::HumanString(src_piece->subshape())
1993 << "dest_piece has shape: "
1994 << ShapeUtil::HumanString(dest_piece->subshape());
1995 if (shape.IsTuple()) {
1996 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1997 const Shape& subshape = shape.tuple_shapes(i);
1998
1999 auto child_piece = Piece();
2000 child_piece.set_subshape(&subshape);
2001
2002 CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2003
2004 dest_piece->emplace_back(std::move(child_piece));
2005 }
2006 } else if (shape.IsArray()) {
2007 dest_piece->set_buffer(src_piece->buffer());
2008 } else {
2009 // If the shape is neither an array nor tuple, then it must be
2010 // zero-sized. Otherwise, some memory needs to be allocated for it.
2011 CHECK_EQ(dest_piece->size_bytes(), 0);
2012 }
2013 }
2014
~MutableLiteralBase()2015 MutableLiteralBase::~MutableLiteralBase() {}
2016
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2017 MutableBorrowingLiteral::MutableBorrowingLiteral(
2018 const MutableBorrowingLiteral& literal)
2019 : MutableLiteralBase() {
2020 shape_ = absl::make_unique<Shape>(literal.shape());
2021 CHECK(LayoutUtil::HasLayout(*shape_));
2022
2023 root_piece_ = new Piece();
2024 root_piece_->set_subshape(shape_.get());
2025
2026 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2027 }
2028
operator =(const MutableBorrowingLiteral & literal)2029 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2030 const MutableBorrowingLiteral& literal) {
2031 shape_ = absl::make_unique<Shape>(literal.shape());
2032 CHECK(LayoutUtil::HasLayout(*shape_));
2033
2034 root_piece_ = new Piece();
2035 root_piece_->set_subshape(shape_.get());
2036
2037 CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2038
2039 return *this;
2040 }
2041
MutableBorrowingLiteral(MutableLiteralBase * literal)2042 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2043 : MutableLiteralBase() {
2044 shape_ = absl::make_unique<Shape>(literal->shape());
2045 CHECK(LayoutUtil::HasLayout(*shape_));
2046
2047 root_piece_ = new Piece();
2048 root_piece_->set_subshape(shape_.get());
2049
2050 CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2051 }
2052
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2053 MutableBorrowingLiteral::MutableBorrowingLiteral(
2054 MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2055 : MutableLiteralBase() {
2056 shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2057 CHECK(LayoutUtil::HasLayout(*shape_));
2058
2059 root_piece_ = new Piece();
2060 root_piece_->set_subshape(shape_.get());
2061
2062 CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2063 }
2064
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2065 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2066 const Shape& shape)
2067 : MutableLiteralBase() {
2068 shape_ = absl::make_unique<Shape>(shape);
2069 CHECK(LayoutUtil::HasLayout(*shape_));
2070 CHECK(!shape_->IsTuple());
2071
2072 root_piece_ = new Piece();
2073 root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2074 root_piece_->set_subshape(shape_.get());
2075 }
2076
~MutableBorrowingLiteral()2077 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2078 if (root_piece_ != nullptr) {
2079 delete root_piece_;
2080 }
2081 }
2082
LiteralSlice(const LiteralBase & literal)2083 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2084 : LiteralBase(), root_piece_(&literal.root_piece()) {}
2085
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2086 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2087 const ShapeIndex& view_root)
2088 : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2089
BuildPieceSubtree(const Shape & shape,Piece * piece)2090 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2091 CHECK(shape.IsTuple());
2092 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2093 const Shape& subshape = shape.tuple_shapes(i);
2094
2095 auto child_piece = Piece();
2096 child_piece.set_subshape(&subshape);
2097
2098 if (subshape.IsTuple()) {
2099 BuildPieceSubtree(subshape, &child_piece);
2100 }
2101
2102 piece->emplace_back(std::move(child_piece));
2103 }
2104 }
2105
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2106 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2107 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2108 CHECK(shape_->IsArray());
2109 CHECK(LayoutUtil::HasLayout(*shape_));
2110
2111 root_piece_ = Piece();
2112 root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2113 root_piece_.set_subshape(shape_.get());
2114 }
2115
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2116 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2117 const Shape& shape)
2118 : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2119 CHECK(shape_->IsTuple());
2120 CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2121 CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2122 root_piece_ = Piece();
2123 root_piece_.set_subshape(shape_.get());
2124 BuildPieceSubtree(*shape_, &root_piece_);
2125
2126 for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2127 const auto& src_shape = shape_->tuple_shapes(i);
2128 CHECK(src_shape.IsArray());
2129 root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2130 }
2131 }
2132
2133 } // namespace xla
2134