1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 // The array is allocated by the given allocator. It runs T's
23 // default constructors and destructors when T is not a simple type
24 // (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T. The routines
27 // includes running the constructor and destructor of T[], encoding
28 // an decoding T[] into/from a Cord, etc.
29
30 #include "tensorflow/core/framework/tensor.h"
31
32 #include <utility>
33
34 #include "absl/strings/escaping.h"
35 #include "tensorflow/core/framework/allocation_description.pb.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/resource_handle.h"
38 #include "tensorflow/core/framework/resource_handle.pb.h"
39 #include "tensorflow/core/framework/tensor.pb.h"
40 #include "tensorflow/core/framework/tensor_description.pb.h"
41 #include "tensorflow/core/framework/type_traits.h"
42 #include "tensorflow/core/framework/typed_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/framework/variant.h"
46 #include "tensorflow/core/framework/variant_encode_decode.h"
47 #include "tensorflow/core/framework/variant_op_registry.h"
48 #include "tensorflow/core/framework/variant_tensor_data.h"
49 #include "tensorflow/core/lib/core/coding.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
53 #include "tensorflow/core/lib/strings/str_util.h"
54 #include "tensorflow/core/lib/strings/strcat.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/tensor_coding.h"
59 #include "tensorflow/core/platform/types.h"
60
61 namespace tensorflow {
62
63 // Allow Tensors to be stored inside Variants with automatic
64 // encoding/decoding when those Variants are themselves being decoded
65 // in a Tensor's FromProto.
66 //
67 // NOTE(mrry): The corresponding "copy function" registrations can be found in
68 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
69 // code).
70 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
71
GetAllocatedBytes(size_t * out_bytes) const72 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
73 AllocationDescription allocation_description;
74 FillAllocationDescription(&allocation_description);
75 if (allocation_description.allocated_bytes() > 0) {
76 *out_bytes = allocation_description.allocated_bytes();
77 return true;
78 } else {
79 return false;
80 }
81 }
82
83 namespace {
84
85 // An un-templated base class for Buffer.
86 class BufferBase : public TensorBuffer {
87 public:
BufferBase(Allocator * alloc,void * data_ptr)88 explicit BufferBase(Allocator* alloc, void* data_ptr)
89 : TensorBuffer(data_ptr), alloc_(alloc) {}
90
root_buffer()91 TensorBuffer* root_buffer() override { return this; }
92
GetAllocatedBytes(size_t * out_bytes) const93 bool GetAllocatedBytes(size_t* out_bytes) const override {
94 if (alloc_->TracksAllocationSizes()) {
95 *out_bytes = alloc_->AllocatedSize(data());
96 return *out_bytes > 0;
97 } else {
98 return false;
99 }
100 }
101
FillAllocationDescription(AllocationDescription * proto) const102 void FillAllocationDescription(AllocationDescription* proto) const override {
103 void* data_ptr = data();
104 int64_t rb = size();
105 proto->set_requested_bytes(rb);
106 proto->set_allocator_name(alloc_->Name());
107 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
108 if (alloc_->TracksAllocationSizes()) {
109 int64_t ab = alloc_->AllocatedSize(data_ptr);
110 proto->set_allocated_bytes(ab);
111 int64_t id = alloc_->AllocationId(data_ptr);
112 if (id > 0) {
113 proto->set_allocation_id(id);
114 }
115 if (RefCountIsOne()) {
116 proto->set_has_single_reference(true);
117 }
118 }
119 }
120
121 protected:
RecordDeallocation()122 void RecordDeallocation() {
123 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
124 alloc_->Name());
125 }
126
127 Allocator* const alloc_;
128 };
129
130 // Typed ref-counted buffer: T[n].
131 template <typename T>
132 class Buffer : public BufferBase {
133 public:
134 Buffer(Allocator* a, int64_t n);
135 Buffer(Allocator* a, int64_t n, const AllocationAttributes& allocation_attr);
136
size() const137 size_t size() const override { return sizeof(T) * elem_; }
138
139 private:
140 int64 elem_;
141
142 ~Buffer() override;
143
144 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
145 };
146
LogUnexpectedSize(int64_t actual,int64_t expected)147 void LogUnexpectedSize(int64_t actual, int64_t expected) {
148 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
149 }
150
MemoryLoggingEnabled()151 bool MemoryLoggingEnabled() {
152 static bool memory_logging_enabled = LogMemory::IsEnabled();
153 return memory_logging_enabled;
154 }
155
156 // A set of helper functions depending on T.
157 template <typename T>
158 struct Helper {
159 // By default, we assume T is a simple type (float, int32, etc.)
160 static_assert(is_simple_type<T>::value, "T is not a simple type.");
161 typedef protobuf::RepeatedField<T> RepeatedFieldType;
162
163 // Encoder of simple type T to a string. We do a copy.
164 template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper165 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
166 DCHECK_EQ(in->size(), sizeof(T) * n);
167 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
168 out);
169 }
170
171 // Decoder of simple type T. Copy the bytes from "in" into the
172 // tensor buffer.
173 template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper174 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
175 if (in.size() != sizeof(T) * n) {
176 LogUnexpectedSize(in.size(), sizeof(T) * n);
177 return nullptr;
178 }
179 Buffer<T>* buf = new Buffer<T>(a, n);
180 char* data = buf->template base<char>();
181 if (data == nullptr) {
182 buf->Unref();
183 return nullptr;
184 }
185 port::CopyToArray(in, data);
186 return buf;
187 }
188
189 // Memory usage.
TotalBytestensorflow::__anon6d9e61650111::Helper190 static int64 TotalBytes(TensorBuffer* in, int64_t n) {
191 DCHECK_EQ(in->size(), sizeof(T) * n);
192 return in->size();
193 }
194 };
195
196 // Helper specialization for string (the only non-simple type we
197 // support).
198 template <>
199 struct Helper<tstring> {
200 // Proto message uses RepeatedFieldType to hold repeated T.
201 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
202
203 // Encodes "n" elements of type string stored in "in" into Cord
204 // "out", which is usually the TensorProto::tensor_content.
205 template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper206 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
207 port::EncodeStringList(in->base<const tstring>(), n, out);
208 }
209
210 // Decodes "n" elements of type string from "in" and constructs a
211 // buffer out of it. Returns nullptr if the decoding fails. "in" is
212 // usually the TensorProto::tensor_content.
213 template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper214 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
215 Buffer<tstring>* buf = new Buffer<tstring>(a, n);
216 tstring* strings = buf->template base<tstring>();
217 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
218 buf->Unref();
219 return nullptr;
220 }
221 return buf;
222 }
223
224 // Returns the estimated memory usage of "n" elements of type T
225 // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper226 static int64 TotalBytes(TensorBuffer* in, int n) {
227 int64_t tot = in->size();
228 DCHECK_EQ(tot, sizeof(tstring) * n);
229 const tstring* p = in->base<const tstring>();
230 for (int i = 0; i < n; ++i, ++p) tot += p->size();
231 return tot;
232 }
233 };
234
235 template <>
236 struct Helper<ResourceHandle> {
237 // Proto message uses RepeatedFieldType to hold repeated T.
238 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
239
240 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
241 // "out", which is usually the TensorProto::tensor_content.
242 template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper243 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
244 EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
245 port::NewStringListEncoder(out));
246 }
247
248 // Decodes "n" elements of type string from "in" and constructs a
249 // buffer out of it. Returns nullptr if the decoding fails. "in" is
250 // usually the TensorProto::tensor_content.
251 template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper252 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
253 auto* buf = new Buffer<ResourceHandle>(a, n);
254 ResourceHandle* ps = buf->template base<ResourceHandle>();
255 if (ps == nullptr ||
256 !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
257 buf->Unref();
258 return nullptr;
259 }
260 return buf;
261 }
262
263 // Returns the estimated memory usage of "n" elements of type T
264 // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper265 static int64 TotalBytes(TensorBuffer* in, int n) {
266 return n * sizeof(ResourceHandle);
267 }
268 };
269
270 template <>
271 struct Helper<Variant> {
272 // Encodes "n" elements of type Variant stored in "in" into destination
273 // "out", which is usually the TensorProto::tensor_content.
274 template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper275 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
276 EncodeVariantList(in->base<const Variant>(), n,
277 port::NewStringListEncoder(out));
278 }
279
280 // Decodes "n" elements of type Variant from "in" and constructs a
281 // buffer out of it. Returns nullptr if the decoding fails. "in" is
282 // usually the TensorProto::tensor_content.
283 template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper284 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
285 auto* buf = new Buffer<Variant>(a, n);
286 Variant* ps = buf->template base<Variant>();
287 if (ps == nullptr ||
288 !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
289 buf->Unref();
290 return nullptr;
291 }
292 return buf;
293 }
294
295 // Returns the estimated memory usage of "n" elements of type T
296 // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper297 static int64 TotalBytes(TensorBuffer* in, int n) {
298 return n * sizeof(Variant);
299 }
300 };
301
302 template <typename T>
303 struct ProtoHelper {};
304
305 // For a C++ type "T" (float, double, int32, etc.), the repeated field
306 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
307 // int32, string, etc) in the TensorProto is used for serializing the
308 // tensor of type "T".
309 #define PROTO_TRAITS(T, F, N) \
310 template <> \
311 struct ProtoHelper<T> { \
312 typedef Helper<F>::RepeatedFieldType FieldType; \
313 static FieldType::const_iterator Begin(const TensorProto& proto) { \
314 return proto.N##_val().begin(); \
315 } \
316 static size_t NumElements(const TensorProto& proto) { \
317 return proto.N##_val().size(); \
318 } \
319 static void Fill(const T* data, size_t n, TensorProto* proto) { \
320 typename ProtoHelper<T>::FieldType copy(data, data + n); \
321 proto->mutable_##N##_val()->Swap(©); \
322 } \
323 };
324 PROTO_TRAITS(float, float, float);
325 PROTO_TRAITS(double, double, double);
326 PROTO_TRAITS(int32, int32, int);
327 PROTO_TRAITS(uint8, int32, int);
328 PROTO_TRAITS(uint16, int32, int);
329 PROTO_TRAITS(uint32, uint32, uint32);
330 PROTO_TRAITS(int16, int32, int);
331 PROTO_TRAITS(int8, int32, int);
332 PROTO_TRAITS(bool, bool, bool);
333 PROTO_TRAITS(tstring, tstring, string);
334 PROTO_TRAITS(qint8, int32, int);
335 PROTO_TRAITS(quint8, int32, int);
336 PROTO_TRAITS(qint16, int32, int);
337 PROTO_TRAITS(quint16, int32, int);
338 #undef PROTO_TRAITS
339
340 template <>
341 struct ProtoHelper<int64> {
Begintensorflow::__anon6d9e61650111::ProtoHelper342 static const int64* Begin(const TensorProto& proto) {
343 return reinterpret_cast<const int64*>(proto.int64_val().begin());
344 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper345 static size_t NumElements(const TensorProto& proto) {
346 return proto.int64_val().size();
347 }
Filltensorflow::__anon6d9e61650111::ProtoHelper348 static void Fill(const int64* data, size_t n, TensorProto* proto) {
349 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
350 proto->mutable_int64_val()->Swap(©);
351 }
352 };
353
354 template <>
355 struct ProtoHelper<uint64> {
Begintensorflow::__anon6d9e61650111::ProtoHelper356 static const uint64* Begin(const TensorProto& proto) {
357 return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
358 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper359 static size_t NumElements(const TensorProto& proto) {
360 return proto.uint64_val().size();
361 }
Filltensorflow::__anon6d9e61650111::ProtoHelper362 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
363 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
364 proto->mutable_uint64_val()->Swap(©);
365 }
366 };
367
368 template <>
369 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon6d9e61650111::ProtoHelper370 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
371 const TensorProto& proto) {
372 return proto.resource_handle_val().begin();
373 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper374 static size_t NumElements(const TensorProto& proto) {
375 return proto.resource_handle_val().size();
376 }
Filltensorflow::__anon6d9e61650111::ProtoHelper377 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
378 auto* handles = proto->mutable_resource_handle_val();
379 handles->Clear();
380 for (size_t i = 0; i < n; i++) {
381 data[i].AsProto(handles->Add());
382 }
383 }
384 };
385
386 template <>
387 struct ProtoHelper<Variant> {
388 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon6d9e61650111::ProtoHelper389 Begin(const TensorProto& proto) {
390 return proto.variant_val().begin();
391 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper392 static size_t NumElements(const TensorProto& proto) {
393 return proto.variant_val().size();
394 }
Filltensorflow::__anon6d9e61650111::ProtoHelper395 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
396 auto* variant_values = proto->mutable_variant_val();
397 variant_values->Clear();
398 for (size_t i = 0; i < n; ++i) {
399 VariantTensorData tmp;
400 data[i].Encode(&tmp);
401 tmp.ToProto(variant_values->Add());
402 }
403 }
404 };
405
406 template <>
407 struct ProtoHelper<complex64> {
408 typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper409 static const complex64* Begin(const TensorProto& proto) {
410 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
411 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper412 static size_t NumElements(const TensorProto& proto) {
413 return proto.scomplex_val().size() / 2;
414 }
Filltensorflow::__anon6d9e61650111::ProtoHelper415 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
416 const float* p = reinterpret_cast<const float*>(data);
417 FieldType copy(p, p + n * 2);
418 proto->mutable_scomplex_val()->Swap(©);
419 }
420 };
421
422 template <>
423 struct ProtoHelper<complex128> {
424 typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper425 static const complex128* Begin(const TensorProto& proto) {
426 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
427 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper428 static size_t NumElements(const TensorProto& proto) {
429 return proto.dcomplex_val().size() / 2;
430 }
Filltensorflow::__anon6d9e61650111::ProtoHelper431 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
432 const double* p = reinterpret_cast<const double*>(data);
433 FieldType copy(p, p + n * 2);
434 proto->mutable_dcomplex_val()->Swap(©);
435 }
436 };
437
438 template <>
439 struct ProtoHelper<qint32> {
440 typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper441 static const qint32* Begin(const TensorProto& proto) {
442 return reinterpret_cast<const qint32*>(proto.int_val().data());
443 }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper444 static size_t NumElements(const TensorProto& proto) {
445 return proto.int_val().size();
446 }
Filltensorflow::__anon6d9e61650111::ProtoHelper447 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
448 const int32* p = reinterpret_cast<const int32*>(data);
449 FieldType copy(p, p + n);
450 proto->mutable_int_val()->Swap(©);
451 }
452 };
453
454 template <>
455 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon6d9e61650111::ProtoHelper456 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
457 proto->mutable_half_val()->Reserve(n);
458 for (size_t i = 0; i < n; ++i) {
459 proto->mutable_half_val()->AddAlreadyReserved(
460 Eigen::numext::bit_cast<uint16>(data[i]));
461 }
462 }
463 };
464
465 template <>
466 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon6d9e61650111::ProtoHelper467 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
468 proto->mutable_half_val()->Reserve(n);
469 for (size_t i = 0; i < n; ++i) {
470 proto->mutable_half_val()->AddAlreadyReserved(
471 Eigen::numext::bit_cast<uint16>(data[i]));
472 }
473 }
474 };
475
476 template <typename T>
Buffer(Allocator * a,int64_t n)477 Buffer<T>::Buffer(Allocator* a, int64_t n)
478 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
479 elem_(n) {}
480
481 template <typename T>
Buffer(Allocator * a,int64_t n,const AllocationAttributes & allocation_attr)482 Buffer<T>::Buffer(Allocator* a, int64_t n,
483 const AllocationAttributes& allocation_attr)
484 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
485 elem_(n) {}
486
487 template <typename T>
~Buffer()488 Buffer<T>::~Buffer() {
489 if (data()) {
490 if (MemoryLoggingEnabled()) {
491 RecordDeallocation();
492 }
493 TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
494 }
495 }
496
497 // Allocates a T[n] buffer. Fills in the buffer with repeated values
498 // in "in". If "in" has less values than "n", fills the rest of T[n]
499 // with the last value. If "in" has no values, fills T[n] with the
500 // default value for T.
501 //
502 // This routine is using the typed fields (float_val, etc.) in the
503 // tensor proto as opposed to the untyped binary representation
504 // (tensor_content). This is used when we expect the TensorProto is
505 // used by a client program which may not know how to encode a tensor
506 // in the compact binary representation.
507 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)508 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) {
509 CHECK_GT(n, 0);
510 Buffer<T>* buf = new Buffer<T>(a, n);
511 T* data = buf->template base<T>();
512 if (data == nullptr) {
513 buf->Unref();
514 return nullptr;
515 }
516
517 const int64_t in_n = ProtoHelper<T>::NumElements(in);
518 if (in_n <= 0) {
519 std::fill_n(data, n, T());
520 } else {
521 auto begin = ProtoHelper<T>::Begin(in);
522 if (n <= in_n) {
523 std::copy_n(begin, n, data);
524 } else {
525 std::copy_n(begin, in_n, data);
526 if (std::is_trivially_copyable<T>::value) {
527 const T last = *(data + in_n - 1);
528 std::fill_n(data + in_n, n - in_n, last);
529 } else {
530 const T& last = *(data + in_n - 1);
531 std::fill_n(data + in_n, n - in_n, last);
532 }
533 }
534 }
535
536 return buf;
537 }
538
539 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)540 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
541 int64_t n) {
542 CHECK_GT(n, 0);
543 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
544 Variant* data = buf->template base<Variant>();
545 if (data == nullptr) {
546 buf->Unref();
547 return nullptr;
548 }
549 const int64_t in_n = ProtoHelper<Variant>::NumElements(in);
550 if (in_n <= 0) {
551 std::fill_n(data, n, Variant());
552 } else {
553 // If tensor shape says we have n < in_n elements in the output tensor
554 // then make sure to only decode the first n out of the in_n elements in the
555 // in tensors. In all other cases, we decode all in_n elements of in and set
556 // the remaining elements up to n to be the default Variant() value.
557 const int64_t real_n = n < in_n ? n : in_n;
558 for (int64_t i = 0; i < real_n; ++i) {
559 data[i] = in.variant_val(i);
560 if (!DecodeUnaryVariant(&data[i])) {
561 LOG(ERROR) << "Could not decode variant with type_name: \""
562 << data[i].TypeName()
563 << "\". Perhaps you forgot to register a "
564 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
565 buf->Unref();
566 return nullptr;
567 }
568 }
569 for (int64_t i = in_n; i < n; ++i) {
570 data[i] = Variant();
571 }
572 }
573 return buf;
574 }
575
576 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
577 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
578 // we don't use ProtoHelper<uint16>).
579 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)580 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
581 int64_t n) {
582 CHECK_GT(n, 0);
583 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
584 uint16* data = buf->template base<uint16>();
585 if (data == nullptr) {
586 buf->Unref();
587 return nullptr;
588 }
589 const int64_t in_n = in.half_val().size();
590 auto begin = in.half_val().begin();
591 if (n <= in_n) {
592 std::copy_n(begin, n, data);
593 } else if (in_n > 0) {
594 std::copy_n(begin, in_n, data);
595 const uint16 last = *(data + in_n - 1);
596 std::fill_n(data + in_n, n - in_n, last);
597 } else {
598 std::fill_n(data, n, 0);
599 }
600 return buf;
601 }
602
603 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)604 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
605 int64_t n) {
606 CHECK_GT(n, 0);
607 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
608 uint16* data = buf->template base<uint16>();
609 if (data == nullptr) {
610 buf->Unref();
611 return nullptr;
612 }
613 const int64_t in_n = in.half_val().size();
614 auto begin = in.half_val().begin();
615 if (n <= in_n) {
616 std::copy_n(begin, n, data);
617 } else if (in_n > 0) {
618 std::copy_n(begin, in_n, data);
619 const uint16 last = *(data + in_n - 1);
620 std::fill_n(data + in_n, n - in_n, last);
621 } else {
622 std::fill_n(data, n, 0);
623 }
624 return buf;
625 }
626
627 // Copies T[n] stored in the buffer "in" into the repeated field in
628 // "out" corresponding to type T.
629 template <typename T>
ToProtoField(const TensorBuffer & in,int64_t n,TensorProto * out)630 void ToProtoField(const TensorBuffer& in, int64_t n, TensorProto* out) {
631 const T* data = in.base<const T>();
632 // NOTE: T may not the same as
633 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
634 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
635 // critical, we can specialize T=float and do memcpy directly.
636 ProtoHelper<T>::Fill(data, n, out);
637 }
638
RefIfNonNull(core::RefCounted * buf)639 void RefIfNonNull(core::RefCounted* buf) {
640 if (buf) buf->Ref();
641 }
642
UnrefIfNonNull(core::RefCounted * buf)643 void UnrefIfNonNull(core::RefCounted* buf) {
644 if (buf) buf->Unref();
645 }
646
647 } // end namespace
648
Tensor()649 Tensor::Tensor() : Tensor(DT_FLOAT) {}
650
Tensor(DataType type)651 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
652
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)653 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
654 : shape_(shape), buf_(buf) {
655 set_dtype(type);
656 RefIfNonNull(buf);
657 }
658
Tensor(DataType type,TensorShape shape,core::RefCountPtr<TensorBuffer> buf)659 Tensor::Tensor(DataType type, TensorShape shape,
660 core::RefCountPtr<TensorBuffer> buf)
661 : shape_(std::move(shape)), buf_(buf.release()) {
662 set_dtype(type);
663 }
664
IsInitialized() const665 bool Tensor::IsInitialized() const {
666 return (buf_ != nullptr && buf_->data() != nullptr) ||
667 shape_.num_elements() == 0;
668 }
669
CheckType(DataType expected_dtype) const670 void Tensor::CheckType(DataType expected_dtype) const {
671 CHECK_EQ(dtype(), expected_dtype)
672 << " " << DataTypeString(expected_dtype) << " expected, got "
673 << DataTypeString(dtype());
674 }
675
CheckTypeAndIsAligned(DataType expected_dtype) const676 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
677 CHECK_EQ(dtype(), expected_dtype)
678 << " " << DataTypeString(expected_dtype) << " expected, got "
679 << DataTypeString(dtype());
680 CHECK(IsAligned()) << "ptr = " << base<void>();
681 }
682
CheckIsAlignedAndSingleElement() const683 void Tensor::CheckIsAlignedAndSingleElement() const {
684 CHECK(IsAligned()) << "Aligned and single element";
685 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
686 }
687
~Tensor()688 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
689
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)690 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
691 const TensorShape& shape) {
692 int in_size = DataTypeSize(other.dtype());
693 int out_size = DataTypeSize(dtype);
694 if (in_size == 0) {
695 return errors::InvalidArgument("other tensor has zero-sized data type");
696 }
697 if (out_size == 0) {
698 return errors::InvalidArgument("specified output type is zero-sized");
699 }
700 if (shape.num_elements() * out_size !=
701 other.shape().num_elements() * in_size) {
702 return errors::InvalidArgument(
703 "input and output shapes/data type sizes are not compatible");
704 }
705 shape_ = shape;
706 shape_.set_data_type(dtype);
707 if (buf_ != other.buf_) {
708 UnrefIfNonNull(buf_);
709 buf_ = other.buf_;
710 RefIfNonNull(buf_);
711 }
712 return Status::OK();
713 }
714
715 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
716 // For the latter case, we have to make sure that the refcount is
717 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const718 bool Tensor::RefCountIsOne() const {
719 return buf_ != nullptr && buf_->RefCountIsOne() &&
720 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
721 }
722
723 // The macro CASES() expands to a switch statement conditioned on
724 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
725 #define SINGLE_ARG(...) __VA_ARGS__
726 #define CASE(TYPE, STMTS) \
727 case DataTypeToEnum<TYPE>::value: { \
728 typedef TYPE T; \
729 STMTS; \
730 break; \
731 }
732 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
733 switch (TYPE_ENUM) { \
734 CASE(float, SINGLE_ARG(STMTS)) \
735 CASE(double, SINGLE_ARG(STMTS)) \
736 CASE(int32, SINGLE_ARG(STMTS)) \
737 CASE(uint8, SINGLE_ARG(STMTS)) \
738 CASE(uint16, SINGLE_ARG(STMTS)) \
739 CASE(uint32, SINGLE_ARG(STMTS)) \
740 CASE(uint64, SINGLE_ARG(STMTS)) \
741 CASE(int16, SINGLE_ARG(STMTS)) \
742 CASE(int8, SINGLE_ARG(STMTS)) \
743 CASE(tstring, SINGLE_ARG(STMTS)) \
744 CASE(complex64, SINGLE_ARG(STMTS)) \
745 CASE(complex128, SINGLE_ARG(STMTS)) \
746 CASE(int64, SINGLE_ARG(STMTS)) \
747 CASE(bool, SINGLE_ARG(STMTS)) \
748 CASE(qint32, SINGLE_ARG(STMTS)) \
749 CASE(quint8, SINGLE_ARG(STMTS)) \
750 CASE(qint8, SINGLE_ARG(STMTS)) \
751 CASE(quint16, SINGLE_ARG(STMTS)) \
752 CASE(qint16, SINGLE_ARG(STMTS)) \
753 CASE(bfloat16, SINGLE_ARG(STMTS)) \
754 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
755 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
756 CASE(Variant, SINGLE_ARG(STMTS)) \
757 case DT_INVALID: \
758 INVALID; \
759 break; \
760 default: \
761 DEFAULT; \
762 break; \
763 }
764
765 #define CASES(TYPE_ENUM, STMTS) \
766 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, \
767 LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
768 , LOG(FATAL) << "Type not set";)
769
Tensor(Allocator * a,DataType type,const TensorShape & shape)770 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
771 : shape_(shape), buf_(nullptr) {
772 set_dtype(type);
773 CHECK_NOTNULL(a);
774 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
775 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
776 }
777 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
778 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
779 *this);
780 }
781 }
782
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)783 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
784 const AllocationAttributes& allocation_attr)
785 : shape_(shape), buf_(nullptr) {
786 set_dtype(type);
787 CHECK_NOTNULL(a);
788 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
789 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
790 }
791 if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
792 buf_ != nullptr && buf_->data() != nullptr) {
793 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
794 LogMemory::UNKNOWN_STEP_ID, *this);
795 }
796 }
797
798 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
799 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
800 // acquiring a lock, which guards initialization of the per-NUMA zone
801 // allocators, and becomes highly contended.
802 //
803 // Note also that it would be better if all Tensor allocations required the user
804 // to specify an allocator, for purposes of accounting, etc. However, the
805 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()806 static Allocator* get_default_cpu_allocator() {
807 static Allocator* default_cpu_allocator =
808 cpu_allocator(port::kNUMANoAffinity);
809 return default_cpu_allocator;
810 }
811
Tensor(DataType type,const TensorShape & shape)812 Tensor::Tensor(DataType type, const TensorShape& shape)
813 : Tensor(get_default_cpu_allocator(), type, shape) {}
814
GetAllocatedBytes(size_t * out_bytes) const815 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
816 size_t* out_bytes) const {
817 // `this->FillAllocationDescription()` never sets allocated bytes information,
818 // so we can short-circuit the construction of an `AllocationDescription`.
819 return false;
820 }
821
FillAllocationDescription(AllocationDescription * proto) const822 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
823 AllocationDescription* proto) const {
824 proto->set_requested_bytes(size());
825 proto->set_allocator_name("HostScalarTensorBuffer");
826 proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
827 }
828
829 template <typename T>
830 class SubBuffer : public TensorBuffer {
831 public:
832 // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64_t delta,int64_t n)833 SubBuffer(TensorBuffer* buf, int64_t delta, int64_t n)
834 : TensorBuffer(buf->base<T>() + delta),
835 root_(buf->root_buffer()),
836 elem_(n) {
837 // Sanity check. The caller should ensure the sub buffer is valid.
838 CHECK_LE(root_->base<T>(), this->base<T>());
839 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
840 CHECK_LE(this->base<T>(), root_limit);
841 CHECK_LE(this->base<T>() + n, root_limit);
842 // Hold a ref of the underlying root buffer.
843 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
844 root_->Ref();
845 }
846
size() const847 size_t size() const override { return sizeof(T) * elem_; }
root_buffer()848 TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const849 bool GetAllocatedBytes(size_t* out_bytes) const override {
850 return root_->GetAllocatedBytes(out_bytes);
851 }
FillAllocationDescription(AllocationDescription * proto) const852 void FillAllocationDescription(AllocationDescription* proto) const override {
853 root_->FillAllocationDescription(proto);
854 }
855
856 private:
857 TensorBuffer* root_;
858 int64 elem_;
859
~SubBuffer()860 ~SubBuffer() override { root_->Unref(); }
861
862 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
863 };
864
Slice(int64_t start,int64_t limit) const865 Tensor Tensor::Slice(int64_t start, int64_t limit) const {
866 CHECK_GE(dims(), 1);
867 CHECK_LE(0, start);
868 CHECK_LE(start, limit);
869 int64_t dim0_size = shape_.dim_size(0);
870 CHECK_LE(limit, dim0_size);
871 if ((start == 0) && (limit == dim0_size)) {
872 return *this;
873 }
874 Tensor ret;
875 ret.shape_ = shape_;
876 ret.set_dtype(dtype());
877 ret.buf_ = nullptr;
878 if (dim0_size > 0) {
879 const int64_t elems_per_dim0 = NumElements() / dim0_size;
880 const int64_t delta = start * elems_per_dim0;
881 dim0_size = limit - start;
882 ret.shape_.set_dim(0, dim0_size);
883 const int64_t num_elems = dim0_size * elems_per_dim0;
884 if (buf_) {
885 DataType dt = dtype();
886 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
887 }
888 }
889 return ret;
890 }
891
SubSlice(int64_t index) const892 Tensor Tensor::SubSlice(int64_t index) const {
893 CHECK_GE(dims(), 1); // Crash ok.
894 CHECK_LE(0, index); // Crash ok.
895 int64_t dim0_size = shape_.dim_size(0);
896 CHECK_LE(index, dim0_size); // Crash ok.
897 Tensor ret;
898 ret.shape_ = shape_;
899 ret.shape_.RemoveDim(0);
900 ret.set_dtype(dtype());
901 ret.buf_ = nullptr;
902 if (dim0_size > 0) {
903 const int64_t elems_per_dim0 = NumElements() / dim0_size;
904 const int64_t delta = index * elems_per_dim0;
905 const int64_t num_elems = elems_per_dim0;
906 if (buf_) {
907 DataType dt = dtype();
908 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
909 }
910 }
911 return ret;
912 }
913
FromProto(const TensorProto & proto)914 bool Tensor::FromProto(const TensorProto& proto) {
915 return FromProto(get_default_cpu_allocator(), proto);
916 }
917
FromProto(Allocator * a,const TensorProto & proto)918 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
919 CHECK_NOTNULL(a);
920 TensorBuffer* p = nullptr;
921 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
922 if (proto.dtype() == DT_INVALID) return false;
923 TensorShape shape(proto.tensor_shape());
924 const int64_t N = shape.num_elements();
925 if (N > 0 && proto.dtype()) {
926 bool dtype_error = false;
927 if (!proto.tensor_content().empty()) {
928 const auto& content = proto.tensor_content();
929 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
930 dtype_error = true, dtype_error = true);
931 } else {
932 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
933 dtype_error = true, dtype_error = true);
934 }
935 if (dtype_error || p == nullptr) return false;
936 }
937 shape_ = shape;
938 set_dtype(proto.dtype());
939 UnrefIfNonNull(buf_);
940 buf_ = p;
941 // TODO(misard) add tracking of which kernels and steps are calling
942 // FromProto.
943 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
944 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
945 LogMemory::UNKNOWN_STEP_ID, *this);
946 }
947 return true;
948 }
949
AsProtoField(TensorProto * proto) const950 void Tensor::AsProtoField(TensorProto* proto) const {
951 proto->Clear();
952 shape_.AsProto(proto->mutable_tensor_shape());
953 proto->set_dtype(dtype());
954 if (buf_) {
955 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
956 }
957 }
958
AsProtoTensorContent(TensorProto * proto) const959 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
960 proto->Clear();
961 proto->set_dtype(dtype());
962 shape_.AsProto(proto->mutable_tensor_shape());
963 if (buf_) {
964 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
965 proto->mutable_tensor_content()));
966 }
967 }
968
TotalBytes() const969 size_t Tensor::TotalBytes() const {
970 if (shape_.num_elements() == 0) return 0;
971 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
972 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
973 return 0; // Makes compiler happy.
974 }
975
AllocatedBytes() const976 size_t Tensor::AllocatedBytes() const {
977 if (buf_) {
978 size_t ret;
979 if (buf_->GetAllocatedBytes(&ret)) {
980 return ret;
981 }
982 }
983 return TotalBytes();
984 }
985
CanUseDMA() const986 bool Tensor::CanUseDMA() const {
987 CASES(dtype(), return is_simple_type<T>::value);
988 return false; // Makes compiler happy.
989 }
990
991 #undef CASES
992 #undef CASE
993
994 namespace {
995
996 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
997 // we would like to keep them compatible with their absl counterparts, for ease
998 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
999 // logic is so simple we can just replicate it here, where it is close to its
1000 // usage and easy to change later. And there's the extra benefit of not
1001 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)1002 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1003 bool print_v2) {
1004 return a;
1005 }
PrintOneElement(const tstring & a,bool print_v2)1006 inline string PrintOneElement(const tstring& a, bool print_v2) {
1007 if (print_v2) {
1008 return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1009 } else {
1010 return absl::Utf8SafeCEscape(a);
1011 }
1012 }
PrintOneElement(const Eigen::half & h,bool print_v2)1013 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1014 return static_cast<float>(h);
1015 }
1016
PrintOneElement(bfloat16 f,bool print_v2)1017 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1018 return static_cast<float>(f);
1019 }
1020
1021 // Print from left dim to right dim recursively.
1022 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t limit,int shape_size,const T * data,int64 * data_index,string * result)1023 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1024 int64_t limit, int shape_size, const T* data,
1025 int64* data_index, string* result) {
1026 if (*data_index >= limit) return;
1027 int64_t element_count = shape[dim_index];
1028 // We have reached the right-most dimension of the tensor.
1029 if (dim_index == shape_size - 1) {
1030 for (int64_t i = 0; i < element_count; i++) {
1031 if (*data_index >= limit) {
1032 // If not enough elements has been printed, append "...".
1033 if (dim_index != 0) {
1034 strings::StrAppend(result, "...");
1035 }
1036 return;
1037 }
1038 if (i > 0) strings::StrAppend(result, " ");
1039 strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1040 }
1041 return;
1042 }
1043 // Loop every element of one dim.
1044 for (int64_t i = 0; i < element_count; i++) {
1045 bool flag = false;
1046 if (*data_index < limit) {
1047 strings::StrAppend(result, "[");
1048 flag = true;
1049 }
1050 // As for each element, print the sub-dim.
1051 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1052 result);
1053 if (*data_index < limit || flag) {
1054 strings::StrAppend(result, "]");
1055 flag = false;
1056 }
1057 }
1058 }
1059
1060 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1061 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1062 if (dim_index == num_dims - 1) {
1063 strings::StrAppend(result, " ");
1064 return;
1065 }
1066 for (int j = 0; j < num_dims - dim_index - 1; j++) {
1067 strings::StrAppend(result, "\n");
1068 }
1069 for (int j = 0; j <= dim_index; j++) {
1070 strings::StrAppend(result, " ");
1071 }
1072 }
1073
1074 // Print from left dim to right dim recursively.
1075 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t num_elts_at_ends,int num_dims,const T * data,int64_t data_index,string * result)1076 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1077 int64_t num_elts_at_ends, int num_dims, const T* data,
1078 int64_t data_index, string* result) {
1079 // We have recursed beyond all the dimensions into a single element
1080 // of the tensor.
1081 if (dim_index == num_dims) {
1082 strings::StrAppend(result, PrintOneElement(data[data_index], true));
1083 return;
1084 }
1085
1086 strings::StrAppend(result, "[");
1087 int64_t element_count = shape[dim_index];
1088 int64_t start_of_end =
1089 std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1090
1091 // Loop every element of one dim.
1092 int64_t elements_per_iter = 1;
1093 for (int i = dim_index + 1; i < num_dims; i++) {
1094 elements_per_iter *= shape[i];
1095 }
1096 for (int64_t i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1097 if (i > 0) {
1098 PrintDimSpacing(dim_index, num_dims, result);
1099 }
1100
1101 // As for each element, print the sub-dim.
1102 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1103 data_index + elements_per_iter * i, result);
1104 }
1105 if (element_count > 2 * num_elts_at_ends) {
1106 PrintDimSpacing(dim_index, num_dims, result);
1107 strings::StrAppend(result, "...");
1108 }
1109 for (int64_t i = start_of_end; i < element_count; i++) {
1110 // As for each element, print the sub-dim.
1111 PrintDimSpacing(dim_index, num_dims, result);
1112 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1113 data_index + elements_per_iter * i, result);
1114 }
1115
1116 strings::StrAppend(result, "]");
1117 }
1118
1119 template <typename T>
SummarizeArray(int64_t limit,int64_t num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1120 string SummarizeArray(int64_t limit, int64_t num_elts,
1121 const TensorShape& tensor_shape, const char* data,
1122 const bool print_v2) {
1123 string ret;
1124 const T* array = reinterpret_cast<const T*>(data);
1125
1126 const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1127 if (shape.empty()) {
1128 for (int64_t i = 0; i < limit; ++i) {
1129 if (i > 0) strings::StrAppend(&ret, " ");
1130 strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1131 }
1132 if (num_elts > limit) strings::StrAppend(&ret, "...");
1133 return ret;
1134 }
1135 if (print_v2) {
1136 const int num_dims = tensor_shape.dims();
1137 PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1138 } else {
1139 int64_t data_index = 0;
1140 const int shape_size = tensor_shape.dims();
1141 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1142
1143 if (num_elts > limit) strings::StrAppend(&ret, "...");
1144 }
1145
1146 return ret;
1147 }
1148 } // namespace
1149
SummarizeValue(int64_t max_entries,bool print_v2) const1150 string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const {
1151 const int64_t num_elts = NumElements();
1152 if (max_entries < 0) {
1153 max_entries = num_elts;
1154 }
1155 size_t limit = std::min(max_entries, num_elts);
1156 if ((limit > 0) && (buf_ == nullptr)) {
1157 return strings::StrCat("uninitialized Tensor of ", num_elts,
1158 " elements of type ", dtype());
1159 }
1160 const char* data = limit > 0 ? tensor_data().data() : nullptr;
1161 switch (dtype()) {
1162 case DT_BFLOAT16:
1163 return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1164 break;
1165 case DT_HALF:
1166 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1167 print_v2);
1168 break;
1169 case DT_FLOAT:
1170 return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1171 break;
1172 case DT_DOUBLE:
1173 return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1174 break;
1175 case DT_UINT32:
1176 return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1177 break;
1178 case DT_INT32:
1179 return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1180 break;
1181 case DT_UINT8:
1182 case DT_QUINT8:
1183 return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1184 break;
1185 case DT_UINT16:
1186 case DT_QUINT16:
1187 return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1188 break;
1189 case DT_INT16:
1190 case DT_QINT16:
1191 return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1192 break;
1193 case DT_INT8:
1194 case DT_QINT8:
1195 return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1196 break;
1197 case DT_UINT64:
1198 return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1199 break;
1200 case DT_INT64:
1201 return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1202 break;
1203 case DT_BOOL:
1204 // TODO(tucker): Is it better to emit "True False..."? This
1205 // will emit "1 0..." which is more compact.
1206 return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1207 break;
1208 case DT_STRING:
1209 return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1210 break;
1211 default: {
1212 // All irregular cases
1213 string ret;
1214 if (print_v2 && (dims() > 0)) {
1215 strings::StrAppend(&ret, "[");
1216 }
1217 // TODO(irving): Don't call flat every time around this
1218 // loop.
1219 for (size_t i = 0; i < limit; ++i) {
1220 if (i > 0) strings::StrAppend(&ret, " ");
1221 switch (dtype()) {
1222 case DT_VARIANT: {
1223 const Variant& v = flat<Variant>()(i);
1224 strings::StrAppend(&ret, "<", v.SummarizeValue(), ">");
1225 } break;
1226 case DT_RESOURCE: {
1227 const ResourceHandle& r = flat<ResourceHandle>()(i);
1228 strings::StrAppend(&ret, "<", r.SummarizeValue(), ">");
1229 } break;
1230 default:
1231 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1232 // complex64, quantized).
1233 strings::StrAppend(&ret, "?");
1234 }
1235 }
1236 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1237 if (print_v2 && (dims() > 0)) {
1238 strings::StrAppend(&ret, "]");
1239 }
1240 return ret;
1241 }
1242 }
1243 }
1244
tensor_data() const1245 StringPiece Tensor::tensor_data() const {
1246 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1247 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1248 }
1249
data() const1250 void* Tensor::data() const {
1251 if (buf_ == nullptr) return nullptr; // Don't die for empty tensors
1252 return static_cast<void*>(buf_->data());
1253 }
1254
SharesBufferWith(const Tensor & b) const1255 bool Tensor::SharesBufferWith(const Tensor& b) const {
1256 return buf_ != nullptr && b.buf_ != nullptr &&
1257 buf_->root_buffer() == b.buf_->root_buffer();
1258 }
1259
DebugString(int num_values) const1260 string Tensor::DebugString(int num_values) const {
1261 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1262 " shape: ", shape().DebugString(),
1263 " values: ", SummarizeValue(num_values), ">");
1264 }
1265
DeviceSafeDebugString() const1266 string Tensor::DeviceSafeDebugString() const {
1267 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1268 " shape: ", shape().DebugString(), ">");
1269 }
1270
FillDescription(TensorDescription * description) const1271 void Tensor::FillDescription(TensorDescription* description) const {
1272 description->set_dtype(dtype());
1273 shape().AsProto(description->mutable_shape());
1274 if (buf_ != nullptr && buf_->data() != nullptr) {
1275 buf_->FillAllocationDescription(
1276 description->mutable_allocation_description());
1277 }
1278 }
1279
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64_t num_out_dims)1280 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1281 gtl::ArraySlice<int64> orig, int64_t num_out_dims) {
1282 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1283 int64_t offset = orig.size() - num_out_dims;
1284 for (int64_t out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1285 const int64_t in_dim = out_dim + offset;
1286 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1287 }
1288 for (int64_t in_dim = 0; in_dim < offset; ++in_dim) {
1289 out_dims[0] *= orig[in_dim];
1290 }
1291 return out_dims;
1292 }
1293
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64_t num_out_dims)1294 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1295 gtl::ArraySlice<int64> orig, int64_t num_out_dims) {
1296 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1297 for (int64_t out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1298 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1299 }
1300 for (int64_t in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1301 out_dims[num_out_dims - 1] *= orig[in_dim];
1302 }
1303 return out_dims;
1304 }
1305
1306 } // namespace tensorflow
1307