1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 // The array is allocated by the given allocator. It runs T's
23 // default constructors and destructors when T is not a simple type
24 // (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T. The routines
27 // includes running the constructor and destructor of T[], encoding
28 // an decoding T[] into/from a Cord, etc.
29
30 #include "tensorflow/core/framework/tensor.h"
31
32 #include "absl/strings/escaping.h"
33 #include "tensorflow/core/framework/allocation_description.pb.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/resource_handle.pb.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_description.pb.h"
38 #include "tensorflow/core/framework/type_traits.h"
39 #include "tensorflow/core/framework/typed_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/variant.h"
42 #include "tensorflow/core/framework/variant_encode_decode.h"
43 #include "tensorflow/core/framework/variant_op_registry.h"
44 #include "tensorflow/core/framework/variant_tensor_data.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/tensor_coding.h"
55 #include "tensorflow/core/platform/types.h"
56
57 namespace tensorflow {
58
59 // Allow Tensors to be stored inside Variants with automatic
60 // encoding/decoding when those Variants are themselves being decoded
61 // in a Tensor's FromProto.
62 //
63 // NOTE(mrry): The corresponding "copy function" registrations can be found in
64 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
65 // code).
66 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
67
GetAllocatedBytes(size_t * out_bytes) const68 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
69 AllocationDescription allocation_description;
70 FillAllocationDescription(&allocation_description);
71 if (allocation_description.allocated_bytes() > 0) {
72 *out_bytes = allocation_description.allocated_bytes();
73 return true;
74 } else {
75 return false;
76 }
77 }
78
79 namespace {
80
81 // An un-templated base class for Buffer.
82 class BufferBase : public TensorBuffer {
83 public:
BufferBase(Allocator * alloc,void * data_ptr)84 explicit BufferBase(Allocator* alloc, void* data_ptr)
85 : TensorBuffer(data_ptr), alloc_(alloc) {}
86
root_buffer()87 TensorBuffer* root_buffer() override { return this; }
88
GetAllocatedBytes(size_t * out_bytes) const89 bool GetAllocatedBytes(size_t* out_bytes) const override {
90 if (alloc_->TracksAllocationSizes()) {
91 *out_bytes = alloc_->AllocatedSize(data());
92 return *out_bytes > 0;
93 } else {
94 return false;
95 }
96 }
97
FillAllocationDescription(AllocationDescription * proto) const98 void FillAllocationDescription(AllocationDescription* proto) const override {
99 void* data_ptr = data();
100 int64 rb = size();
101 proto->set_requested_bytes(rb);
102 proto->set_allocator_name(alloc_->Name());
103 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
104 if (alloc_->TracksAllocationSizes()) {
105 int64 ab = alloc_->AllocatedSize(data_ptr);
106 proto->set_allocated_bytes(ab);
107 int64 id = alloc_->AllocationId(data_ptr);
108 if (id > 0) {
109 proto->set_allocation_id(id);
110 }
111 if (RefCountIsOne()) {
112 proto->set_has_single_reference(true);
113 }
114 }
115 }
116
117 protected:
RecordDeallocation()118 void RecordDeallocation() {
119 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
120 alloc_->Name());
121 }
122
123 Allocator* const alloc_;
124 };
125
126 // Typed ref-counted buffer: T[n].
127 template <typename T>
128 class Buffer : public BufferBase {
129 public:
130 Buffer(Allocator* a, int64 n);
131 Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
132
size() const133 size_t size() const override { return sizeof(T) * elem_; }
134
135 private:
136 int64 elem_;
137
138 ~Buffer() override;
139
140 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
141 };
142
LogUnexpectedSize(int64 actual,int64 expected)143 void LogUnexpectedSize(int64 actual, int64 expected) {
144 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
145 }
146
MemoryLoggingEnabled()147 bool MemoryLoggingEnabled() {
148 static bool memory_logging_enabled = LogMemory::IsEnabled();
149 return memory_logging_enabled;
150 }
151
152 // A set of helper functions depending on T.
153 template <typename T>
154 struct Helper {
155 // By default, we assume T is a simple type (float, int32, etc.)
156 static_assert(is_simple_type<T>::value, "T is not a simple type.");
157 typedef protobuf::RepeatedField<T> RepeatedFieldType;
158
159 // Encoder of simple type T to a string. We do a copy.
160 template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper161 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
162 DCHECK_EQ(in->size(), sizeof(T) * n);
163 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
164 out);
165 }
166
167 // Decoder of simple type T. Copy the bytes from "in" into the
168 // tensor buffer.
169 template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper170 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
171 if (in.size() != sizeof(T) * n) {
172 LogUnexpectedSize(in.size(), sizeof(T) * n);
173 return nullptr;
174 }
175 Buffer<T>* buf = new Buffer<T>(a, n);
176 char* data = buf->template base<char>();
177 if (data == nullptr) {
178 buf->Unref();
179 return nullptr;
180 }
181 port::CopyToArray(in, data);
182 return buf;
183 }
184
185 // Memory usage.
TotalBytestensorflow::__anond06fe7e40111::Helper186 static int64 TotalBytes(TensorBuffer* in, int64 n) {
187 DCHECK_EQ(in->size(), sizeof(T) * n);
188 return in->size();
189 }
190 };
191
192 // Helper specialization for string (the only non-simple type we
193 // support).
194 template <>
195 struct Helper<tstring> {
196 // Proto message uses RepeatedFieldType to hold repeated T.
197 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
198
199 // Encodes "n" elements of type string stored in "in" into Cord
200 // "out", which is usually the TensorProto::tensor_content.
201 template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper202 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
203 port::EncodeStringList(in->base<const tstring>(), n, out);
204 }
205
206 // Decodes "n" elements of type string from "in" and constructs a
207 // buffer out of it. Returns nullptr if the decoding fails. "in" is
208 // usually the TensorProto::tensor_content.
209 template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper210 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
211 Buffer<tstring>* buf = new Buffer<tstring>(a, n);
212 tstring* strings = buf->template base<tstring>();
213 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
214 buf->Unref();
215 return nullptr;
216 }
217 return buf;
218 }
219
220 // Returns the estimated memory usage of "n" elements of type T
221 // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper222 static int64 TotalBytes(TensorBuffer* in, int n) {
223 int64 tot = in->size();
224 DCHECK_EQ(tot, sizeof(tstring) * n);
225 const tstring* p = in->base<const tstring>();
226 for (int i = 0; i < n; ++i, ++p) tot += p->size();
227 return tot;
228 }
229 };
230
231 template <>
232 struct Helper<ResourceHandle> {
233 // Proto message uses RepeatedFieldType to hold repeated T.
234 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
235
236 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
237 // "out", which is usually the TensorProto::tensor_content.
238 template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper239 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240 EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
241 port::NewStringListEncoder(out));
242 }
243
244 // Decodes "n" elements of type string from "in" and constructs a
245 // buffer out of it. Returns nullptr if the decoding fails. "in" is
246 // usually the TensorProto::tensor_content.
247 template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper248 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
249 auto* buf = new Buffer<ResourceHandle>(a, n);
250 ResourceHandle* ps = buf->template base<ResourceHandle>();
251 if (ps == nullptr ||
252 !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
253 buf->Unref();
254 return nullptr;
255 }
256 return buf;
257 }
258
259 // Returns the estimated memory usage of "n" elements of type T
260 // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper261 static int64 TotalBytes(TensorBuffer* in, int n) {
262 return n * sizeof(ResourceHandle);
263 }
264 };
265
266 template <>
267 struct Helper<Variant> {
268 // Encodes "n" elements of type Variant stored in "in" into destination
269 // "out", which is usually the TensorProto::tensor_content.
270 template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper271 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
272 EncodeVariantList(in->base<const Variant>(), n,
273 port::NewStringListEncoder(out));
274 }
275
276 // Decodes "n" elements of type Variant from "in" and constructs a
277 // buffer out of it. Returns nullptr if the decoding fails. "in" is
278 // usually the TensorProto::tensor_content.
279 template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper280 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
281 auto* buf = new Buffer<Variant>(a, n);
282 Variant* ps = buf->template base<Variant>();
283 if (ps == nullptr ||
284 !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
285 buf->Unref();
286 return nullptr;
287 }
288 return buf;
289 }
290
291 // Returns the estimated memory usage of "n" elements of type T
292 // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper293 static int64 TotalBytes(TensorBuffer* in, int n) {
294 return n * sizeof(Variant);
295 }
296 };
297
298 template <typename T>
299 struct ProtoHelper {};
300
301 // For a C++ type "T" (float, double, int32, etc.), the repeated field
302 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
303 // int32, string, etc) in the TensorProto is used for serializing the
304 // tensor of type "T".
305 #define PROTO_TRAITS(T, F, N) \
306 template <> \
307 struct ProtoHelper<T> { \
308 typedef Helper<F>::RepeatedFieldType FieldType; \
309 static FieldType::const_iterator Begin(const TensorProto& proto) { \
310 return proto.N##_val().begin(); \
311 } \
312 static size_t NumElements(const TensorProto& proto) { \
313 return proto.N##_val().size(); \
314 } \
315 static void Fill(const T* data, size_t n, TensorProto* proto) { \
316 typename ProtoHelper<T>::FieldType copy(data, data + n); \
317 proto->mutable_##N##_val()->Swap(©); \
318 } \
319 };
320 PROTO_TRAITS(float, float, float);
321 PROTO_TRAITS(double, double, double);
322 PROTO_TRAITS(int32, int32, int);
323 PROTO_TRAITS(uint8, int32, int);
324 PROTO_TRAITS(uint16, int32, int);
325 PROTO_TRAITS(uint32, uint32, uint32);
326 PROTO_TRAITS(int16, int32, int);
327 PROTO_TRAITS(int8, int32, int);
328 PROTO_TRAITS(bool, bool, bool);
329 PROTO_TRAITS(tstring, tstring, string);
330 PROTO_TRAITS(qint8, int32, int);
331 PROTO_TRAITS(quint8, int32, int);
332 PROTO_TRAITS(qint16, int32, int);
333 PROTO_TRAITS(quint16, int32, int);
334 #undef PROTO_TRAITS
335
336 template <>
337 struct ProtoHelper<int64> {
Begintensorflow::__anond06fe7e40111::ProtoHelper338 static const int64* Begin(const TensorProto& proto) {
339 return reinterpret_cast<const int64*>(proto.int64_val().begin());
340 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper341 static size_t NumElements(const TensorProto& proto) {
342 return proto.int64_val().size();
343 }
Filltensorflow::__anond06fe7e40111::ProtoHelper344 static void Fill(const int64* data, size_t n, TensorProto* proto) {
345 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
346 proto->mutable_int64_val()->Swap(©);
347 }
348 };
349
350 template <>
351 struct ProtoHelper<uint64> {
Begintensorflow::__anond06fe7e40111::ProtoHelper352 static const uint64* Begin(const TensorProto& proto) {
353 return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
354 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper355 static size_t NumElements(const TensorProto& proto) {
356 return proto.uint64_val().size();
357 }
Filltensorflow::__anond06fe7e40111::ProtoHelper358 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
359 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
360 proto->mutable_uint64_val()->Swap(©);
361 }
362 };
363
364 template <>
365 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anond06fe7e40111::ProtoHelper366 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
367 const TensorProto& proto) {
368 return proto.resource_handle_val().begin();
369 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper370 static size_t NumElements(const TensorProto& proto) {
371 return proto.resource_handle_val().size();
372 }
Filltensorflow::__anond06fe7e40111::ProtoHelper373 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
374 auto* handles = proto->mutable_resource_handle_val();
375 handles->Clear();
376 for (size_t i = 0; i < n; i++) {
377 data[i].AsProto(handles->Add());
378 }
379 }
380 };
381
382 template <>
383 struct ProtoHelper<Variant> {
384 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anond06fe7e40111::ProtoHelper385 Begin(const TensorProto& proto) {
386 return proto.variant_val().begin();
387 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper388 static size_t NumElements(const TensorProto& proto) {
389 return proto.variant_val().size();
390 }
Filltensorflow::__anond06fe7e40111::ProtoHelper391 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
392 auto* variant_values = proto->mutable_variant_val();
393 variant_values->Clear();
394 for (size_t i = 0; i < n; ++i) {
395 VariantTensorData tmp;
396 data[i].Encode(&tmp);
397 tmp.ToProto(variant_values->Add());
398 }
399 }
400 };
401
402 template <>
403 struct ProtoHelper<complex64> {
404 typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper405 static const complex64* Begin(const TensorProto& proto) {
406 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
407 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper408 static size_t NumElements(const TensorProto& proto) {
409 return proto.scomplex_val().size() / 2;
410 }
Filltensorflow::__anond06fe7e40111::ProtoHelper411 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
412 const float* p = reinterpret_cast<const float*>(data);
413 FieldType copy(p, p + n * 2);
414 proto->mutable_scomplex_val()->Swap(©);
415 }
416 };
417
418 template <>
419 struct ProtoHelper<complex128> {
420 typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper421 static const complex128* Begin(const TensorProto& proto) {
422 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
423 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper424 static size_t NumElements(const TensorProto& proto) {
425 return proto.dcomplex_val().size() / 2;
426 }
Filltensorflow::__anond06fe7e40111::ProtoHelper427 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
428 const double* p = reinterpret_cast<const double*>(data);
429 FieldType copy(p, p + n * 2);
430 proto->mutable_dcomplex_val()->Swap(©);
431 }
432 };
433
434 template <>
435 struct ProtoHelper<qint32> {
436 typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper437 static const qint32* Begin(const TensorProto& proto) {
438 return reinterpret_cast<const qint32*>(proto.int_val().data());
439 }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper440 static size_t NumElements(const TensorProto& proto) {
441 return proto.int_val().size();
442 }
Filltensorflow::__anond06fe7e40111::ProtoHelper443 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
444 const int32* p = reinterpret_cast<const int32*>(data);
445 FieldType copy(p, p + n);
446 proto->mutable_int_val()->Swap(©);
447 }
448 };
449
450 template <>
451 struct ProtoHelper<bfloat16> {
Filltensorflow::__anond06fe7e40111::ProtoHelper452 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
453 proto->mutable_half_val()->Reserve(n);
454 for (size_t i = 0; i < n; ++i) {
455 proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
456 }
457 }
458 };
459
460 template <>
461 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anond06fe7e40111::ProtoHelper462 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
463 proto->mutable_half_val()->Reserve(n);
464 for (size_t i = 0; i < n; ++i) {
465 proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
466 }
467 }
468 };
469
470 template <typename T>
Buffer(Allocator * a,int64 n)471 Buffer<T>::Buffer(Allocator* a, int64 n)
472 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
473 elem_(n) {}
474
475 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)476 Buffer<T>::Buffer(Allocator* a, int64 n,
477 const AllocationAttributes& allocation_attr)
478 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
479 elem_(n) {}
480
481 template <typename T>
~Buffer()482 Buffer<T>::~Buffer() {
483 if (data()) {
484 if (MemoryLoggingEnabled()) {
485 RecordDeallocation();
486 }
487 TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
488 }
489 }
490
491 // Allocates a T[n] buffer. Fills in the buffer with repeated values
492 // in "in". If "in" has less values than "n", fills the rest of T[n]
493 // with the last value. If "in" has no values, fills T[n] with the
494 // default value for T.
495 //
496 // This routine is using the typed fields (float_val, etc.) in the
497 // tensor proto as opposed to the untyped binary representation
498 // (tensor_content). This is used when we expect the TensorProto is
499 // used by a client program which may not know how to encode a tensor
500 // in the compact binary representation.
501 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)502 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
503 CHECK_GT(n, 0);
504 Buffer<T>* buf = new Buffer<T>(a, n);
505 T* data = buf->template base<T>();
506 if (data == nullptr) {
507 buf->Unref();
508 return nullptr;
509 }
510
511 const int64 in_n = ProtoHelper<T>::NumElements(in);
512 if (in_n <= 0) {
513 std::fill_n(data, n, T());
514 } else {
515 auto begin = ProtoHelper<T>::Begin(in);
516 if (n <= in_n) {
517 std::copy_n(begin, n, data);
518 } else {
519 std::copy_n(begin, in_n, data);
520 if (std::is_trivially_copyable<T>::value) {
521 const T last = *(data + in_n - 1);
522 std::fill_n(data + in_n, n - in_n, last);
523 } else {
524 const T& last = *(data + in_n - 1);
525 std::fill_n(data + in_n, n - in_n, last);
526 }
527 }
528 }
529
530 return buf;
531 }
532
533 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)534 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
535 int64 n) {
536 CHECK_GT(n, 0);
537 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
538 Variant* data = buf->template base<Variant>();
539 if (data == nullptr) {
540 buf->Unref();
541 return nullptr;
542 }
543 const int64 in_n = ProtoHelper<Variant>::NumElements(in);
544 if (in_n <= 0) {
545 std::fill_n(data, n, Variant());
546 } else {
547 // If tensor shape says we have n < in_n elements in the output tensor
548 // then make sure to only decode the first n out of the in_n elements in the
549 // in tensors. In all other cases, we decode all in_n elements of in and set
550 // the remaining elements up to n to be the default Variant() value.
551 const int64 real_n = n < in_n ? n : in_n;
552 for (int64 i = 0; i < real_n; ++i) {
553 data[i] = in.variant_val(i);
554 if (!DecodeUnaryVariant(&data[i])) {
555 LOG(ERROR) << "Could not decode variant with type_name: \""
556 << data[i].TypeName()
557 << "\". Perhaps you forgot to register a "
558 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
559 buf->Unref();
560 return nullptr;
561 }
562 }
563 for (int64 i = in_n; i < n; ++i) {
564 data[i] = Variant();
565 }
566 }
567 return buf;
568 }
569
570 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
571 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
572 // we don't use ProtoHelper<uint16>).
573 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)574 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
575 int64 n) {
576 CHECK_GT(n, 0);
577 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
578 uint16* data = buf->template base<uint16>();
579 if (data == nullptr) {
580 buf->Unref();
581 return nullptr;
582 }
583 const int64 in_n = in.half_val().size();
584 auto begin = in.half_val().begin();
585 if (n <= in_n) {
586 std::copy_n(begin, n, data);
587 } else if (in_n > 0) {
588 std::copy_n(begin, in_n, data);
589 const uint16 last = *(data + in_n - 1);
590 std::fill_n(data + in_n, n - in_n, last);
591 } else {
592 std::fill_n(data, n, 0);
593 }
594 return buf;
595 }
596
597 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)598 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
599 int64 n) {
600 CHECK_GT(n, 0);
601 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
602 uint16* data = buf->template base<uint16>();
603 if (data == nullptr) {
604 buf->Unref();
605 return nullptr;
606 }
607 const int64 in_n = in.half_val().size();
608 auto begin = in.half_val().begin();
609 if (n <= in_n) {
610 std::copy_n(begin, n, data);
611 } else if (in_n > 0) {
612 std::copy_n(begin, in_n, data);
613 const uint16 last = *(data + in_n - 1);
614 std::fill_n(data + in_n, n - in_n, last);
615 } else {
616 std::fill_n(data, n, 0);
617 }
618 return buf;
619 }
620
621 // Copies T[n] stored in the buffer "in" into the repeated field in
622 // "out" corresponding to type T.
623 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)624 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
625 const T* data = in.base<const T>();
626 // NOTE: T may not the same as
627 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
628 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
629 // critical, we can specialize T=float and do memcpy directly.
630 ProtoHelper<T>::Fill(data, n, out);
631 }
632
RefIfNonNull(core::RefCounted * buf)633 void RefIfNonNull(core::RefCounted* buf) {
634 if (buf) buf->Ref();
635 }
636
UnrefIfNonNull(core::RefCounted * buf)637 void UnrefIfNonNull(core::RefCounted* buf) {
638 if (buf) buf->Unref();
639 }
640
641 } // end namespace
642
Tensor()643 Tensor::Tensor() : Tensor(DT_FLOAT) {}
644
Tensor(DataType type)645 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
646
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)647 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
648 : shape_(shape), buf_(buf) {
649 set_dtype(type);
650 RefIfNonNull(buf);
651 }
652
IsInitialized() const653 bool Tensor::IsInitialized() const {
654 return (buf_ != nullptr && buf_->data() != nullptr) ||
655 shape_.num_elements() == 0;
656 }
657
CheckType(DataType expected_dtype) const658 void Tensor::CheckType(DataType expected_dtype) const {
659 CHECK_EQ(dtype(), expected_dtype)
660 << " " << DataTypeString(expected_dtype) << " expected, got "
661 << DataTypeString(dtype());
662 }
663
CheckTypeAndIsAligned(DataType expected_dtype) const664 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
665 CHECK_EQ(dtype(), expected_dtype)
666 << " " << DataTypeString(expected_dtype) << " expected, got "
667 << DataTypeString(dtype());
668 CHECK(IsAligned()) << "ptr = " << base<void>();
669 }
670
CheckIsAlignedAndSingleElement() const671 void Tensor::CheckIsAlignedAndSingleElement() const {
672 CHECK(IsAligned()) << "Aligned and single element";
673 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
674 }
675
~Tensor()676 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
677
CopyFromInternal(const Tensor & other,const TensorShape & shape)678 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
679 CHECK_EQ(shape.num_elements(), other.NumElements());
680 // Data type will be overwritten if this == &other, since dtype is part of
681 // shape.
682 DataType other_dtype = other.dtype();
683 shape_ = shape;
684 set_dtype(other_dtype);
685 if (buf_ != other.buf_) {
686 UnrefIfNonNull(buf_);
687 buf_ = other.buf_;
688 RefIfNonNull(buf_);
689 }
690 }
691
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)692 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
693 const TensorShape& shape) {
694 int in_size = DataTypeSize(other.dtype());
695 int out_size = DataTypeSize(dtype);
696 if (in_size == 0) {
697 return errors::InvalidArgument("other tensor has zero-sized data type");
698 }
699 if (out_size == 0) {
700 return errors::InvalidArgument("specified output type is zero-sized");
701 }
702 if (shape.num_elements() * out_size !=
703 other.shape().num_elements() * in_size) {
704 return errors::InvalidArgument(
705 "input and output shapes/data type sizes are not compatible");
706 }
707 shape_ = shape;
708 shape_.set_data_type(dtype);
709 if (buf_ != other.buf_) {
710 UnrefIfNonNull(buf_);
711 buf_ = other.buf_;
712 RefIfNonNull(buf_);
713 }
714 return Status::OK();
715 }
716
717 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
718 // For the latter case, we have to make sure that the refcount is
719 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const720 bool Tensor::RefCountIsOne() const {
721 return buf_ != nullptr && buf_->RefCountIsOne() &&
722 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
723 }
724
725 // The macro CASES() expands to a switch statement conditioned on
726 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
727 #define SINGLE_ARG(...) __VA_ARGS__
728 #define CASE(TYPE, STMTS) \
729 case DataTypeToEnum<TYPE>::value: { \
730 typedef TYPE T; \
731 STMTS; \
732 break; \
733 }
734 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
735 switch (TYPE_ENUM) { \
736 CASE(float, SINGLE_ARG(STMTS)) \
737 CASE(double, SINGLE_ARG(STMTS)) \
738 CASE(int32, SINGLE_ARG(STMTS)) \
739 CASE(uint8, SINGLE_ARG(STMTS)) \
740 CASE(uint16, SINGLE_ARG(STMTS)) \
741 CASE(uint32, SINGLE_ARG(STMTS)) \
742 CASE(uint64, SINGLE_ARG(STMTS)) \
743 CASE(int16, SINGLE_ARG(STMTS)) \
744 CASE(int8, SINGLE_ARG(STMTS)) \
745 CASE(tstring, SINGLE_ARG(STMTS)) \
746 CASE(complex64, SINGLE_ARG(STMTS)) \
747 CASE(complex128, SINGLE_ARG(STMTS)) \
748 CASE(int64, SINGLE_ARG(STMTS)) \
749 CASE(bool, SINGLE_ARG(STMTS)) \
750 CASE(qint32, SINGLE_ARG(STMTS)) \
751 CASE(quint8, SINGLE_ARG(STMTS)) \
752 CASE(qint8, SINGLE_ARG(STMTS)) \
753 CASE(quint16, SINGLE_ARG(STMTS)) \
754 CASE(qint16, SINGLE_ARG(STMTS)) \
755 CASE(bfloat16, SINGLE_ARG(STMTS)) \
756 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
757 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
758 CASE(Variant, SINGLE_ARG(STMTS)) \
759 case DT_INVALID: \
760 INVALID; \
761 break; \
762 default: \
763 DEFAULT; \
764 break; \
765 }
766
767 #define CASES(TYPE_ENUM, STMTS) \
768 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
769 , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
770
Tensor(Allocator * a,DataType type,const TensorShape & shape)771 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
772 : shape_(shape), buf_(nullptr) {
773 set_dtype(type);
774 CHECK_NOTNULL(a);
775 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
776 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
777 }
778 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
779 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
780 *this);
781 }
782 }
783
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)784 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
785 const AllocationAttributes& allocation_attr)
786 : shape_(shape), buf_(nullptr) {
787 set_dtype(type);
788 CHECK_NOTNULL(a);
789 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
790 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
791 }
792 if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
793 buf_ != nullptr && buf_->data() != nullptr) {
794 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
795 LogMemory::UNKNOWN_STEP_ID, *this);
796 }
797 }
798
799 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
800 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
801 // acquiring a lock, which guards initialization of the per-NUMA zone
802 // allocators, and becomes highly contended.
803 //
804 // Note also that it would be better if all Tensor allocations required the user
805 // to specify an allocator, for purposes of accounting, etc. However, the
806 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()807 static Allocator* get_default_cpu_allocator() {
808 static Allocator* default_cpu_allocator =
809 cpu_allocator(port::kNUMANoAffinity);
810 return default_cpu_allocator;
811 }
812
Tensor(DataType type,const TensorShape & shape)813 Tensor::Tensor(DataType type, const TensorShape& shape)
814 : Tensor(get_default_cpu_allocator(), type, shape) {}
815
GetAllocatedBytes(size_t * out_bytes) const816 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
817 size_t* out_bytes) const {
818 // `this->FillAllocationDescription()` never sets allocated bytes information,
819 // so we can short-circuit the construction of an `AllocationDescription`.
820 return false;
821 }
822
FillAllocationDescription(AllocationDescription * proto) const823 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
824 AllocationDescription* proto) const {
825 proto->set_requested_bytes(size());
826 proto->set_allocator_name("HostScalarTensorBuffer");
827 proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
828 }
829
830 template <typename T>
831 class SubBuffer : public TensorBuffer {
832 public:
833 // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)834 SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
835 : TensorBuffer(buf->base<T>() + delta),
836 root_(buf->root_buffer()),
837 elem_(n) {
838 // Sanity check. The caller should ensure the sub buffer is valid.
839 CHECK_LE(root_->base<T>(), this->base<T>());
840 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
841 CHECK_LE(this->base<T>(), root_limit);
842 CHECK_LE(this->base<T>() + n, root_limit);
843 // Hold a ref of the underlying root buffer.
844 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
845 root_->Ref();
846 }
847
size() const848 size_t size() const override { return sizeof(T) * elem_; }
root_buffer()849 TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const850 bool GetAllocatedBytes(size_t* out_bytes) const override {
851 return root_->GetAllocatedBytes(out_bytes);
852 }
FillAllocationDescription(AllocationDescription * proto) const853 void FillAllocationDescription(AllocationDescription* proto) const override {
854 root_->FillAllocationDescription(proto);
855 }
856
857 private:
858 TensorBuffer* root_;
859 int64 elem_;
860
~SubBuffer()861 ~SubBuffer() override { root_->Unref(); }
862
863 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
864 };
865
Slice(int64 start,int64 limit) const866 Tensor Tensor::Slice(int64 start, int64 limit) const {
867 CHECK_GE(dims(), 1);
868 CHECK_LE(0, start);
869 CHECK_LE(start, limit);
870 int64 dim0_size = shape_.dim_size(0);
871 CHECK_LE(limit, dim0_size);
872 if ((start == 0) && (limit == dim0_size)) {
873 return *this;
874 }
875 Tensor ret;
876 ret.shape_ = shape_;
877 ret.set_dtype(dtype());
878 ret.buf_ = nullptr;
879 if (dim0_size > 0) {
880 const int64 elems_per_dim0 = NumElements() / dim0_size;
881 const int64 delta = start * elems_per_dim0;
882 dim0_size = limit - start;
883 ret.shape_.set_dim(0, dim0_size);
884 const int64 num_elems = dim0_size * elems_per_dim0;
885 if (buf_) {
886 DataType dt = dtype();
887 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
888 }
889 }
890 return ret;
891 }
892
SubSlice(int64 index) const893 Tensor Tensor::SubSlice(int64 index) const {
894 CHECK_GE(dims(), 1); // Crash ok.
895 CHECK_LE(0, index); // Crash ok.
896 int64 dim0_size = shape_.dim_size(0);
897 CHECK_LE(index, dim0_size); // Crash ok.
898 Tensor ret;
899 ret.shape_ = shape_;
900 ret.shape_.RemoveDim(0);
901 ret.set_dtype(dtype());
902 ret.buf_ = nullptr;
903 if (dim0_size > 0) {
904 const int64 elems_per_dim0 = NumElements() / dim0_size;
905 const int64 delta = index * elems_per_dim0;
906 const int64 num_elems = elems_per_dim0;
907 if (buf_) {
908 DataType dt = dtype();
909 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
910 }
911 }
912 return ret;
913 }
914
FromProto(const TensorProto & proto)915 bool Tensor::FromProto(const TensorProto& proto) {
916 return FromProto(get_default_cpu_allocator(), proto);
917 }
918
FromProto(Allocator * a,const TensorProto & proto)919 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
920 CHECK_NOTNULL(a);
921 TensorBuffer* p = nullptr;
922 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
923 if (proto.dtype() == DT_INVALID) return false;
924 TensorShape shape(proto.tensor_shape());
925 const int64 N = shape.num_elements();
926 if (N > 0 && proto.dtype()) {
927 bool dtype_error = false;
928 if (!proto.tensor_content().empty()) {
929 const auto& content = proto.tensor_content();
930 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
931 dtype_error = true, dtype_error = true);
932 } else {
933 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
934 dtype_error = true, dtype_error = true);
935 }
936 if (dtype_error || p == nullptr) return false;
937 }
938 shape_ = shape;
939 set_dtype(proto.dtype());
940 UnrefIfNonNull(buf_);
941 buf_ = p;
942 // TODO(misard) add tracking of which kernels and steps are calling
943 // FromProto.
944 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
945 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
946 LogMemory::UNKNOWN_STEP_ID, *this);
947 }
948 return true;
949 }
950
AsProtoField(TensorProto * proto) const951 void Tensor::AsProtoField(TensorProto* proto) const {
952 proto->Clear();
953 shape_.AsProto(proto->mutable_tensor_shape());
954 proto->set_dtype(dtype());
955 if (buf_) {
956 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
957 }
958 }
959
AsProtoTensorContent(TensorProto * proto) const960 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
961 proto->Clear();
962 proto->set_dtype(dtype());
963 shape_.AsProto(proto->mutable_tensor_shape());
964 if (buf_) {
965 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
966 proto->mutable_tensor_content()));
967 }
968 }
969
TotalBytes() const970 size_t Tensor::TotalBytes() const {
971 if (shape_.num_elements() == 0) return 0;
972 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
973 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
974 return 0; // Makes compiler happy.
975 }
976
AllocatedBytes() const977 size_t Tensor::AllocatedBytes() const {
978 if (buf_) {
979 size_t ret;
980 if (buf_->GetAllocatedBytes(&ret)) {
981 return ret;
982 }
983 }
984 return TotalBytes();
985 }
986
CanUseDMA() const987 bool Tensor::CanUseDMA() const {
988 CASES(dtype(), return is_simple_type<T>::value);
989 return false; // Makes compiler happy.
990 }
991
992 #undef CASES
993 #undef CASE
994
995 namespace {
996
997 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
998 // we would like to keep them compatible with their absl counterparts, for ease
999 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
1000 // logic is so simple we can just replicate it here, where it is close to its
1001 // usage and easy to change later. And there's the extra benefit of not
1002 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)1003 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1004 bool print_v2) {
1005 return a;
1006 }
PrintOneElement(const tstring & a,bool print_v2)1007 inline string PrintOneElement(const tstring& a, bool print_v2) {
1008 if (print_v2) {
1009 return "\"" + absl::CEscape(a) + "\"";
1010 } else {
1011 return absl::CEscape(a);
1012 }
1013 }
PrintOneElement(const Eigen::half & h,bool print_v2)1014 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1015 return static_cast<float>(h);
1016 }
1017
PrintOneElement(bfloat16 f,bool print_v2)1018 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1019 return static_cast<float>(f);
1020 }
1021
1022 // Print from left dim to right dim recursively.
1023 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)1024 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1025 int64 limit, int shape_size, const T* data, int64* data_index,
1026 string* result) {
1027 if (*data_index >= limit) return;
1028 int64 element_count = shape[dim_index];
1029 // We have reached the right-most dimension of the tensor.
1030 if (dim_index == shape_size - 1) {
1031 for (int64 i = 0; i < element_count; i++) {
1032 if (*data_index >= limit) {
1033 // If not enough elements has been printed, append "...".
1034 if (dim_index != 0) {
1035 strings::StrAppend(result, "...");
1036 }
1037 return;
1038 }
1039 if (i > 0) strings::StrAppend(result, " ");
1040 strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1041 }
1042 return;
1043 }
1044 // Loop every element of one dim.
1045 for (int64 i = 0; i < element_count; i++) {
1046 bool flag = false;
1047 if (*data_index < limit) {
1048 strings::StrAppend(result, "[");
1049 flag = true;
1050 }
1051 // As for each element, print the sub-dim.
1052 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1053 result);
1054 if (*data_index < limit || flag) {
1055 strings::StrAppend(result, "]");
1056 flag = false;
1057 }
1058 }
1059 }
1060
1061 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1062 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1063 if (dim_index == num_dims - 1) {
1064 strings::StrAppend(result, " ");
1065 return;
1066 }
1067 for (int j = 0; j < num_dims - dim_index - 1; j++) {
1068 strings::StrAppend(result, "\n");
1069 }
1070 for (int j = 0; j <= dim_index; j++) {
1071 strings::StrAppend(result, " ");
1072 }
1073 }
1074
1075 // Print from left dim to right dim recursively.
1076 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1077 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1078 int64 num_elts_at_ends, int num_dims, const T* data,
1079 int64 data_index, string* result) {
1080 // We have recursed beyond all the dimensions into a single element
1081 // of the tensor.
1082 if (dim_index == num_dims) {
1083 strings::StrAppend(result, PrintOneElement(data[data_index], true));
1084 return;
1085 }
1086
1087 strings::StrAppend(result, "[");
1088 int64 element_count = shape[dim_index];
1089 int64 start_of_end =
1090 std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1091
1092 // Loop every element of one dim.
1093 int64 elements_per_iter = 1;
1094 for (int i = dim_index + 1; i < num_dims; i++) {
1095 elements_per_iter *= shape[i];
1096 }
1097 for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1098 if (i > 0) {
1099 PrintDimSpacing(dim_index, num_dims, result);
1100 }
1101
1102 // As for each element, print the sub-dim.
1103 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1104 data_index + elements_per_iter * i, result);
1105 }
1106 if (element_count > 2 * num_elts_at_ends) {
1107 PrintDimSpacing(dim_index, num_dims, result);
1108 strings::StrAppend(result, "...");
1109 }
1110 for (int64 i = start_of_end; i < element_count; i++) {
1111 // As for each element, print the sub-dim.
1112 PrintDimSpacing(dim_index, num_dims, result);
1113 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1114 data_index + elements_per_iter * i, result);
1115 }
1116
1117 strings::StrAppend(result, "]");
1118 }
1119
1120 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1121 string SummarizeArray(int64 limit, int64 num_elts,
1122 const TensorShape& tensor_shape, const char* data,
1123 const bool print_v2) {
1124 string ret;
1125 const T* array = reinterpret_cast<const T*>(data);
1126
1127 const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1128 if (shape.empty()) {
1129 for (int64 i = 0; i < limit; ++i) {
1130 if (i > 0) strings::StrAppend(&ret, " ");
1131 strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1132 }
1133 if (num_elts > limit) strings::StrAppend(&ret, "...");
1134 return ret;
1135 }
1136 if (print_v2) {
1137 const int num_dims = tensor_shape.dims();
1138 PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1139 } else {
1140 int64 data_index = 0;
1141 const int shape_size = tensor_shape.dims();
1142 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1143
1144 if (num_elts > limit) strings::StrAppend(&ret, "...");
1145 }
1146
1147 return ret;
1148 }
1149 } // namespace
1150
SummarizeValue(int64 max_entries,bool print_v2) const1151 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1152 const int64 num_elts = NumElements();
1153 if (max_entries < 0) {
1154 max_entries = num_elts;
1155 }
1156 size_t limit = std::min(max_entries, num_elts);
1157 if ((limit > 0) && (buf_ == nullptr)) {
1158 return strings::StrCat("uninitialized Tensor of ", num_elts,
1159 " elements of type ", dtype());
1160 }
1161 const char* data = limit > 0 ? tensor_data().data() : nullptr;
1162 switch (dtype()) {
1163 case DT_BFLOAT16:
1164 return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1165 break;
1166 case DT_HALF:
1167 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1168 print_v2);
1169 break;
1170 case DT_FLOAT:
1171 return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1172 break;
1173 case DT_DOUBLE:
1174 return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1175 break;
1176 case DT_UINT32:
1177 return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1178 break;
1179 case DT_INT32:
1180 return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1181 break;
1182 case DT_UINT8:
1183 case DT_QUINT8:
1184 return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1185 break;
1186 case DT_UINT16:
1187 case DT_QUINT16:
1188 return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1189 break;
1190 case DT_INT16:
1191 case DT_QINT16:
1192 return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1193 break;
1194 case DT_INT8:
1195 case DT_QINT8:
1196 return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1197 break;
1198 case DT_UINT64:
1199 return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1200 break;
1201 case DT_INT64:
1202 return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1203 break;
1204 case DT_BOOL:
1205 // TODO(tucker): Is it better to emit "True False..."? This
1206 // will emit "1 0..." which is more compact.
1207 return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1208 break;
1209 case DT_STRING:
1210 return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1211 break;
1212 default: {
1213 // All irregular cases
1214 string ret;
1215 if (print_v2) {
1216 strings::StrAppend(&ret, "[");
1217 }
1218 // TODO(irving): Don't call flat every time around this
1219 // loop.
1220 for (size_t i = 0; i < limit; ++i) {
1221 if (i > 0) strings::StrAppend(&ret, " ");
1222 switch (dtype()) {
1223 case DT_VARIANT: {
1224 const Variant& v = flat<Variant>()(i);
1225 strings::StrAppend(&ret, v.DebugString());
1226 } break;
1227 default:
1228 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1229 // complex64, quantized).
1230 strings::StrAppend(&ret, "?");
1231 }
1232 }
1233 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1234 if (print_v2) {
1235 strings::StrAppend(&ret, "]");
1236 }
1237 return ret;
1238 }
1239 }
1240 }
1241
tensor_data() const1242 StringPiece Tensor::tensor_data() const {
1243 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1244 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1245 }
1246
SharesBufferWith(const Tensor & b) const1247 bool Tensor::SharesBufferWith(const Tensor& b) const {
1248 return buf_ != nullptr && b.buf_ != nullptr &&
1249 buf_->root_buffer() == b.buf_->root_buffer();
1250 }
1251
DebugString(int num_values) const1252 string Tensor::DebugString(int num_values) const {
1253 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1254 " shape: ", shape().DebugString(),
1255 " values: ", SummarizeValue(num_values), ">");
1256 }
1257
DeviceSafeDebugString() const1258 string Tensor::DeviceSafeDebugString() const {
1259 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1260 " shape: ", shape().DebugString(), ">");
1261 }
1262
FillDescription(TensorDescription * description) const1263 void Tensor::FillDescription(TensorDescription* description) const {
1264 description->set_dtype(dtype());
1265 shape().AsProto(description->mutable_shape());
1266 if (buf_ != nullptr && buf_->data() != nullptr) {
1267 buf_->FillAllocationDescription(
1268 description->mutable_allocation_description());
1269 }
1270 }
1271
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1272 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1273 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1274 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1275 int64 offset = orig.size() - num_out_dims;
1276 for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1277 const int64 in_dim = out_dim + offset;
1278 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1279 }
1280 for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1281 out_dims[0] *= orig[in_dim];
1282 }
1283 return out_dims;
1284 }
1285
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1286 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1287 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1288 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1289 for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1290 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1291 }
1292 for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1293 out_dims[num_out_dims - 1] *= orig[in_dim];
1294 }
1295 return out_dims;
1296 }
1297
1298 } // namespace tensorflow
1299