1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 // The array is allocated by the given allocator. It runs T's
23 // default constructors and destructors when T is not a simple type
24 // (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T. The routines
27 // includes running the constructor and destructor of T[], encoding
28 // an decoding T[] into/from a Cord, etc.
29
30 #include "tensorflow/core/framework/tensor.h"
31
32 #include "absl/strings/escaping.h"
33 #include "tensorflow/core/framework/allocation_description.pb.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/resource_handle.pb.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_description.pb.h"
38 #include "tensorflow/core/framework/type_traits.h"
39 #include "tensorflow/core/framework/typed_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/variant.h"
42 #include "tensorflow/core/framework/variant_encode_decode.h"
43 #include "tensorflow/core/framework/variant_op_registry.h"
44 #include "tensorflow/core/framework/variant_tensor_data.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/tensor_coding.h"
55 #include "tensorflow/core/platform/types.h"
56
57 namespace tensorflow {
58
59 // Allow Tensors to be stored inside Variants with automatic
60 // encoding/decoding when those Variants are themselves being decoded
61 // in a Tensor's FromProto.
62 //
63 // NOTE(mrry): The corresponding "copy function" registrations can be found in
64 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
65 // code).
66 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
67
GetAllocatedBytes(size_t * out_bytes) const68 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
69 AllocationDescription allocation_description;
70 FillAllocationDescription(&allocation_description);
71 if (allocation_description.allocated_bytes() > 0) {
72 *out_bytes = allocation_description.allocated_bytes();
73 return true;
74 } else {
75 return false;
76 }
77 }
78
79 namespace {
80
81 // An un-templated base class for Buffer.
82 class BufferBase : public TensorBuffer {
83 public:
BufferBase(Allocator * alloc,void * data_ptr)84 explicit BufferBase(Allocator* alloc, void* data_ptr)
85 : TensorBuffer(data_ptr), alloc_(alloc) {}
86
root_buffer()87 TensorBuffer* root_buffer() override { return this; }
88
GetAllocatedBytes(size_t * out_bytes) const89 bool GetAllocatedBytes(size_t* out_bytes) const override {
90 if (alloc_->TracksAllocationSizes()) {
91 *out_bytes = alloc_->AllocatedSize(data());
92 return *out_bytes > 0;
93 } else {
94 return false;
95 }
96 }
97
FillAllocationDescription(AllocationDescription * proto) const98 void FillAllocationDescription(AllocationDescription* proto) const override {
99 void* data_ptr = data();
100 int64 rb = size();
101 proto->set_requested_bytes(rb);
102 proto->set_allocator_name(alloc_->Name());
103 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
104 if (alloc_->TracksAllocationSizes()) {
105 int64 ab = alloc_->AllocatedSize(data_ptr);
106 proto->set_allocated_bytes(ab);
107 int64 id = alloc_->AllocationId(data_ptr);
108 if (id > 0) {
109 proto->set_allocation_id(id);
110 }
111 if (RefCountIsOne()) {
112 proto->set_has_single_reference(true);
113 }
114 }
115 }
116
117 protected:
RecordDeallocation()118 void RecordDeallocation() {
119 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
120 alloc_->Name());
121 }
122
123 Allocator* const alloc_;
124 };
125
126 // Typed ref-counted buffer: T[n].
127 template <typename T>
128 class Buffer : public BufferBase {
129 public:
130 Buffer(Allocator* a, int64 n);
131 Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
132
size() const133 size_t size() const override { return sizeof(T) * elem_; }
134
135 private:
136 int64 elem_;
137
138 ~Buffer() override;
139
140 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
141 };
142
LogUnexpectedSize(int64 actual,int64 expected)143 void LogUnexpectedSize(int64 actual, int64 expected) {
144 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
145 }
146
MemoryLoggingEnabled()147 bool MemoryLoggingEnabled() {
148 static bool memory_logging_enabled = LogMemory::IsEnabled();
149 return memory_logging_enabled;
150 }
151
152 // A set of helper functions depending on T.
153 template <typename T>
154 struct Helper {
155 // By default, we assume T is a simple type (float, int32, etc.)
156 static_assert(is_simple_type<T>::value, "T is not a simple type.");
157 typedef protobuf::RepeatedField<T> RepeatedFieldType;
158
159 // Encoder of simple type T to a string. We do a copy.
160 template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper161 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
162 DCHECK_EQ(in->size(), sizeof(T) * n);
163 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
164 out);
165 }
166
167 // Decoder of simple type T. Copy the bytes from "in" into the
168 // tensor buffer.
169 template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper170 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
171 if (in.size() != sizeof(T) * n) {
172 LogUnexpectedSize(in.size(), sizeof(T) * n);
173 return nullptr;
174 }
175 Buffer<T>* buf = new Buffer<T>(a, n);
176 char* data = buf->template base<char>();
177 if (data == nullptr) {
178 buf->Unref();
179 return nullptr;
180 }
181 port::CopyToArray(in, data);
182 return buf;
183 }
184
185 // Memory usage.
TotalBytestensorflow::__anon8c201cc00111::Helper186 static int64 TotalBytes(TensorBuffer* in, int64 n) {
187 DCHECK_EQ(in->size(), sizeof(T) * n);
188 return in->size();
189 }
190 };
191
192 // Helper specialization for string (the only non-simple type we
193 // support).
194 template <>
195 struct Helper<tstring> {
196 // Proto message uses RepeatedFieldType to hold repeated T.
197 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
198
199 // Encodes "n" elements of type string stored in "in" into Cord
200 // "out", which is usually the TensorProto::tensor_content.
201 template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper202 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
203 port::EncodeStringList(in->base<const tstring>(), n, out);
204 }
205
206 // Decodes "n" elements of type string from "in" and constructs a
207 // buffer out of it. Returns nullptr if the decoding fails. "in" is
208 // usually the TensorProto::tensor_content.
209 template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper210 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
211 Buffer<tstring>* buf = new Buffer<tstring>(a, n);
212 tstring* strings = buf->template base<tstring>();
213 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
214 buf->Unref();
215 return nullptr;
216 }
217 return buf;
218 }
219
220 // Returns the estimated memory usage of "n" elements of type T
221 // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper222 static int64 TotalBytes(TensorBuffer* in, int n) {
223 int64 tot = in->size();
224 DCHECK_EQ(tot, sizeof(tstring) * n);
225 const tstring* p = in->base<const tstring>();
226 for (int i = 0; i < n; ++i, ++p) tot += p->size();
227 return tot;
228 }
229 };
230
231 template <>
232 struct Helper<ResourceHandle> {
233 // Proto message uses RepeatedFieldType to hold repeated T.
234 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
235
236 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
237 // "out", which is usually the TensorProto::tensor_content.
238 template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper239 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240 EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
241 port::NewStringListEncoder(out));
242 }
243
244 // Decodes "n" elements of type string from "in" and constructs a
245 // buffer out of it. Returns nullptr if the decoding fails. "in" is
246 // usually the TensorProto::tensor_content.
247 template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper248 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
249 auto* buf = new Buffer<ResourceHandle>(a, n);
250 ResourceHandle* ps = buf->template base<ResourceHandle>();
251 if (ps == nullptr ||
252 !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
253 buf->Unref();
254 return nullptr;
255 }
256 return buf;
257 }
258
259 // Returns the estimated memory usage of "n" elements of type T
260 // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper261 static int64 TotalBytes(TensorBuffer* in, int n) {
262 return n * sizeof(ResourceHandle);
263 }
264 };
265
266 template <>
267 struct Helper<Variant> {
268 // Encodes "n" elements of type Variant stored in "in" into destination
269 // "out", which is usually the TensorProto::tensor_content.
270 template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper271 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
272 EncodeVariantList(in->base<const Variant>(), n,
273 port::NewStringListEncoder(out));
274 }
275
276 // Decodes "n" elements of type Variant from "in" and constructs a
277 // buffer out of it. Returns nullptr if the decoding fails. "in" is
278 // usually the TensorProto::tensor_content.
279 template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper280 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
281 auto* buf = new Buffer<Variant>(a, n);
282 Variant* ps = buf->template base<Variant>();
283 if (ps == nullptr ||
284 !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
285 buf->Unref();
286 return nullptr;
287 }
288 return buf;
289 }
290
291 // Returns the estimated memory usage of "n" elements of type T
292 // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper293 static int64 TotalBytes(TensorBuffer* in, int n) {
294 return n * sizeof(Variant);
295 }
296 };
297
298 template <typename T>
299 struct ProtoHelper {};
300
301 // For a C++ type "T" (float, double, int32, etc.), the repeated field
302 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
303 // int32, string, etc) in the TensorProto is used for serializing the
304 // tensor of type "T".
305 #define PROTO_TRAITS(T, F, N) \
306 template <> \
307 struct ProtoHelper<T> { \
308 typedef Helper<F>::RepeatedFieldType FieldType; \
309 static FieldType::const_iterator Begin(const TensorProto& proto) { \
310 return proto.N##_val().begin(); \
311 } \
312 static size_t NumElements(const TensorProto& proto) { \
313 return proto.N##_val().size(); \
314 } \
315 static void Fill(const T* data, size_t n, TensorProto* proto) { \
316 typename ProtoHelper<T>::FieldType copy(data, data + n); \
317 proto->mutable_##N##_val()->Swap(©); \
318 } \
319 };
320 PROTO_TRAITS(float, float, float);
321 PROTO_TRAITS(double, double, double);
322 PROTO_TRAITS(int32, int32, int);
323 PROTO_TRAITS(uint8, int32, int);
324 PROTO_TRAITS(uint16, int32, int);
325 PROTO_TRAITS(uint32, uint32, uint32);
326 PROTO_TRAITS(int16, int32, int);
327 PROTO_TRAITS(int8, int32, int);
328 PROTO_TRAITS(bool, bool, bool);
329 PROTO_TRAITS(tstring, tstring, string);
330 PROTO_TRAITS(qint8, int32, int);
331 PROTO_TRAITS(quint8, int32, int);
332 PROTO_TRAITS(qint16, int32, int);
333 PROTO_TRAITS(quint16, int32, int);
334 #undef PROTO_TRAITS
335
336 template <>
337 struct ProtoHelper<int64> {
Begintensorflow::__anon8c201cc00111::ProtoHelper338 static const int64* Begin(const TensorProto& proto) {
339 return reinterpret_cast<const int64*>(proto.int64_val().begin());
340 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper341 static size_t NumElements(const TensorProto& proto) {
342 return proto.int64_val().size();
343 }
Filltensorflow::__anon8c201cc00111::ProtoHelper344 static void Fill(const int64* data, size_t n, TensorProto* proto) {
345 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
346 proto->mutable_int64_val()->Swap(©);
347 }
348 };
349
350 template <>
351 struct ProtoHelper<uint64> {
Begintensorflow::__anon8c201cc00111::ProtoHelper352 static const uint64* Begin(const TensorProto& proto) {
353 return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
354 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper355 static size_t NumElements(const TensorProto& proto) {
356 return proto.uint64_val().size();
357 }
Filltensorflow::__anon8c201cc00111::ProtoHelper358 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
359 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
360 proto->mutable_uint64_val()->Swap(©);
361 }
362 };
363
364 template <>
365 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon8c201cc00111::ProtoHelper366 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
367 const TensorProto& proto) {
368 return proto.resource_handle_val().begin();
369 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper370 static size_t NumElements(const TensorProto& proto) {
371 return proto.resource_handle_val().size();
372 }
Filltensorflow::__anon8c201cc00111::ProtoHelper373 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
374 auto* handles = proto->mutable_resource_handle_val();
375 handles->Clear();
376 for (size_t i = 0; i < n; i++) {
377 data[i].AsProto(handles->Add());
378 }
379 }
380 };
381
382 template <>
383 struct ProtoHelper<Variant> {
384 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon8c201cc00111::ProtoHelper385 Begin(const TensorProto& proto) {
386 return proto.variant_val().begin();
387 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper388 static size_t NumElements(const TensorProto& proto) {
389 return proto.variant_val().size();
390 }
Filltensorflow::__anon8c201cc00111::ProtoHelper391 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
392 auto* variant_values = proto->mutable_variant_val();
393 variant_values->Clear();
394 for (size_t i = 0; i < n; ++i) {
395 VariantTensorData tmp;
396 data[i].Encode(&tmp);
397 tmp.ToProto(variant_values->Add());
398 }
399 }
400 };
401
402 template <>
403 struct ProtoHelper<complex64> {
404 typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper405 static const complex64* Begin(const TensorProto& proto) {
406 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
407 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper408 static size_t NumElements(const TensorProto& proto) {
409 return proto.scomplex_val().size() / 2;
410 }
Filltensorflow::__anon8c201cc00111::ProtoHelper411 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
412 const float* p = reinterpret_cast<const float*>(data);
413 FieldType copy(p, p + n * 2);
414 proto->mutable_scomplex_val()->Swap(©);
415 }
416 };
417
418 template <>
419 struct ProtoHelper<complex128> {
420 typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper421 static const complex128* Begin(const TensorProto& proto) {
422 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
423 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper424 static size_t NumElements(const TensorProto& proto) {
425 return proto.dcomplex_val().size() / 2;
426 }
Filltensorflow::__anon8c201cc00111::ProtoHelper427 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
428 const double* p = reinterpret_cast<const double*>(data);
429 FieldType copy(p, p + n * 2);
430 proto->mutable_dcomplex_val()->Swap(©);
431 }
432 };
433
434 template <>
435 struct ProtoHelper<qint32> {
436 typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper437 static const qint32* Begin(const TensorProto& proto) {
438 return reinterpret_cast<const qint32*>(proto.int_val().data());
439 }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper440 static size_t NumElements(const TensorProto& proto) {
441 return proto.int_val().size();
442 }
Filltensorflow::__anon8c201cc00111::ProtoHelper443 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
444 const int32* p = reinterpret_cast<const int32*>(data);
445 FieldType copy(p, p + n);
446 proto->mutable_int_val()->Swap(©);
447 }
448 };
449
450 template <>
451 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon8c201cc00111::ProtoHelper452 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
453 proto->mutable_half_val()->Reserve(n);
454 for (size_t i = 0; i < n; ++i) {
455 proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
456 }
457 }
458 };
459
460 template <>
461 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon8c201cc00111::ProtoHelper462 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
463 proto->mutable_half_val()->Reserve(n);
464 for (size_t i = 0; i < n; ++i) {
465 proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
466 }
467 }
468 };
469
470 template <typename T>
Buffer(Allocator * a,int64 n)471 Buffer<T>::Buffer(Allocator* a, int64 n)
472 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
473 elem_(n) {}
474
475 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)476 Buffer<T>::Buffer(Allocator* a, int64 n,
477 const AllocationAttributes& allocation_attr)
478 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
479 elem_(n) {}
480
481 template <typename T>
~Buffer()482 Buffer<T>::~Buffer() {
483 if (data()) {
484 if (MemoryLoggingEnabled()) {
485 RecordDeallocation();
486 }
487 TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
488 }
489 }
490
491 // Allocates a T[n] buffer. Fills in the buffer with repeated values
492 // in "in". If "in" has less values than "n", fills the rest of T[n]
493 // with the last value. If "in" has no values, fills T[n] with the
494 // default value for T.
495 //
496 // This routine is using the typed fields (float_val, etc.) in the
497 // tensor proto as opposed to the untyped binary representation
498 // (tensor_content). This is used when we expect the TensorProto is
499 // used by a client program which may not know how to encode a tensor
500 // in the compact binary representation.
501 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)502 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
503 CHECK_GT(n, 0);
504 Buffer<T>* buf = new Buffer<T>(a, n);
505 T* data = buf->template base<T>();
506 if (data == nullptr) {
507 buf->Unref();
508 return nullptr;
509 }
510
511 const int64 in_n = ProtoHelper<T>::NumElements(in);
512 if (in_n <= 0) {
513 std::fill_n(data, n, T());
514 } else {
515 auto begin = ProtoHelper<T>::Begin(in);
516 if (n <= in_n) {
517 std::copy_n(begin, n, data);
518 } else {
519 std::copy_n(begin, in_n, data);
520 if (std::is_trivially_copyable<T>::value) {
521 const T last = *(data + in_n - 1);
522 std::fill_n(data + in_n, n - in_n, last);
523 } else {
524 const T& last = *(data + in_n - 1);
525 std::fill_n(data + in_n, n - in_n, last);
526 }
527 }
528 }
529
530 return buf;
531 }
532
533 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)534 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
535 int64 n) {
536 CHECK_GT(n, 0);
537 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
538 Variant* data = buf->template base<Variant>();
539 if (data == nullptr) {
540 buf->Unref();
541 return nullptr;
542 }
543 const int64 in_n = ProtoHelper<Variant>::NumElements(in);
544 if (in_n <= 0) {
545 std::fill_n(data, n, Variant());
546 } else {
547 // If tensor shape says we have n < in_n elements in the output tensor
548 // then make sure to only decode the first n out of the in_n elements in the
549 // in tensors. In all other cases, we decode all in_n elements of in and set
550 // the remaining elements up to n to be the default Variant() value.
551 const int64 real_n = n < in_n ? n : in_n;
552 for (int64 i = 0; i < real_n; ++i) {
553 data[i] = in.variant_val(i);
554 if (!DecodeUnaryVariant(&data[i])) {
555 LOG(ERROR) << "Could not decode variant with type_name: \""
556 << data[i].TypeName()
557 << "\". Perhaps you forgot to register a "
558 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
559 buf->Unref();
560 return nullptr;
561 }
562 }
563 for (int64 i = in_n; i < n; ++i) {
564 data[i] = Variant();
565 }
566 }
567 return buf;
568 }
569
570 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
571 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
572 // we don't use ProtoHelper<uint16>).
573 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)574 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
575 int64 n) {
576 CHECK_GT(n, 0);
577 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
578 uint16* data = buf->template base<uint16>();
579 if (data == nullptr) {
580 buf->Unref();
581 return nullptr;
582 }
583 const int64 in_n = in.half_val().size();
584 auto begin = in.half_val().begin();
585 if (n <= in_n) {
586 std::copy_n(begin, n, data);
587 } else if (in_n > 0) {
588 std::copy_n(begin, in_n, data);
589 const uint16 last = *(data + in_n - 1);
590 std::fill_n(data + in_n, n - in_n, last);
591 } else {
592 std::fill_n(data, n, 0);
593 }
594 return buf;
595 }
596
597 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)598 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
599 int64 n) {
600 CHECK_GT(n, 0);
601 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
602 uint16* data = buf->template base<uint16>();
603 if (data == nullptr) {
604 buf->Unref();
605 return nullptr;
606 }
607 const int64 in_n = in.half_val().size();
608 auto begin = in.half_val().begin();
609 if (n <= in_n) {
610 std::copy_n(begin, n, data);
611 } else if (in_n > 0) {
612 std::copy_n(begin, in_n, data);
613 const uint16 last = *(data + in_n - 1);
614 std::fill_n(data + in_n, n - in_n, last);
615 } else {
616 std::fill_n(data, n, 0);
617 }
618 return buf;
619 }
620
621 // Copies T[n] stored in the buffer "in" into the repeated field in
622 // "out" corresponding to type T.
623 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)624 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
625 const T* data = in.base<const T>();
626 // NOTE: T may not the same as
627 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
628 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
629 // critical, we can specialize T=float and do memcpy directly.
630 ProtoHelper<T>::Fill(data, n, out);
631 }
632
RefIfNonNull(core::RefCounted * buf)633 void RefIfNonNull(core::RefCounted* buf) {
634 if (buf) buf->Ref();
635 }
636
UnrefIfNonNull(core::RefCounted * buf)637 void UnrefIfNonNull(core::RefCounted* buf) {
638 if (buf) buf->Unref();
639 }
640
641 } // end namespace
642
Tensor()643 Tensor::Tensor() : Tensor(DT_FLOAT) {}
644
Tensor(DataType type)645 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
646
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)647 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
648 : shape_(shape), buf_(buf) {
649 set_dtype(type);
650 RefIfNonNull(buf);
651 }
652
Tensor(DataType type,const TensorShape & shape,core::RefCountPtr<TensorBuffer> buf)653 Tensor::Tensor(DataType type, const TensorShape& shape,
654 core::RefCountPtr<TensorBuffer> buf)
655 : shape_(shape), buf_(buf.release()) {
656 set_dtype(type);
657 }
658
IsInitialized() const659 bool Tensor::IsInitialized() const {
660 return (buf_ != nullptr && buf_->data() != nullptr) ||
661 shape_.num_elements() == 0;
662 }
663
CheckType(DataType expected_dtype) const664 void Tensor::CheckType(DataType expected_dtype) const {
665 CHECK_EQ(dtype(), expected_dtype)
666 << " " << DataTypeString(expected_dtype) << " expected, got "
667 << DataTypeString(dtype());
668 }
669
CheckTypeAndIsAligned(DataType expected_dtype) const670 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
671 CHECK_EQ(dtype(), expected_dtype)
672 << " " << DataTypeString(expected_dtype) << " expected, got "
673 << DataTypeString(dtype());
674 CHECK(IsAligned()) << "ptr = " << base<void>();
675 }
676
CheckIsAlignedAndSingleElement() const677 void Tensor::CheckIsAlignedAndSingleElement() const {
678 CHECK(IsAligned()) << "Aligned and single element";
679 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
680 }
681
~Tensor()682 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
683
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)684 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
685 const TensorShape& shape) {
686 int in_size = DataTypeSize(other.dtype());
687 int out_size = DataTypeSize(dtype);
688 if (in_size == 0) {
689 return errors::InvalidArgument("other tensor has zero-sized data type");
690 }
691 if (out_size == 0) {
692 return errors::InvalidArgument("specified output type is zero-sized");
693 }
694 if (shape.num_elements() * out_size !=
695 other.shape().num_elements() * in_size) {
696 return errors::InvalidArgument(
697 "input and output shapes/data type sizes are not compatible");
698 }
699 shape_ = shape;
700 shape_.set_data_type(dtype);
701 if (buf_ != other.buf_) {
702 UnrefIfNonNull(buf_);
703 buf_ = other.buf_;
704 RefIfNonNull(buf_);
705 }
706 return Status::OK();
707 }
708
709 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
710 // For the latter case, we have to make sure that the refcount is
711 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const712 bool Tensor::RefCountIsOne() const {
713 return buf_ != nullptr && buf_->RefCountIsOne() &&
714 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
715 }
716
717 // The macro CASES() expands to a switch statement conditioned on
718 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
719 #define SINGLE_ARG(...) __VA_ARGS__
720 #define CASE(TYPE, STMTS) \
721 case DataTypeToEnum<TYPE>::value: { \
722 typedef TYPE T; \
723 STMTS; \
724 break; \
725 }
726 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
727 switch (TYPE_ENUM) { \
728 CASE(float, SINGLE_ARG(STMTS)) \
729 CASE(double, SINGLE_ARG(STMTS)) \
730 CASE(int32, SINGLE_ARG(STMTS)) \
731 CASE(uint8, SINGLE_ARG(STMTS)) \
732 CASE(uint16, SINGLE_ARG(STMTS)) \
733 CASE(uint32, SINGLE_ARG(STMTS)) \
734 CASE(uint64, SINGLE_ARG(STMTS)) \
735 CASE(int16, SINGLE_ARG(STMTS)) \
736 CASE(int8, SINGLE_ARG(STMTS)) \
737 CASE(tstring, SINGLE_ARG(STMTS)) \
738 CASE(complex64, SINGLE_ARG(STMTS)) \
739 CASE(complex128, SINGLE_ARG(STMTS)) \
740 CASE(int64, SINGLE_ARG(STMTS)) \
741 CASE(bool, SINGLE_ARG(STMTS)) \
742 CASE(qint32, SINGLE_ARG(STMTS)) \
743 CASE(quint8, SINGLE_ARG(STMTS)) \
744 CASE(qint8, SINGLE_ARG(STMTS)) \
745 CASE(quint16, SINGLE_ARG(STMTS)) \
746 CASE(qint16, SINGLE_ARG(STMTS)) \
747 CASE(bfloat16, SINGLE_ARG(STMTS)) \
748 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
749 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
750 CASE(Variant, SINGLE_ARG(STMTS)) \
751 case DT_INVALID: \
752 INVALID; \
753 break; \
754 default: \
755 DEFAULT; \
756 break; \
757 }
758
759 #define CASES(TYPE_ENUM, STMTS) \
760 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, \
761 LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
762 , LOG(FATAL) << "Type not set";)
763
Tensor(Allocator * a,DataType type,const TensorShape & shape)764 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
765 : shape_(shape), buf_(nullptr) {
766 set_dtype(type);
767 CHECK_NOTNULL(a);
768 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
769 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
770 }
771 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
772 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
773 *this);
774 }
775 }
776
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)777 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
778 const AllocationAttributes& allocation_attr)
779 : shape_(shape), buf_(nullptr) {
780 set_dtype(type);
781 CHECK_NOTNULL(a);
782 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
783 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
784 }
785 if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
786 buf_ != nullptr && buf_->data() != nullptr) {
787 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
788 LogMemory::UNKNOWN_STEP_ID, *this);
789 }
790 }
791
792 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
793 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
794 // acquiring a lock, which guards initialization of the per-NUMA zone
795 // allocators, and becomes highly contended.
796 //
797 // Note also that it would be better if all Tensor allocations required the user
798 // to specify an allocator, for purposes of accounting, etc. However, the
799 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()800 static Allocator* get_default_cpu_allocator() {
801 static Allocator* default_cpu_allocator =
802 cpu_allocator(port::kNUMANoAffinity);
803 return default_cpu_allocator;
804 }
805
Tensor(DataType type,const TensorShape & shape)806 Tensor::Tensor(DataType type, const TensorShape& shape)
807 : Tensor(get_default_cpu_allocator(), type, shape) {}
808
GetAllocatedBytes(size_t * out_bytes) const809 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
810 size_t* out_bytes) const {
811 // `this->FillAllocationDescription()` never sets allocated bytes information,
812 // so we can short-circuit the construction of an `AllocationDescription`.
813 return false;
814 }
815
FillAllocationDescription(AllocationDescription * proto) const816 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
817 AllocationDescription* proto) const {
818 proto->set_requested_bytes(size());
819 proto->set_allocator_name("HostScalarTensorBuffer");
820 proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
821 }
822
823 template <typename T>
824 class SubBuffer : public TensorBuffer {
825 public:
826 // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)827 SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
828 : TensorBuffer(buf->base<T>() + delta),
829 root_(buf->root_buffer()),
830 elem_(n) {
831 // Sanity check. The caller should ensure the sub buffer is valid.
832 CHECK_LE(root_->base<T>(), this->base<T>());
833 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
834 CHECK_LE(this->base<T>(), root_limit);
835 CHECK_LE(this->base<T>() + n, root_limit);
836 // Hold a ref of the underlying root buffer.
837 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
838 root_->Ref();
839 }
840
size() const841 size_t size() const override { return sizeof(T) * elem_; }
root_buffer()842 TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const843 bool GetAllocatedBytes(size_t* out_bytes) const override {
844 return root_->GetAllocatedBytes(out_bytes);
845 }
FillAllocationDescription(AllocationDescription * proto) const846 void FillAllocationDescription(AllocationDescription* proto) const override {
847 root_->FillAllocationDescription(proto);
848 }
849
850 private:
851 TensorBuffer* root_;
852 int64 elem_;
853
~SubBuffer()854 ~SubBuffer() override { root_->Unref(); }
855
856 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
857 };
858
Slice(int64 start,int64 limit) const859 Tensor Tensor::Slice(int64 start, int64 limit) const {
860 CHECK_GE(dims(), 1);
861 CHECK_LE(0, start);
862 CHECK_LE(start, limit);
863 int64 dim0_size = shape_.dim_size(0);
864 CHECK_LE(limit, dim0_size);
865 if ((start == 0) && (limit == dim0_size)) {
866 return *this;
867 }
868 Tensor ret;
869 ret.shape_ = shape_;
870 ret.set_dtype(dtype());
871 ret.buf_ = nullptr;
872 if (dim0_size > 0) {
873 const int64 elems_per_dim0 = NumElements() / dim0_size;
874 const int64 delta = start * elems_per_dim0;
875 dim0_size = limit - start;
876 ret.shape_.set_dim(0, dim0_size);
877 const int64 num_elems = dim0_size * elems_per_dim0;
878 if (buf_) {
879 DataType dt = dtype();
880 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
881 }
882 }
883 return ret;
884 }
885
SubSlice(int64 index) const886 Tensor Tensor::SubSlice(int64 index) const {
887 CHECK_GE(dims(), 1); // Crash ok.
888 CHECK_LE(0, index); // Crash ok.
889 int64 dim0_size = shape_.dim_size(0);
890 CHECK_LE(index, dim0_size); // Crash ok.
891 Tensor ret;
892 ret.shape_ = shape_;
893 ret.shape_.RemoveDim(0);
894 ret.set_dtype(dtype());
895 ret.buf_ = nullptr;
896 if (dim0_size > 0) {
897 const int64 elems_per_dim0 = NumElements() / dim0_size;
898 const int64 delta = index * elems_per_dim0;
899 const int64 num_elems = elems_per_dim0;
900 if (buf_) {
901 DataType dt = dtype();
902 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
903 }
904 }
905 return ret;
906 }
907
FromProto(const TensorProto & proto)908 bool Tensor::FromProto(const TensorProto& proto) {
909 return FromProto(get_default_cpu_allocator(), proto);
910 }
911
FromProto(Allocator * a,const TensorProto & proto)912 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
913 CHECK_NOTNULL(a);
914 TensorBuffer* p = nullptr;
915 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
916 if (proto.dtype() == DT_INVALID) return false;
917 TensorShape shape(proto.tensor_shape());
918 const int64 N = shape.num_elements();
919 if (N > 0 && proto.dtype()) {
920 bool dtype_error = false;
921 if (!proto.tensor_content().empty()) {
922 const auto& content = proto.tensor_content();
923 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
924 dtype_error = true, dtype_error = true);
925 } else {
926 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
927 dtype_error = true, dtype_error = true);
928 }
929 if (dtype_error || p == nullptr) return false;
930 }
931 shape_ = shape;
932 set_dtype(proto.dtype());
933 UnrefIfNonNull(buf_);
934 buf_ = p;
935 // TODO(misard) add tracking of which kernels and steps are calling
936 // FromProto.
937 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
938 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
939 LogMemory::UNKNOWN_STEP_ID, *this);
940 }
941 return true;
942 }
943
AsProtoField(TensorProto * proto) const944 void Tensor::AsProtoField(TensorProto* proto) const {
945 proto->Clear();
946 shape_.AsProto(proto->mutable_tensor_shape());
947 proto->set_dtype(dtype());
948 if (buf_) {
949 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
950 }
951 }
952
AsProtoTensorContent(TensorProto * proto) const953 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
954 proto->Clear();
955 proto->set_dtype(dtype());
956 shape_.AsProto(proto->mutable_tensor_shape());
957 if (buf_) {
958 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
959 proto->mutable_tensor_content()));
960 }
961 }
962
TotalBytes() const963 size_t Tensor::TotalBytes() const {
964 if (shape_.num_elements() == 0) return 0;
965 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
966 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
967 return 0; // Makes compiler happy.
968 }
969
AllocatedBytes() const970 size_t Tensor::AllocatedBytes() const {
971 if (buf_) {
972 size_t ret;
973 if (buf_->GetAllocatedBytes(&ret)) {
974 return ret;
975 }
976 }
977 return TotalBytes();
978 }
979
CanUseDMA() const980 bool Tensor::CanUseDMA() const {
981 CASES(dtype(), return is_simple_type<T>::value);
982 return false; // Makes compiler happy.
983 }
984
985 #undef CASES
986 #undef CASE
987
988 namespace {
989
990 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
991 // we would like to keep them compatible with their absl counterparts, for ease
992 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
993 // logic is so simple we can just replicate it here, where it is close to its
994 // usage and easy to change later. And there's the extra benefit of not
995 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)996 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
997 bool print_v2) {
998 return a;
999 }
PrintOneElement(const tstring & a,bool print_v2)1000 inline string PrintOneElement(const tstring& a, bool print_v2) {
1001 if (print_v2) {
1002 return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1003 } else {
1004 return absl::Utf8SafeCEscape(a);
1005 }
1006 }
PrintOneElement(const Eigen::half & h,bool print_v2)1007 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1008 return static_cast<float>(h);
1009 }
1010
PrintOneElement(bfloat16 f,bool print_v2)1011 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1012 return static_cast<float>(f);
1013 }
1014
1015 // Print from left dim to right dim recursively.
1016 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)1017 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1018 int64 limit, int shape_size, const T* data, int64* data_index,
1019 string* result) {
1020 if (*data_index >= limit) return;
1021 int64 element_count = shape[dim_index];
1022 // We have reached the right-most dimension of the tensor.
1023 if (dim_index == shape_size - 1) {
1024 for (int64 i = 0; i < element_count; i++) {
1025 if (*data_index >= limit) {
1026 // If not enough elements has been printed, append "...".
1027 if (dim_index != 0) {
1028 strings::StrAppend(result, "...");
1029 }
1030 return;
1031 }
1032 if (i > 0) strings::StrAppend(result, " ");
1033 strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1034 }
1035 return;
1036 }
1037 // Loop every element of one dim.
1038 for (int64 i = 0; i < element_count; i++) {
1039 bool flag = false;
1040 if (*data_index < limit) {
1041 strings::StrAppend(result, "[");
1042 flag = true;
1043 }
1044 // As for each element, print the sub-dim.
1045 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1046 result);
1047 if (*data_index < limit || flag) {
1048 strings::StrAppend(result, "]");
1049 flag = false;
1050 }
1051 }
1052 }
1053
1054 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1055 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1056 if (dim_index == num_dims - 1) {
1057 strings::StrAppend(result, " ");
1058 return;
1059 }
1060 for (int j = 0; j < num_dims - dim_index - 1; j++) {
1061 strings::StrAppend(result, "\n");
1062 }
1063 for (int j = 0; j <= dim_index; j++) {
1064 strings::StrAppend(result, " ");
1065 }
1066 }
1067
1068 // Print from left dim to right dim recursively.
1069 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1070 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1071 int64 num_elts_at_ends, int num_dims, const T* data,
1072 int64 data_index, string* result) {
1073 // We have recursed beyond all the dimensions into a single element
1074 // of the tensor.
1075 if (dim_index == num_dims) {
1076 strings::StrAppend(result, PrintOneElement(data[data_index], true));
1077 return;
1078 }
1079
1080 strings::StrAppend(result, "[");
1081 int64 element_count = shape[dim_index];
1082 int64 start_of_end =
1083 std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1084
1085 // Loop every element of one dim.
1086 int64 elements_per_iter = 1;
1087 for (int i = dim_index + 1; i < num_dims; i++) {
1088 elements_per_iter *= shape[i];
1089 }
1090 for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1091 if (i > 0) {
1092 PrintDimSpacing(dim_index, num_dims, result);
1093 }
1094
1095 // As for each element, print the sub-dim.
1096 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1097 data_index + elements_per_iter * i, result);
1098 }
1099 if (element_count > 2 * num_elts_at_ends) {
1100 PrintDimSpacing(dim_index, num_dims, result);
1101 strings::StrAppend(result, "...");
1102 }
1103 for (int64 i = start_of_end; i < element_count; i++) {
1104 // As for each element, print the sub-dim.
1105 PrintDimSpacing(dim_index, num_dims, result);
1106 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1107 data_index + elements_per_iter * i, result);
1108 }
1109
1110 strings::StrAppend(result, "]");
1111 }
1112
1113 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1114 string SummarizeArray(int64 limit, int64 num_elts,
1115 const TensorShape& tensor_shape, const char* data,
1116 const bool print_v2) {
1117 string ret;
1118 const T* array = reinterpret_cast<const T*>(data);
1119
1120 const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1121 if (shape.empty()) {
1122 for (int64 i = 0; i < limit; ++i) {
1123 if (i > 0) strings::StrAppend(&ret, " ");
1124 strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1125 }
1126 if (num_elts > limit) strings::StrAppend(&ret, "...");
1127 return ret;
1128 }
1129 if (print_v2) {
1130 const int num_dims = tensor_shape.dims();
1131 PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1132 } else {
1133 int64 data_index = 0;
1134 const int shape_size = tensor_shape.dims();
1135 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1136
1137 if (num_elts > limit) strings::StrAppend(&ret, "...");
1138 }
1139
1140 return ret;
1141 }
1142 } // namespace
1143
SummarizeValue(int64 max_entries,bool print_v2) const1144 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1145 const int64 num_elts = NumElements();
1146 if (max_entries < 0) {
1147 max_entries = num_elts;
1148 }
1149 size_t limit = std::min(max_entries, num_elts);
1150 if ((limit > 0) && (buf_ == nullptr)) {
1151 return strings::StrCat("uninitialized Tensor of ", num_elts,
1152 " elements of type ", dtype());
1153 }
1154 const char* data = limit > 0 ? tensor_data().data() : nullptr;
1155 switch (dtype()) {
1156 case DT_BFLOAT16:
1157 return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1158 break;
1159 case DT_HALF:
1160 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1161 print_v2);
1162 break;
1163 case DT_FLOAT:
1164 return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1165 break;
1166 case DT_DOUBLE:
1167 return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1168 break;
1169 case DT_UINT32:
1170 return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1171 break;
1172 case DT_INT32:
1173 return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1174 break;
1175 case DT_UINT8:
1176 case DT_QUINT8:
1177 return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1178 break;
1179 case DT_UINT16:
1180 case DT_QUINT16:
1181 return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1182 break;
1183 case DT_INT16:
1184 case DT_QINT16:
1185 return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1186 break;
1187 case DT_INT8:
1188 case DT_QINT8:
1189 return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1190 break;
1191 case DT_UINT64:
1192 return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1193 break;
1194 case DT_INT64:
1195 return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1196 break;
1197 case DT_BOOL:
1198 // TODO(tucker): Is it better to emit "True False..."? This
1199 // will emit "1 0..." which is more compact.
1200 return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1201 break;
1202 case DT_STRING:
1203 return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1204 break;
1205 default: {
1206 // All irregular cases
1207 string ret;
1208 if (print_v2) {
1209 strings::StrAppend(&ret, "[");
1210 }
1211 // TODO(irving): Don't call flat every time around this
1212 // loop.
1213 for (size_t i = 0; i < limit; ++i) {
1214 if (i > 0) strings::StrAppend(&ret, " ");
1215 switch (dtype()) {
1216 case DT_VARIANT: {
1217 const Variant& v = flat<Variant>()(i);
1218 strings::StrAppend(&ret, v.DebugString());
1219 } break;
1220 default:
1221 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1222 // complex64, quantized).
1223 strings::StrAppend(&ret, "?");
1224 }
1225 }
1226 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1227 if (print_v2) {
1228 strings::StrAppend(&ret, "]");
1229 }
1230 return ret;
1231 }
1232 }
1233 }
1234
tensor_data() const1235 StringPiece Tensor::tensor_data() const {
1236 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1237 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1238 }
1239
data() const1240 void* Tensor::data() const {
1241 if (buf_ == nullptr) return nullptr; // Don't die for empty tensors
1242 return static_cast<void*>(buf_->data());
1243 }
1244
SharesBufferWith(const Tensor & b) const1245 bool Tensor::SharesBufferWith(const Tensor& b) const {
1246 return buf_ != nullptr && b.buf_ != nullptr &&
1247 buf_->root_buffer() == b.buf_->root_buffer();
1248 }
1249
DebugString(int num_values) const1250 string Tensor::DebugString(int num_values) const {
1251 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1252 " shape: ", shape().DebugString(),
1253 " values: ", SummarizeValue(num_values), ">");
1254 }
1255
DeviceSafeDebugString() const1256 string Tensor::DeviceSafeDebugString() const {
1257 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1258 " shape: ", shape().DebugString(), ">");
1259 }
1260
FillDescription(TensorDescription * description) const1261 void Tensor::FillDescription(TensorDescription* description) const {
1262 description->set_dtype(dtype());
1263 shape().AsProto(description->mutable_shape());
1264 if (buf_ != nullptr && buf_->data() != nullptr) {
1265 buf_->FillAllocationDescription(
1266 description->mutable_allocation_description());
1267 }
1268 }
1269
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1270 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1271 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1272 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1273 int64 offset = orig.size() - num_out_dims;
1274 for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1275 const int64 in_dim = out_dim + offset;
1276 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1277 }
1278 for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1279 out_dims[0] *= orig[in_dim];
1280 }
1281 return out_dims;
1282 }
1283
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1284 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1285 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1286 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1287 for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1288 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1289 }
1290 for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1291 out_dims[num_out_dims - 1] *= orig[in_dim];
1292 }
1293 return out_dims;
1294 }
1295
1296 } // namespace tensorflow
1297