• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include "absl/strings/escaping.h"
33 #include "tensorflow/core/framework/allocation_description.pb.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/resource_handle.pb.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_description.pb.h"
38 #include "tensorflow/core/framework/type_traits.h"
39 #include "tensorflow/core/framework/typed_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/variant.h"
42 #include "tensorflow/core/framework/variant_encode_decode.h"
43 #include "tensorflow/core/framework/variant_op_registry.h"
44 #include "tensorflow/core/framework/variant_tensor_data.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/tensor_coding.h"
55 #include "tensorflow/core/platform/types.h"
56 
57 namespace tensorflow {
58 
59 // Allow Tensors to be stored inside Variants with automatic
60 // encoding/decoding when those Variants are themselves being decoded
61 // in a Tensor's FromProto.
62 //
63 // NOTE(mrry): The corresponding "copy function" registrations can be found in
64 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
65 // code).
66 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
67 
GetAllocatedBytes(size_t * out_bytes) const68 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
69   AllocationDescription allocation_description;
70   FillAllocationDescription(&allocation_description);
71   if (allocation_description.allocated_bytes() > 0) {
72     *out_bytes = allocation_description.allocated_bytes();
73     return true;
74   } else {
75     return false;
76   }
77 }
78 
79 namespace {
80 
81 // An un-templated base class for Buffer.
82 class BufferBase : public TensorBuffer {
83  public:
BufferBase(Allocator * alloc,void * data_ptr)84   explicit BufferBase(Allocator* alloc, void* data_ptr)
85       : TensorBuffer(data_ptr), alloc_(alloc) {}
86 
root_buffer()87   TensorBuffer* root_buffer() override { return this; }
88 
GetAllocatedBytes(size_t * out_bytes) const89   bool GetAllocatedBytes(size_t* out_bytes) const override {
90     if (alloc_->TracksAllocationSizes()) {
91       *out_bytes = alloc_->AllocatedSize(data());
92       return *out_bytes > 0;
93     } else {
94       return false;
95     }
96   }
97 
FillAllocationDescription(AllocationDescription * proto) const98   void FillAllocationDescription(AllocationDescription* proto) const override {
99     void* data_ptr = data();
100     int64 rb = size();
101     proto->set_requested_bytes(rb);
102     proto->set_allocator_name(alloc_->Name());
103     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
104     if (alloc_->TracksAllocationSizes()) {
105       int64 ab = alloc_->AllocatedSize(data_ptr);
106       proto->set_allocated_bytes(ab);
107       int64 id = alloc_->AllocationId(data_ptr);
108       if (id > 0) {
109         proto->set_allocation_id(id);
110       }
111       if (RefCountIsOne()) {
112         proto->set_has_single_reference(true);
113       }
114     }
115   }
116 
117  protected:
RecordDeallocation()118   void RecordDeallocation() {
119     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
120                                         alloc_->Name());
121   }
122 
123   Allocator* const alloc_;
124 };
125 
126 // Typed ref-counted buffer: T[n].
127 template <typename T>
128 class Buffer : public BufferBase {
129  public:
130   Buffer(Allocator* a, int64 n);
131   Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
132 
size() const133   size_t size() const override { return sizeof(T) * elem_; }
134 
135  private:
136   int64 elem_;
137 
138   ~Buffer() override;
139 
140   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
141 };
142 
LogUnexpectedSize(int64 actual,int64 expected)143 void LogUnexpectedSize(int64 actual, int64 expected) {
144   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
145 }
146 
MemoryLoggingEnabled()147 bool MemoryLoggingEnabled() {
148   static bool memory_logging_enabled = LogMemory::IsEnabled();
149   return memory_logging_enabled;
150 }
151 
152 // A set of helper functions depending on T.
153 template <typename T>
154 struct Helper {
155   // By default, we assume T is a simple type (float, int32, etc.)
156   static_assert(is_simple_type<T>::value, "T is not a simple type.");
157   typedef protobuf::RepeatedField<T> RepeatedFieldType;
158 
159   // Encoder of simple type T to a string.  We do a copy.
160   template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper161   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
162     DCHECK_EQ(in->size(), sizeof(T) * n);
163     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
164                            out);
165   }
166 
167   // Decoder of simple type T. Copy the bytes from "in" into the
168   // tensor buffer.
169   template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper170   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
171     if (in.size() != sizeof(T) * n) {
172       LogUnexpectedSize(in.size(), sizeof(T) * n);
173       return nullptr;
174     }
175     Buffer<T>* buf = new Buffer<T>(a, n);
176     char* data = buf->template base<char>();
177     if (data == nullptr) {
178       buf->Unref();
179       return nullptr;
180     }
181     port::CopyToArray(in, data);
182     return buf;
183   }
184 
185   // Memory usage.
TotalBytestensorflow::__anon8c201cc00111::Helper186   static int64 TotalBytes(TensorBuffer* in, int64 n) {
187     DCHECK_EQ(in->size(), sizeof(T) * n);
188     return in->size();
189   }
190 };
191 
192 // Helper specialization for string (the only non-simple type we
193 // support).
194 template <>
195 struct Helper<tstring> {
196   // Proto message uses RepeatedFieldType to hold repeated T.
197   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
198 
199   // Encodes "n" elements of type string stored in "in" into Cord
200   // "out", which is usually the TensorProto::tensor_content.
201   template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper202   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
203     port::EncodeStringList(in->base<const tstring>(), n, out);
204   }
205 
206   // Decodes "n" elements of type string from "in" and constructs a
207   // buffer out of it. Returns nullptr if the decoding fails. "in" is
208   // usually the TensorProto::tensor_content.
209   template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper210   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
211     Buffer<tstring>* buf = new Buffer<tstring>(a, n);
212     tstring* strings = buf->template base<tstring>();
213     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
214       buf->Unref();
215       return nullptr;
216     }
217     return buf;
218   }
219 
220   // Returns the estimated memory usage of "n" elements of type T
221   // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper222   static int64 TotalBytes(TensorBuffer* in, int n) {
223     int64 tot = in->size();
224     DCHECK_EQ(tot, sizeof(tstring) * n);
225     const tstring* p = in->base<const tstring>();
226     for (int i = 0; i < n; ++i, ++p) tot += p->size();
227     return tot;
228   }
229 };
230 
231 template <>
232 struct Helper<ResourceHandle> {
233   // Proto message uses RepeatedFieldType to hold repeated T.
234   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
235 
236   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
237   // "out", which is usually the TensorProto::tensor_content.
238   template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper239   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
241                              port::NewStringListEncoder(out));
242   }
243 
244   // Decodes "n" elements of type string from "in" and constructs a
245   // buffer out of it. Returns nullptr if the decoding fails. "in" is
246   // usually the TensorProto::tensor_content.
247   template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper248   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
249     auto* buf = new Buffer<ResourceHandle>(a, n);
250     ResourceHandle* ps = buf->template base<ResourceHandle>();
251     if (ps == nullptr ||
252         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
253       buf->Unref();
254       return nullptr;
255     }
256     return buf;
257   }
258 
259   // Returns the estimated memory usage of "n" elements of type T
260   // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper261   static int64 TotalBytes(TensorBuffer* in, int n) {
262     return n * sizeof(ResourceHandle);
263   }
264 };
265 
266 template <>
267 struct Helper<Variant> {
268   // Encodes "n" elements of type Variant stored in "in" into destination
269   // "out", which is usually the TensorProto::tensor_content.
270   template <typename Destination>
Encodetensorflow::__anon8c201cc00111::Helper271   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
272     EncodeVariantList(in->base<const Variant>(), n,
273                       port::NewStringListEncoder(out));
274   }
275 
276   // Decodes "n" elements of type Variant from "in" and constructs a
277   // buffer out of it. Returns nullptr if the decoding fails. "in" is
278   // usually the TensorProto::tensor_content.
279   template <typename Source>
Decodetensorflow::__anon8c201cc00111::Helper280   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
281     auto* buf = new Buffer<Variant>(a, n);
282     Variant* ps = buf->template base<Variant>();
283     if (ps == nullptr ||
284         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
285       buf->Unref();
286       return nullptr;
287     }
288     return buf;
289   }
290 
291   // Returns the estimated memory usage of "n" elements of type T
292   // stored in buffer "in".
TotalBytestensorflow::__anon8c201cc00111::Helper293   static int64 TotalBytes(TensorBuffer* in, int n) {
294     return n * sizeof(Variant);
295   }
296 };
297 
298 template <typename T>
299 struct ProtoHelper {};
300 
301 // For a C++ type "T" (float, double, int32, etc.), the repeated field
302 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
303 // int32, string, etc) in the TensorProto is used for serializing the
304 // tensor of type "T".
305 #define PROTO_TRAITS(T, F, N)                                          \
306   template <>                                                          \
307   struct ProtoHelper<T> {                                              \
308     typedef Helper<F>::RepeatedFieldType FieldType;                    \
309     static FieldType::const_iterator Begin(const TensorProto& proto) { \
310       return proto.N##_val().begin();                                  \
311     }                                                                  \
312     static size_t NumElements(const TensorProto& proto) {              \
313       return proto.N##_val().size();                                   \
314     }                                                                  \
315     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
316       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
317       proto->mutable_##N##_val()->Swap(&copy);                         \
318     }                                                                  \
319   };
320 PROTO_TRAITS(float, float, float);
321 PROTO_TRAITS(double, double, double);
322 PROTO_TRAITS(int32, int32, int);
323 PROTO_TRAITS(uint8, int32, int);
324 PROTO_TRAITS(uint16, int32, int);
325 PROTO_TRAITS(uint32, uint32, uint32);
326 PROTO_TRAITS(int16, int32, int);
327 PROTO_TRAITS(int8, int32, int);
328 PROTO_TRAITS(bool, bool, bool);
329 PROTO_TRAITS(tstring, tstring, string);
330 PROTO_TRAITS(qint8, int32, int);
331 PROTO_TRAITS(quint8, int32, int);
332 PROTO_TRAITS(qint16, int32, int);
333 PROTO_TRAITS(quint16, int32, int);
334 #undef PROTO_TRAITS
335 
336 template <>
337 struct ProtoHelper<int64> {
Begintensorflow::__anon8c201cc00111::ProtoHelper338   static const int64* Begin(const TensorProto& proto) {
339     return reinterpret_cast<const int64*>(proto.int64_val().begin());
340   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper341   static size_t NumElements(const TensorProto& proto) {
342     return proto.int64_val().size();
343   }
Filltensorflow::__anon8c201cc00111::ProtoHelper344   static void Fill(const int64* data, size_t n, TensorProto* proto) {
345     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
346     proto->mutable_int64_val()->Swap(&copy);
347   }
348 };
349 
350 template <>
351 struct ProtoHelper<uint64> {
Begintensorflow::__anon8c201cc00111::ProtoHelper352   static const uint64* Begin(const TensorProto& proto) {
353     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
354   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper355   static size_t NumElements(const TensorProto& proto) {
356     return proto.uint64_val().size();
357   }
Filltensorflow::__anon8c201cc00111::ProtoHelper358   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
359     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
360     proto->mutable_uint64_val()->Swap(&copy);
361   }
362 };
363 
364 template <>
365 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon8c201cc00111::ProtoHelper366   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
367       const TensorProto& proto) {
368     return proto.resource_handle_val().begin();
369   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper370   static size_t NumElements(const TensorProto& proto) {
371     return proto.resource_handle_val().size();
372   }
Filltensorflow::__anon8c201cc00111::ProtoHelper373   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
374     auto* handles = proto->mutable_resource_handle_val();
375     handles->Clear();
376     for (size_t i = 0; i < n; i++) {
377       data[i].AsProto(handles->Add());
378     }
379   }
380 };
381 
382 template <>
383 struct ProtoHelper<Variant> {
384   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon8c201cc00111::ProtoHelper385   Begin(const TensorProto& proto) {
386     return proto.variant_val().begin();
387   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper388   static size_t NumElements(const TensorProto& proto) {
389     return proto.variant_val().size();
390   }
Filltensorflow::__anon8c201cc00111::ProtoHelper391   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
392     auto* variant_values = proto->mutable_variant_val();
393     variant_values->Clear();
394     for (size_t i = 0; i < n; ++i) {
395       VariantTensorData tmp;
396       data[i].Encode(&tmp);
397       tmp.ToProto(variant_values->Add());
398     }
399   }
400 };
401 
402 template <>
403 struct ProtoHelper<complex64> {
404   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper405   static const complex64* Begin(const TensorProto& proto) {
406     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
407   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper408   static size_t NumElements(const TensorProto& proto) {
409     return proto.scomplex_val().size() / 2;
410   }
Filltensorflow::__anon8c201cc00111::ProtoHelper411   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
412     const float* p = reinterpret_cast<const float*>(data);
413     FieldType copy(p, p + n * 2);
414     proto->mutable_scomplex_val()->Swap(&copy);
415   }
416 };
417 
418 template <>
419 struct ProtoHelper<complex128> {
420   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper421   static const complex128* Begin(const TensorProto& proto) {
422     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
423   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper424   static size_t NumElements(const TensorProto& proto) {
425     return proto.dcomplex_val().size() / 2;
426   }
Filltensorflow::__anon8c201cc00111::ProtoHelper427   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
428     const double* p = reinterpret_cast<const double*>(data);
429     FieldType copy(p, p + n * 2);
430     proto->mutable_dcomplex_val()->Swap(&copy);
431   }
432 };
433 
434 template <>
435 struct ProtoHelper<qint32> {
436   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon8c201cc00111::ProtoHelper437   static const qint32* Begin(const TensorProto& proto) {
438     return reinterpret_cast<const qint32*>(proto.int_val().data());
439   }
NumElementstensorflow::__anon8c201cc00111::ProtoHelper440   static size_t NumElements(const TensorProto& proto) {
441     return proto.int_val().size();
442   }
Filltensorflow::__anon8c201cc00111::ProtoHelper443   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
444     const int32* p = reinterpret_cast<const int32*>(data);
445     FieldType copy(p, p + n);
446     proto->mutable_int_val()->Swap(&copy);
447   }
448 };
449 
450 template <>
451 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon8c201cc00111::ProtoHelper452   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
453     proto->mutable_half_val()->Reserve(n);
454     for (size_t i = 0; i < n; ++i) {
455       proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
456     }
457   }
458 };
459 
460 template <>
461 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon8c201cc00111::ProtoHelper462   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
463     proto->mutable_half_val()->Reserve(n);
464     for (size_t i = 0; i < n; ++i) {
465       proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
466     }
467   }
468 };
469 
470 template <typename T>
Buffer(Allocator * a,int64 n)471 Buffer<T>::Buffer(Allocator* a, int64 n)
472     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
473       elem_(n) {}
474 
475 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)476 Buffer<T>::Buffer(Allocator* a, int64 n,
477                   const AllocationAttributes& allocation_attr)
478     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
479       elem_(n) {}
480 
481 template <typename T>
~Buffer()482 Buffer<T>::~Buffer() {
483   if (data()) {
484     if (MemoryLoggingEnabled()) {
485       RecordDeallocation();
486     }
487     TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
488   }
489 }
490 
491 // Allocates a T[n] buffer. Fills in the buffer with repeated values
492 // in "in".  If "in" has less values than "n", fills the rest of T[n]
493 // with the last value. If "in" has no values, fills T[n] with the
494 // default value for T.
495 //
496 // This routine is using the typed fields (float_val, etc.) in the
497 // tensor proto as opposed to the untyped binary representation
498 // (tensor_content). This is used when we expect the TensorProto is
499 // used by a client program which may not know how to encode a tensor
500 // in the compact binary representation.
501 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)502 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
503   CHECK_GT(n, 0);
504   Buffer<T>* buf = new Buffer<T>(a, n);
505   T* data = buf->template base<T>();
506   if (data == nullptr) {
507     buf->Unref();
508     return nullptr;
509   }
510 
511   const int64 in_n = ProtoHelper<T>::NumElements(in);
512   if (in_n <= 0) {
513     std::fill_n(data, n, T());
514   } else {
515     auto begin = ProtoHelper<T>::Begin(in);
516     if (n <= in_n) {
517       std::copy_n(begin, n, data);
518     } else {
519       std::copy_n(begin, in_n, data);
520       if (std::is_trivially_copyable<T>::value) {
521         const T last = *(data + in_n - 1);
522         std::fill_n(data + in_n, n - in_n, last);
523       } else {
524         const T& last = *(data + in_n - 1);
525         std::fill_n(data + in_n, n - in_n, last);
526       }
527     }
528   }
529 
530   return buf;
531 }
532 
533 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)534 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
535                                       int64 n) {
536   CHECK_GT(n, 0);
537   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
538   Variant* data = buf->template base<Variant>();
539   if (data == nullptr) {
540     buf->Unref();
541     return nullptr;
542   }
543   const int64 in_n = ProtoHelper<Variant>::NumElements(in);
544   if (in_n <= 0) {
545     std::fill_n(data, n, Variant());
546   } else {
547     // If tensor shape says we have n < in_n elements in the output tensor
548     // then make sure to only decode the first n out of the in_n elements in the
549     // in tensors. In all other cases, we decode all in_n elements of in and set
550     // the remaining elements up to n to be the default Variant() value.
551     const int64 real_n = n < in_n ? n : in_n;
552     for (int64 i = 0; i < real_n; ++i) {
553       data[i] = in.variant_val(i);
554       if (!DecodeUnaryVariant(&data[i])) {
555         LOG(ERROR) << "Could not decode variant with type_name: \""
556                    << data[i].TypeName()
557                    << "\".  Perhaps you forgot to register a "
558                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
559         buf->Unref();
560         return nullptr;
561       }
562     }
563     for (int64 i = in_n; i < n; ++i) {
564       data[i] = Variant();
565     }
566   }
567   return buf;
568 }
569 
570 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
571 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
572 // we don't use ProtoHelper<uint16>).
573 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)574 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
575                                           int64 n) {
576   CHECK_GT(n, 0);
577   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
578   uint16* data = buf->template base<uint16>();
579   if (data == nullptr) {
580     buf->Unref();
581     return nullptr;
582   }
583   const int64 in_n = in.half_val().size();
584   auto begin = in.half_val().begin();
585   if (n <= in_n) {
586     std::copy_n(begin, n, data);
587   } else if (in_n > 0) {
588     std::copy_n(begin, in_n, data);
589     const uint16 last = *(data + in_n - 1);
590     std::fill_n(data + in_n, n - in_n, last);
591   } else {
592     std::fill_n(data, n, 0);
593   }
594   return buf;
595 }
596 
597 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)598 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
599                                        int64 n) {
600   CHECK_GT(n, 0);
601   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
602   uint16* data = buf->template base<uint16>();
603   if (data == nullptr) {
604     buf->Unref();
605     return nullptr;
606   }
607   const int64 in_n = in.half_val().size();
608   auto begin = in.half_val().begin();
609   if (n <= in_n) {
610     std::copy_n(begin, n, data);
611   } else if (in_n > 0) {
612     std::copy_n(begin, in_n, data);
613     const uint16 last = *(data + in_n - 1);
614     std::fill_n(data + in_n, n - in_n, last);
615   } else {
616     std::fill_n(data, n, 0);
617   }
618   return buf;
619 }
620 
621 // Copies T[n] stored in the buffer "in" into the repeated field in
622 // "out" corresponding to type T.
623 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)624 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
625   const T* data = in.base<const T>();
626   // NOTE: T may not the same as
627   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
628   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
629   // critical, we can specialize T=float and do memcpy directly.
630   ProtoHelper<T>::Fill(data, n, out);
631 }
632 
RefIfNonNull(core::RefCounted * buf)633 void RefIfNonNull(core::RefCounted* buf) {
634   if (buf) buf->Ref();
635 }
636 
UnrefIfNonNull(core::RefCounted * buf)637 void UnrefIfNonNull(core::RefCounted* buf) {
638   if (buf) buf->Unref();
639 }
640 
641 }  // end namespace
642 
Tensor()643 Tensor::Tensor() : Tensor(DT_FLOAT) {}
644 
Tensor(DataType type)645 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
646 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)647 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
648     : shape_(shape), buf_(buf) {
649   set_dtype(type);
650   RefIfNonNull(buf);
651 }
652 
Tensor(DataType type,const TensorShape & shape,core::RefCountPtr<TensorBuffer> buf)653 Tensor::Tensor(DataType type, const TensorShape& shape,
654                core::RefCountPtr<TensorBuffer> buf)
655     : shape_(shape), buf_(buf.release()) {
656   set_dtype(type);
657 }
658 
IsInitialized() const659 bool Tensor::IsInitialized() const {
660   return (buf_ != nullptr && buf_->data() != nullptr) ||
661          shape_.num_elements() == 0;
662 }
663 
CheckType(DataType expected_dtype) const664 void Tensor::CheckType(DataType expected_dtype) const {
665   CHECK_EQ(dtype(), expected_dtype)
666       << " " << DataTypeString(expected_dtype) << " expected, got "
667       << DataTypeString(dtype());
668 }
669 
CheckTypeAndIsAligned(DataType expected_dtype) const670 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
671   CHECK_EQ(dtype(), expected_dtype)
672       << " " << DataTypeString(expected_dtype) << " expected, got "
673       << DataTypeString(dtype());
674   CHECK(IsAligned()) << "ptr = " << base<void>();
675 }
676 
CheckIsAlignedAndSingleElement() const677 void Tensor::CheckIsAlignedAndSingleElement() const {
678   CHECK(IsAligned()) << "Aligned and single element";
679   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
680 }
681 
~Tensor()682 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
683 
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)684 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
685                            const TensorShape& shape) {
686   int in_size = DataTypeSize(other.dtype());
687   int out_size = DataTypeSize(dtype);
688   if (in_size == 0) {
689     return errors::InvalidArgument("other tensor has zero-sized data type");
690   }
691   if (out_size == 0) {
692     return errors::InvalidArgument("specified output type is zero-sized");
693   }
694   if (shape.num_elements() * out_size !=
695       other.shape().num_elements() * in_size) {
696     return errors::InvalidArgument(
697         "input and output shapes/data type sizes are not compatible");
698   }
699   shape_ = shape;
700   shape_.set_data_type(dtype);
701   if (buf_ != other.buf_) {
702     UnrefIfNonNull(buf_);
703     buf_ = other.buf_;
704     RefIfNonNull(buf_);
705   }
706   return Status::OK();
707 }
708 
709 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
710 // For the latter case, we have to make sure that the refcount is
711 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const712 bool Tensor::RefCountIsOne() const {
713   return buf_ != nullptr && buf_->RefCountIsOne() &&
714          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
715 }
716 
717 // The macro CASES() expands to a switch statement conditioned on
718 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
719 #define SINGLE_ARG(...) __VA_ARGS__
720 #define CASE(TYPE, STMTS)             \
721   case DataTypeToEnum<TYPE>::value: { \
722     typedef TYPE T;                   \
723     STMTS;                            \
724     break;                            \
725   }
726 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
727   switch (TYPE_ENUM) {                                         \
728     CASE(float, SINGLE_ARG(STMTS))                             \
729     CASE(double, SINGLE_ARG(STMTS))                            \
730     CASE(int32, SINGLE_ARG(STMTS))                             \
731     CASE(uint8, SINGLE_ARG(STMTS))                             \
732     CASE(uint16, SINGLE_ARG(STMTS))                            \
733     CASE(uint32, SINGLE_ARG(STMTS))                            \
734     CASE(uint64, SINGLE_ARG(STMTS))                            \
735     CASE(int16, SINGLE_ARG(STMTS))                             \
736     CASE(int8, SINGLE_ARG(STMTS))                              \
737     CASE(tstring, SINGLE_ARG(STMTS))                           \
738     CASE(complex64, SINGLE_ARG(STMTS))                         \
739     CASE(complex128, SINGLE_ARG(STMTS))                        \
740     CASE(int64, SINGLE_ARG(STMTS))                             \
741     CASE(bool, SINGLE_ARG(STMTS))                              \
742     CASE(qint32, SINGLE_ARG(STMTS))                            \
743     CASE(quint8, SINGLE_ARG(STMTS))                            \
744     CASE(qint8, SINGLE_ARG(STMTS))                             \
745     CASE(quint16, SINGLE_ARG(STMTS))                           \
746     CASE(qint16, SINGLE_ARG(STMTS))                            \
747     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
748     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
749     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
750     CASE(Variant, SINGLE_ARG(STMTS))                           \
751     case DT_INVALID:                                           \
752       INVALID;                                                 \
753       break;                                                   \
754     default:                                                   \
755       DEFAULT;                                                 \
756       break;                                                   \
757   }
758 
759 #define CASES(TYPE_ENUM, STMTS)                                      \
760   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS,                               \
761                      LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
762                      , LOG(FATAL) << "Type not set";)
763 
Tensor(Allocator * a,DataType type,const TensorShape & shape)764 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
765     : shape_(shape), buf_(nullptr) {
766   set_dtype(type);
767   CHECK_NOTNULL(a);
768   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
769     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
770   }
771   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
772     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
773                                       *this);
774   }
775 }
776 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)777 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
778                const AllocationAttributes& allocation_attr)
779     : shape_(shape), buf_(nullptr) {
780   set_dtype(type);
781   CHECK_NOTNULL(a);
782   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
783     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
784   }
785   if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
786       buf_ != nullptr && buf_->data() != nullptr) {
787     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
788                                       LogMemory::UNKNOWN_STEP_ID, *this);
789   }
790 }
791 
792 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
793 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
794 // acquiring a lock, which guards initialization of the per-NUMA zone
795 // allocators, and becomes highly contended.
796 //
797 // Note also that it would be better if all Tensor allocations required the user
798 // to specify an allocator, for purposes of accounting, etc. However, the
799 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()800 static Allocator* get_default_cpu_allocator() {
801   static Allocator* default_cpu_allocator =
802       cpu_allocator(port::kNUMANoAffinity);
803   return default_cpu_allocator;
804 }
805 
Tensor(DataType type,const TensorShape & shape)806 Tensor::Tensor(DataType type, const TensorShape& shape)
807     : Tensor(get_default_cpu_allocator(), type, shape) {}
808 
GetAllocatedBytes(size_t * out_bytes) const809 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
810     size_t* out_bytes) const {
811   // `this->FillAllocationDescription()` never sets allocated bytes information,
812   // so we can short-circuit the construction of an `AllocationDescription`.
813   return false;
814 }
815 
FillAllocationDescription(AllocationDescription * proto) const816 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
817     AllocationDescription* proto) const {
818   proto->set_requested_bytes(size());
819   proto->set_allocator_name("HostScalarTensorBuffer");
820   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
821 }
822 
823 template <typename T>
824 class SubBuffer : public TensorBuffer {
825  public:
826   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)827   SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
828       : TensorBuffer(buf->base<T>() + delta),
829         root_(buf->root_buffer()),
830         elem_(n) {
831     // Sanity check. The caller should ensure the sub buffer is valid.
832     CHECK_LE(root_->base<T>(), this->base<T>());
833     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
834     CHECK_LE(this->base<T>(), root_limit);
835     CHECK_LE(this->base<T>() + n, root_limit);
836     // Hold a ref of the underlying root buffer.
837     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
838     root_->Ref();
839   }
840 
size() const841   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()842   TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const843   bool GetAllocatedBytes(size_t* out_bytes) const override {
844     return root_->GetAllocatedBytes(out_bytes);
845   }
FillAllocationDescription(AllocationDescription * proto) const846   void FillAllocationDescription(AllocationDescription* proto) const override {
847     root_->FillAllocationDescription(proto);
848   }
849 
850  private:
851   TensorBuffer* root_;
852   int64 elem_;
853 
~SubBuffer()854   ~SubBuffer() override { root_->Unref(); }
855 
856   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
857 };
858 
Slice(int64 start,int64 limit) const859 Tensor Tensor::Slice(int64 start, int64 limit) const {
860   CHECK_GE(dims(), 1);
861   CHECK_LE(0, start);
862   CHECK_LE(start, limit);
863   int64 dim0_size = shape_.dim_size(0);
864   CHECK_LE(limit, dim0_size);
865   if ((start == 0) && (limit == dim0_size)) {
866     return *this;
867   }
868   Tensor ret;
869   ret.shape_ = shape_;
870   ret.set_dtype(dtype());
871   ret.buf_ = nullptr;
872   if (dim0_size > 0) {
873     const int64 elems_per_dim0 = NumElements() / dim0_size;
874     const int64 delta = start * elems_per_dim0;
875     dim0_size = limit - start;
876     ret.shape_.set_dim(0, dim0_size);
877     const int64 num_elems = dim0_size * elems_per_dim0;
878     if (buf_) {
879       DataType dt = dtype();
880       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
881     }
882   }
883   return ret;
884 }
885 
SubSlice(int64 index) const886 Tensor Tensor::SubSlice(int64 index) const {
887   CHECK_GE(dims(), 1);  // Crash ok.
888   CHECK_LE(0, index);   // Crash ok.
889   int64 dim0_size = shape_.dim_size(0);
890   CHECK_LE(index, dim0_size);  // Crash ok.
891   Tensor ret;
892   ret.shape_ = shape_;
893   ret.shape_.RemoveDim(0);
894   ret.set_dtype(dtype());
895   ret.buf_ = nullptr;
896   if (dim0_size > 0) {
897     const int64 elems_per_dim0 = NumElements() / dim0_size;
898     const int64 delta = index * elems_per_dim0;
899     const int64 num_elems = elems_per_dim0;
900     if (buf_) {
901       DataType dt = dtype();
902       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
903     }
904   }
905   return ret;
906 }
907 
FromProto(const TensorProto & proto)908 bool Tensor::FromProto(const TensorProto& proto) {
909   return FromProto(get_default_cpu_allocator(), proto);
910 }
911 
FromProto(Allocator * a,const TensorProto & proto)912 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
913   CHECK_NOTNULL(a);
914   TensorBuffer* p = nullptr;
915   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
916   if (proto.dtype() == DT_INVALID) return false;
917   TensorShape shape(proto.tensor_shape());
918   const int64 N = shape.num_elements();
919   if (N > 0 && proto.dtype()) {
920     bool dtype_error = false;
921     if (!proto.tensor_content().empty()) {
922       const auto& content = proto.tensor_content();
923       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
924                          dtype_error = true, dtype_error = true);
925     } else {
926       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
927                          dtype_error = true, dtype_error = true);
928     }
929     if (dtype_error || p == nullptr) return false;
930   }
931   shape_ = shape;
932   set_dtype(proto.dtype());
933   UnrefIfNonNull(buf_);
934   buf_ = p;
935   // TODO(misard) add tracking of which kernels and steps are calling
936   // FromProto.
937   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
938     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
939                                       LogMemory::UNKNOWN_STEP_ID, *this);
940   }
941   return true;
942 }
943 
AsProtoField(TensorProto * proto) const944 void Tensor::AsProtoField(TensorProto* proto) const {
945   proto->Clear();
946   shape_.AsProto(proto->mutable_tensor_shape());
947   proto->set_dtype(dtype());
948   if (buf_) {
949     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
950   }
951 }
952 
AsProtoTensorContent(TensorProto * proto) const953 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
954   proto->Clear();
955   proto->set_dtype(dtype());
956   shape_.AsProto(proto->mutable_tensor_shape());
957   if (buf_) {
958     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
959                                      proto->mutable_tensor_content()));
960   }
961 }
962 
TotalBytes() const963 size_t Tensor::TotalBytes() const {
964   if (shape_.num_elements() == 0) return 0;
965   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
966   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
967   return 0;  // Makes compiler happy.
968 }
969 
AllocatedBytes() const970 size_t Tensor::AllocatedBytes() const {
971   if (buf_) {
972     size_t ret;
973     if (buf_->GetAllocatedBytes(&ret)) {
974       return ret;
975     }
976   }
977   return TotalBytes();
978 }
979 
CanUseDMA() const980 bool Tensor::CanUseDMA() const {
981   CASES(dtype(), return is_simple_type<T>::value);
982   return false;  // Makes compiler happy.
983 }
984 
985 #undef CASES
986 #undef CASE
987 
988 namespace {
989 
990 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
991 // we would like to keep them compatible with their absl counterparts, for ease
992 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
993 // logic is so simple we can just replicate it here, where it is close to its
994 // usage and easy to change later. And there's the extra benefit of not
995 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)996 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
997                                                 bool print_v2) {
998   return a;
999 }
PrintOneElement(const tstring & a,bool print_v2)1000 inline string PrintOneElement(const tstring& a, bool print_v2) {
1001   if (print_v2) {
1002     return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1003   } else {
1004     return absl::Utf8SafeCEscape(a);
1005   }
1006 }
PrintOneElement(const Eigen::half & h,bool print_v2)1007 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1008   return static_cast<float>(h);
1009 }
1010 
PrintOneElement(bfloat16 f,bool print_v2)1011 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1012   return static_cast<float>(f);
1013 }
1014 
1015 // Print from left dim to right dim recursively.
1016 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)1017 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1018                  int64 limit, int shape_size, const T* data, int64* data_index,
1019                  string* result) {
1020   if (*data_index >= limit) return;
1021   int64 element_count = shape[dim_index];
1022   // We have reached the right-most dimension of the tensor.
1023   if (dim_index == shape_size - 1) {
1024     for (int64 i = 0; i < element_count; i++) {
1025       if (*data_index >= limit) {
1026         // If not enough elements has been printed, append "...".
1027         if (dim_index != 0) {
1028           strings::StrAppend(result, "...");
1029         }
1030         return;
1031       }
1032       if (i > 0) strings::StrAppend(result, " ");
1033       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1034     }
1035     return;
1036   }
1037   // Loop every element of one dim.
1038   for (int64 i = 0; i < element_count; i++) {
1039     bool flag = false;
1040     if (*data_index < limit) {
1041       strings::StrAppend(result, "[");
1042       flag = true;
1043     }
1044     // As for each element, print the sub-dim.
1045     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1046                 result);
1047     if (*data_index < limit || flag) {
1048       strings::StrAppend(result, "]");
1049       flag = false;
1050     }
1051   }
1052 }
1053 
1054 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1055 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1056   if (dim_index == num_dims - 1) {
1057     strings::StrAppend(result, " ");
1058     return;
1059   }
1060   for (int j = 0; j < num_dims - dim_index - 1; j++) {
1061     strings::StrAppend(result, "\n");
1062   }
1063   for (int j = 0; j <= dim_index; j++) {
1064     strings::StrAppend(result, " ");
1065   }
1066 }
1067 
1068 // Print from left dim to right dim recursively.
1069 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1070 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1071                    int64 num_elts_at_ends, int num_dims, const T* data,
1072                    int64 data_index, string* result) {
1073   // We have recursed beyond all the dimensions into a single element
1074   // of the tensor.
1075   if (dim_index == num_dims) {
1076     strings::StrAppend(result, PrintOneElement(data[data_index], true));
1077     return;
1078   }
1079 
1080   strings::StrAppend(result, "[");
1081   int64 element_count = shape[dim_index];
1082   int64 start_of_end =
1083       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1084 
1085   // Loop every element of one dim.
1086   int64 elements_per_iter = 1;
1087   for (int i = dim_index + 1; i < num_dims; i++) {
1088     elements_per_iter *= shape[i];
1089   }
1090   for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1091     if (i > 0) {
1092       PrintDimSpacing(dim_index, num_dims, result);
1093     }
1094 
1095     // As for each element, print the sub-dim.
1096     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1097                   data_index + elements_per_iter * i, result);
1098   }
1099   if (element_count > 2 * num_elts_at_ends) {
1100     PrintDimSpacing(dim_index, num_dims, result);
1101     strings::StrAppend(result, "...");
1102   }
1103   for (int64 i = start_of_end; i < element_count; i++) {
1104     // As for each element, print the sub-dim.
1105     PrintDimSpacing(dim_index, num_dims, result);
1106     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1107                   data_index + elements_per_iter * i, result);
1108   }
1109 
1110   strings::StrAppend(result, "]");
1111 }
1112 
1113 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1114 string SummarizeArray(int64 limit, int64 num_elts,
1115                       const TensorShape& tensor_shape, const char* data,
1116                       const bool print_v2) {
1117   string ret;
1118   const T* array = reinterpret_cast<const T*>(data);
1119 
1120   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1121   if (shape.empty()) {
1122     for (int64 i = 0; i < limit; ++i) {
1123       if (i > 0) strings::StrAppend(&ret, " ");
1124       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1125     }
1126     if (num_elts > limit) strings::StrAppend(&ret, "...");
1127     return ret;
1128   }
1129   if (print_v2) {
1130     const int num_dims = tensor_shape.dims();
1131     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1132   } else {
1133     int64 data_index = 0;
1134     const int shape_size = tensor_shape.dims();
1135     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1136 
1137     if (num_elts > limit) strings::StrAppend(&ret, "...");
1138   }
1139 
1140   return ret;
1141 }
1142 }  // namespace
1143 
SummarizeValue(int64 max_entries,bool print_v2) const1144 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1145   const int64 num_elts = NumElements();
1146   if (max_entries < 0) {
1147     max_entries = num_elts;
1148   }
1149   size_t limit = std::min(max_entries, num_elts);
1150   if ((limit > 0) && (buf_ == nullptr)) {
1151     return strings::StrCat("uninitialized Tensor of ", num_elts,
1152                            " elements of type ", dtype());
1153   }
1154   const char* data = limit > 0 ? tensor_data().data() : nullptr;
1155   switch (dtype()) {
1156     case DT_BFLOAT16:
1157       return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1158       break;
1159     case DT_HALF:
1160       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1161                                          print_v2);
1162       break;
1163     case DT_FLOAT:
1164       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1165       break;
1166     case DT_DOUBLE:
1167       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1168       break;
1169     case DT_UINT32:
1170       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1171       break;
1172     case DT_INT32:
1173       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1174       break;
1175     case DT_UINT8:
1176     case DT_QUINT8:
1177       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1178       break;
1179     case DT_UINT16:
1180     case DT_QUINT16:
1181       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1182       break;
1183     case DT_INT16:
1184     case DT_QINT16:
1185       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1186       break;
1187     case DT_INT8:
1188     case DT_QINT8:
1189       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1190       break;
1191     case DT_UINT64:
1192       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1193       break;
1194     case DT_INT64:
1195       return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1196       break;
1197     case DT_BOOL:
1198       // TODO(tucker): Is it better to emit "True False..."?  This
1199       // will emit "1 0..." which is more compact.
1200       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1201       break;
1202     case DT_STRING:
1203       return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1204       break;
1205     default: {
1206       // All irregular cases
1207       string ret;
1208       if (print_v2) {
1209         strings::StrAppend(&ret, "[");
1210       }
1211       // TODO(irving): Don't call flat every time around this
1212       // loop.
1213       for (size_t i = 0; i < limit; ++i) {
1214         if (i > 0) strings::StrAppend(&ret, " ");
1215         switch (dtype()) {
1216           case DT_VARIANT: {
1217             const Variant& v = flat<Variant>()(i);
1218             strings::StrAppend(&ret, v.DebugString());
1219           } break;
1220           default:
1221             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1222             // complex64, quantized).
1223             strings::StrAppend(&ret, "?");
1224         }
1225       }
1226       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1227       if (print_v2) {
1228         strings::StrAppend(&ret, "]");
1229       }
1230       return ret;
1231     }
1232   }
1233 }
1234 
tensor_data() const1235 StringPiece Tensor::tensor_data() const {
1236   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1237   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1238 }
1239 
data() const1240 void* Tensor::data() const {
1241   if (buf_ == nullptr) return nullptr;  // Don't die for empty tensors
1242   return static_cast<void*>(buf_->data());
1243 }
1244 
SharesBufferWith(const Tensor & b) const1245 bool Tensor::SharesBufferWith(const Tensor& b) const {
1246   return buf_ != nullptr && b.buf_ != nullptr &&
1247          buf_->root_buffer() == b.buf_->root_buffer();
1248 }
1249 
DebugString(int num_values) const1250 string Tensor::DebugString(int num_values) const {
1251   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1252                          " shape: ", shape().DebugString(),
1253                          " values: ", SummarizeValue(num_values), ">");
1254 }
1255 
DeviceSafeDebugString() const1256 string Tensor::DeviceSafeDebugString() const {
1257   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1258                          " shape: ", shape().DebugString(), ">");
1259 }
1260 
FillDescription(TensorDescription * description) const1261 void Tensor::FillDescription(TensorDescription* description) const {
1262   description->set_dtype(dtype());
1263   shape().AsProto(description->mutable_shape());
1264   if (buf_ != nullptr && buf_->data() != nullptr) {
1265     buf_->FillAllocationDescription(
1266         description->mutable_allocation_description());
1267   }
1268 }
1269 
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1270 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1271     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1272   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1273   int64 offset = orig.size() - num_out_dims;
1274   for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1275     const int64 in_dim = out_dim + offset;
1276     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1277   }
1278   for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1279     out_dims[0] *= orig[in_dim];
1280   }
1281   return out_dims;
1282 }
1283 
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1284 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1285     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1286   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1287   for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1288     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1289   }
1290   for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1291     out_dims[num_out_dims - 1] *= orig[in_dim];
1292   }
1293   return out_dims;
1294 }
1295 
1296 }  // namespace tensorflow
1297