• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include "absl/strings/escaping.h"
33 #include "tensorflow/core/framework/allocation_description.pb.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/resource_handle.pb.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_description.pb.h"
38 #include "tensorflow/core/framework/type_traits.h"
39 #include "tensorflow/core/framework/typed_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/variant.h"
42 #include "tensorflow/core/framework/variant_encode_decode.h"
43 #include "tensorflow/core/framework/variant_op_registry.h"
44 #include "tensorflow/core/framework/variant_tensor_data.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/tensor_coding.h"
55 #include "tensorflow/core/platform/types.h"
56 
57 namespace tensorflow {
58 
59 // Allow Tensors to be stored inside Variants with automatic
60 // encoding/decoding when those Variants are themselves being decoded
61 // in a Tensor's FromProto.
62 //
63 // NOTE(mrry): The corresponding "copy function" registrations can be found in
64 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
65 // code).
66 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
67 
GetAllocatedBytes(size_t * out_bytes) const68 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
69   AllocationDescription allocation_description;
70   FillAllocationDescription(&allocation_description);
71   if (allocation_description.allocated_bytes() > 0) {
72     *out_bytes = allocation_description.allocated_bytes();
73     return true;
74   } else {
75     return false;
76   }
77 }
78 
79 namespace {
80 
81 // An un-templated base class for Buffer.
82 class BufferBase : public TensorBuffer {
83  public:
BufferBase(Allocator * alloc,void * data_ptr)84   explicit BufferBase(Allocator* alloc, void* data_ptr)
85       : TensorBuffer(data_ptr), alloc_(alloc) {}
86 
root_buffer()87   TensorBuffer* root_buffer() override { return this; }
88 
GetAllocatedBytes(size_t * out_bytes) const89   bool GetAllocatedBytes(size_t* out_bytes) const override {
90     if (alloc_->TracksAllocationSizes()) {
91       *out_bytes = alloc_->AllocatedSize(data());
92       return *out_bytes > 0;
93     } else {
94       return false;
95     }
96   }
97 
FillAllocationDescription(AllocationDescription * proto) const98   void FillAllocationDescription(AllocationDescription* proto) const override {
99     void* data_ptr = data();
100     int64 rb = size();
101     proto->set_requested_bytes(rb);
102     proto->set_allocator_name(alloc_->Name());
103     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
104     if (alloc_->TracksAllocationSizes()) {
105       int64 ab = alloc_->AllocatedSize(data_ptr);
106       proto->set_allocated_bytes(ab);
107       int64 id = alloc_->AllocationId(data_ptr);
108       if (id > 0) {
109         proto->set_allocation_id(id);
110       }
111       if (RefCountIsOne()) {
112         proto->set_has_single_reference(true);
113       }
114     }
115   }
116 
117  protected:
RecordDeallocation()118   void RecordDeallocation() {
119     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
120                                         alloc_->Name());
121   }
122 
123   Allocator* const alloc_;
124 };
125 
126 // Typed ref-counted buffer: T[n].
127 template <typename T>
128 class Buffer : public BufferBase {
129  public:
130   Buffer(Allocator* a, int64 n);
131   Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
132 
size() const133   size_t size() const override { return sizeof(T) * elem_; }
134 
135  private:
136   int64 elem_;
137 
138   ~Buffer() override;
139 
140   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
141 };
142 
LogUnexpectedSize(int64 actual,int64 expected)143 void LogUnexpectedSize(int64 actual, int64 expected) {
144   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
145 }
146 
MemoryLoggingEnabled()147 bool MemoryLoggingEnabled() {
148   static bool memory_logging_enabled = LogMemory::IsEnabled();
149   return memory_logging_enabled;
150 }
151 
152 // A set of helper functions depending on T.
153 template <typename T>
154 struct Helper {
155   // By default, we assume T is a simple type (float, int32, etc.)
156   static_assert(is_simple_type<T>::value, "T is not a simple type.");
157   typedef protobuf::RepeatedField<T> RepeatedFieldType;
158 
159   // Encoder of simple type T to a string.  We do a copy.
160   template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper161   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
162     DCHECK_EQ(in->size(), sizeof(T) * n);
163     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
164                            out);
165   }
166 
167   // Decoder of simple type T. Copy the bytes from "in" into the
168   // tensor buffer.
169   template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper170   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
171     if (in.size() != sizeof(T) * n) {
172       LogUnexpectedSize(in.size(), sizeof(T) * n);
173       return nullptr;
174     }
175     Buffer<T>* buf = new Buffer<T>(a, n);
176     char* data = buf->template base<char>();
177     if (data == nullptr) {
178       buf->Unref();
179       return nullptr;
180     }
181     port::CopyToArray(in, data);
182     return buf;
183   }
184 
185   // Memory usage.
TotalBytestensorflow::__anond06fe7e40111::Helper186   static int64 TotalBytes(TensorBuffer* in, int64 n) {
187     DCHECK_EQ(in->size(), sizeof(T) * n);
188     return in->size();
189   }
190 };
191 
192 // Helper specialization for string (the only non-simple type we
193 // support).
194 template <>
195 struct Helper<tstring> {
196   // Proto message uses RepeatedFieldType to hold repeated T.
197   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
198 
199   // Encodes "n" elements of type string stored in "in" into Cord
200   // "out", which is usually the TensorProto::tensor_content.
201   template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper202   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
203     port::EncodeStringList(in->base<const tstring>(), n, out);
204   }
205 
206   // Decodes "n" elements of type string from "in" and constructs a
207   // buffer out of it. Returns nullptr if the decoding fails. "in" is
208   // usually the TensorProto::tensor_content.
209   template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper210   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
211     Buffer<tstring>* buf = new Buffer<tstring>(a, n);
212     tstring* strings = buf->template base<tstring>();
213     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
214       buf->Unref();
215       return nullptr;
216     }
217     return buf;
218   }
219 
220   // Returns the estimated memory usage of "n" elements of type T
221   // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper222   static int64 TotalBytes(TensorBuffer* in, int n) {
223     int64 tot = in->size();
224     DCHECK_EQ(tot, sizeof(tstring) * n);
225     const tstring* p = in->base<const tstring>();
226     for (int i = 0; i < n; ++i, ++p) tot += p->size();
227     return tot;
228   }
229 };
230 
231 template <>
232 struct Helper<ResourceHandle> {
233   // Proto message uses RepeatedFieldType to hold repeated T.
234   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
235 
236   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
237   // "out", which is usually the TensorProto::tensor_content.
238   template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper239   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
241                              port::NewStringListEncoder(out));
242   }
243 
244   // Decodes "n" elements of type string from "in" and constructs a
245   // buffer out of it. Returns nullptr if the decoding fails. "in" is
246   // usually the TensorProto::tensor_content.
247   template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper248   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
249     auto* buf = new Buffer<ResourceHandle>(a, n);
250     ResourceHandle* ps = buf->template base<ResourceHandle>();
251     if (ps == nullptr ||
252         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
253       buf->Unref();
254       return nullptr;
255     }
256     return buf;
257   }
258 
259   // Returns the estimated memory usage of "n" elements of type T
260   // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper261   static int64 TotalBytes(TensorBuffer* in, int n) {
262     return n * sizeof(ResourceHandle);
263   }
264 };
265 
266 template <>
267 struct Helper<Variant> {
268   // Encodes "n" elements of type Variant stored in "in" into destination
269   // "out", which is usually the TensorProto::tensor_content.
270   template <typename Destination>
Encodetensorflow::__anond06fe7e40111::Helper271   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
272     EncodeVariantList(in->base<const Variant>(), n,
273                       port::NewStringListEncoder(out));
274   }
275 
276   // Decodes "n" elements of type Variant from "in" and constructs a
277   // buffer out of it. Returns nullptr if the decoding fails. "in" is
278   // usually the TensorProto::tensor_content.
279   template <typename Source>
Decodetensorflow::__anond06fe7e40111::Helper280   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
281     auto* buf = new Buffer<Variant>(a, n);
282     Variant* ps = buf->template base<Variant>();
283     if (ps == nullptr ||
284         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
285       buf->Unref();
286       return nullptr;
287     }
288     return buf;
289   }
290 
291   // Returns the estimated memory usage of "n" elements of type T
292   // stored in buffer "in".
TotalBytestensorflow::__anond06fe7e40111::Helper293   static int64 TotalBytes(TensorBuffer* in, int n) {
294     return n * sizeof(Variant);
295   }
296 };
297 
298 template <typename T>
299 struct ProtoHelper {};
300 
301 // For a C++ type "T" (float, double, int32, etc.), the repeated field
302 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
303 // int32, string, etc) in the TensorProto is used for serializing the
304 // tensor of type "T".
305 #define PROTO_TRAITS(T, F, N)                                          \
306   template <>                                                          \
307   struct ProtoHelper<T> {                                              \
308     typedef Helper<F>::RepeatedFieldType FieldType;                    \
309     static FieldType::const_iterator Begin(const TensorProto& proto) { \
310       return proto.N##_val().begin();                                  \
311     }                                                                  \
312     static size_t NumElements(const TensorProto& proto) {              \
313       return proto.N##_val().size();                                   \
314     }                                                                  \
315     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
316       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
317       proto->mutable_##N##_val()->Swap(&copy);                         \
318     }                                                                  \
319   };
320 PROTO_TRAITS(float, float, float);
321 PROTO_TRAITS(double, double, double);
322 PROTO_TRAITS(int32, int32, int);
323 PROTO_TRAITS(uint8, int32, int);
324 PROTO_TRAITS(uint16, int32, int);
325 PROTO_TRAITS(uint32, uint32, uint32);
326 PROTO_TRAITS(int16, int32, int);
327 PROTO_TRAITS(int8, int32, int);
328 PROTO_TRAITS(bool, bool, bool);
329 PROTO_TRAITS(tstring, tstring, string);
330 PROTO_TRAITS(qint8, int32, int);
331 PROTO_TRAITS(quint8, int32, int);
332 PROTO_TRAITS(qint16, int32, int);
333 PROTO_TRAITS(quint16, int32, int);
334 #undef PROTO_TRAITS
335 
336 template <>
337 struct ProtoHelper<int64> {
Begintensorflow::__anond06fe7e40111::ProtoHelper338   static const int64* Begin(const TensorProto& proto) {
339     return reinterpret_cast<const int64*>(proto.int64_val().begin());
340   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper341   static size_t NumElements(const TensorProto& proto) {
342     return proto.int64_val().size();
343   }
Filltensorflow::__anond06fe7e40111::ProtoHelper344   static void Fill(const int64* data, size_t n, TensorProto* proto) {
345     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
346     proto->mutable_int64_val()->Swap(&copy);
347   }
348 };
349 
350 template <>
351 struct ProtoHelper<uint64> {
Begintensorflow::__anond06fe7e40111::ProtoHelper352   static const uint64* Begin(const TensorProto& proto) {
353     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
354   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper355   static size_t NumElements(const TensorProto& proto) {
356     return proto.uint64_val().size();
357   }
Filltensorflow::__anond06fe7e40111::ProtoHelper358   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
359     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
360     proto->mutable_uint64_val()->Swap(&copy);
361   }
362 };
363 
364 template <>
365 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anond06fe7e40111::ProtoHelper366   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
367       const TensorProto& proto) {
368     return proto.resource_handle_val().begin();
369   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper370   static size_t NumElements(const TensorProto& proto) {
371     return proto.resource_handle_val().size();
372   }
Filltensorflow::__anond06fe7e40111::ProtoHelper373   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
374     auto* handles = proto->mutable_resource_handle_val();
375     handles->Clear();
376     for (size_t i = 0; i < n; i++) {
377       data[i].AsProto(handles->Add());
378     }
379   }
380 };
381 
382 template <>
383 struct ProtoHelper<Variant> {
384   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anond06fe7e40111::ProtoHelper385   Begin(const TensorProto& proto) {
386     return proto.variant_val().begin();
387   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper388   static size_t NumElements(const TensorProto& proto) {
389     return proto.variant_val().size();
390   }
Filltensorflow::__anond06fe7e40111::ProtoHelper391   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
392     auto* variant_values = proto->mutable_variant_val();
393     variant_values->Clear();
394     for (size_t i = 0; i < n; ++i) {
395       VariantTensorData tmp;
396       data[i].Encode(&tmp);
397       tmp.ToProto(variant_values->Add());
398     }
399   }
400 };
401 
402 template <>
403 struct ProtoHelper<complex64> {
404   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper405   static const complex64* Begin(const TensorProto& proto) {
406     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
407   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper408   static size_t NumElements(const TensorProto& proto) {
409     return proto.scomplex_val().size() / 2;
410   }
Filltensorflow::__anond06fe7e40111::ProtoHelper411   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
412     const float* p = reinterpret_cast<const float*>(data);
413     FieldType copy(p, p + n * 2);
414     proto->mutable_scomplex_val()->Swap(&copy);
415   }
416 };
417 
418 template <>
419 struct ProtoHelper<complex128> {
420   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper421   static const complex128* Begin(const TensorProto& proto) {
422     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
423   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper424   static size_t NumElements(const TensorProto& proto) {
425     return proto.dcomplex_val().size() / 2;
426   }
Filltensorflow::__anond06fe7e40111::ProtoHelper427   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
428     const double* p = reinterpret_cast<const double*>(data);
429     FieldType copy(p, p + n * 2);
430     proto->mutable_dcomplex_val()->Swap(&copy);
431   }
432 };
433 
434 template <>
435 struct ProtoHelper<qint32> {
436   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anond06fe7e40111::ProtoHelper437   static const qint32* Begin(const TensorProto& proto) {
438     return reinterpret_cast<const qint32*>(proto.int_val().data());
439   }
NumElementstensorflow::__anond06fe7e40111::ProtoHelper440   static size_t NumElements(const TensorProto& proto) {
441     return proto.int_val().size();
442   }
Filltensorflow::__anond06fe7e40111::ProtoHelper443   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
444     const int32* p = reinterpret_cast<const int32*>(data);
445     FieldType copy(p, p + n);
446     proto->mutable_int_val()->Swap(&copy);
447   }
448 };
449 
450 template <>
451 struct ProtoHelper<bfloat16> {
Filltensorflow::__anond06fe7e40111::ProtoHelper452   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
453     proto->mutable_half_val()->Reserve(n);
454     for (size_t i = 0; i < n; ++i) {
455       proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
456     }
457   }
458 };
459 
460 template <>
461 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anond06fe7e40111::ProtoHelper462   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
463     proto->mutable_half_val()->Reserve(n);
464     for (size_t i = 0; i < n; ++i) {
465       proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
466     }
467   }
468 };
469 
470 template <typename T>
Buffer(Allocator * a,int64 n)471 Buffer<T>::Buffer(Allocator* a, int64 n)
472     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
473       elem_(n) {}
474 
475 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)476 Buffer<T>::Buffer(Allocator* a, int64 n,
477                   const AllocationAttributes& allocation_attr)
478     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
479       elem_(n) {}
480 
481 template <typename T>
~Buffer()482 Buffer<T>::~Buffer() {
483   if (data()) {
484     if (MemoryLoggingEnabled()) {
485       RecordDeallocation();
486     }
487     TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
488   }
489 }
490 
491 // Allocates a T[n] buffer. Fills in the buffer with repeated values
492 // in "in".  If "in" has less values than "n", fills the rest of T[n]
493 // with the last value. If "in" has no values, fills T[n] with the
494 // default value for T.
495 //
496 // This routine is using the typed fields (float_val, etc.) in the
497 // tensor proto as opposed to the untyped binary representation
498 // (tensor_content). This is used when we expect the TensorProto is
499 // used by a client program which may not know how to encode a tensor
500 // in the compact binary representation.
501 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)502 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
503   CHECK_GT(n, 0);
504   Buffer<T>* buf = new Buffer<T>(a, n);
505   T* data = buf->template base<T>();
506   if (data == nullptr) {
507     buf->Unref();
508     return nullptr;
509   }
510 
511   const int64 in_n = ProtoHelper<T>::NumElements(in);
512   if (in_n <= 0) {
513     std::fill_n(data, n, T());
514   } else {
515     auto begin = ProtoHelper<T>::Begin(in);
516     if (n <= in_n) {
517       std::copy_n(begin, n, data);
518     } else {
519       std::copy_n(begin, in_n, data);
520       if (std::is_trivially_copyable<T>::value) {
521         const T last = *(data + in_n - 1);
522         std::fill_n(data + in_n, n - in_n, last);
523       } else {
524         const T& last = *(data + in_n - 1);
525         std::fill_n(data + in_n, n - in_n, last);
526       }
527     }
528   }
529 
530   return buf;
531 }
532 
533 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)534 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
535                                       int64 n) {
536   CHECK_GT(n, 0);
537   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
538   Variant* data = buf->template base<Variant>();
539   if (data == nullptr) {
540     buf->Unref();
541     return nullptr;
542   }
543   const int64 in_n = ProtoHelper<Variant>::NumElements(in);
544   if (in_n <= 0) {
545     std::fill_n(data, n, Variant());
546   } else {
547     // If tensor shape says we have n < in_n elements in the output tensor
548     // then make sure to only decode the first n out of the in_n elements in the
549     // in tensors. In all other cases, we decode all in_n elements of in and set
550     // the remaining elements up to n to be the default Variant() value.
551     const int64 real_n = n < in_n ? n : in_n;
552     for (int64 i = 0; i < real_n; ++i) {
553       data[i] = in.variant_val(i);
554       if (!DecodeUnaryVariant(&data[i])) {
555         LOG(ERROR) << "Could not decode variant with type_name: \""
556                    << data[i].TypeName()
557                    << "\".  Perhaps you forgot to register a "
558                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
559         buf->Unref();
560         return nullptr;
561       }
562     }
563     for (int64 i = in_n; i < n; ++i) {
564       data[i] = Variant();
565     }
566   }
567   return buf;
568 }
569 
570 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
571 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
572 // we don't use ProtoHelper<uint16>).
573 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)574 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
575                                           int64 n) {
576   CHECK_GT(n, 0);
577   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
578   uint16* data = buf->template base<uint16>();
579   if (data == nullptr) {
580     buf->Unref();
581     return nullptr;
582   }
583   const int64 in_n = in.half_val().size();
584   auto begin = in.half_val().begin();
585   if (n <= in_n) {
586     std::copy_n(begin, n, data);
587   } else if (in_n > 0) {
588     std::copy_n(begin, in_n, data);
589     const uint16 last = *(data + in_n - 1);
590     std::fill_n(data + in_n, n - in_n, last);
591   } else {
592     std::fill_n(data, n, 0);
593   }
594   return buf;
595 }
596 
597 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)598 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
599                                        int64 n) {
600   CHECK_GT(n, 0);
601   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
602   uint16* data = buf->template base<uint16>();
603   if (data == nullptr) {
604     buf->Unref();
605     return nullptr;
606   }
607   const int64 in_n = in.half_val().size();
608   auto begin = in.half_val().begin();
609   if (n <= in_n) {
610     std::copy_n(begin, n, data);
611   } else if (in_n > 0) {
612     std::copy_n(begin, in_n, data);
613     const uint16 last = *(data + in_n - 1);
614     std::fill_n(data + in_n, n - in_n, last);
615   } else {
616     std::fill_n(data, n, 0);
617   }
618   return buf;
619 }
620 
621 // Copies T[n] stored in the buffer "in" into the repeated field in
622 // "out" corresponding to type T.
623 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)624 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
625   const T* data = in.base<const T>();
626   // NOTE: T may not the same as
627   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
628   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
629   // critical, we can specialize T=float and do memcpy directly.
630   ProtoHelper<T>::Fill(data, n, out);
631 }
632 
RefIfNonNull(core::RefCounted * buf)633 void RefIfNonNull(core::RefCounted* buf) {
634   if (buf) buf->Ref();
635 }
636 
UnrefIfNonNull(core::RefCounted * buf)637 void UnrefIfNonNull(core::RefCounted* buf) {
638   if (buf) buf->Unref();
639 }
640 
641 }  // end namespace
642 
Tensor()643 Tensor::Tensor() : Tensor(DT_FLOAT) {}
644 
Tensor(DataType type)645 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
646 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)647 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
648     : shape_(shape), buf_(buf) {
649   set_dtype(type);
650   RefIfNonNull(buf);
651 }
652 
IsInitialized() const653 bool Tensor::IsInitialized() const {
654   return (buf_ != nullptr && buf_->data() != nullptr) ||
655          shape_.num_elements() == 0;
656 }
657 
CheckType(DataType expected_dtype) const658 void Tensor::CheckType(DataType expected_dtype) const {
659   CHECK_EQ(dtype(), expected_dtype)
660       << " " << DataTypeString(expected_dtype) << " expected, got "
661       << DataTypeString(dtype());
662 }
663 
CheckTypeAndIsAligned(DataType expected_dtype) const664 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
665   CHECK_EQ(dtype(), expected_dtype)
666       << " " << DataTypeString(expected_dtype) << " expected, got "
667       << DataTypeString(dtype());
668   CHECK(IsAligned()) << "ptr = " << base<void>();
669 }
670 
CheckIsAlignedAndSingleElement() const671 void Tensor::CheckIsAlignedAndSingleElement() const {
672   CHECK(IsAligned()) << "Aligned and single element";
673   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
674 }
675 
~Tensor()676 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
677 
CopyFromInternal(const Tensor & other,const TensorShape & shape)678 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
679   CHECK_EQ(shape.num_elements(), other.NumElements());
680   // Data type will be overwritten if this == &other, since dtype is part of
681   // shape.
682   DataType other_dtype = other.dtype();
683   shape_ = shape;
684   set_dtype(other_dtype);
685   if (buf_ != other.buf_) {
686     UnrefIfNonNull(buf_);
687     buf_ = other.buf_;
688     RefIfNonNull(buf_);
689   }
690 }
691 
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)692 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
693                            const TensorShape& shape) {
694   int in_size = DataTypeSize(other.dtype());
695   int out_size = DataTypeSize(dtype);
696   if (in_size == 0) {
697     return errors::InvalidArgument("other tensor has zero-sized data type");
698   }
699   if (out_size == 0) {
700     return errors::InvalidArgument("specified output type is zero-sized");
701   }
702   if (shape.num_elements() * out_size !=
703       other.shape().num_elements() * in_size) {
704     return errors::InvalidArgument(
705         "input and output shapes/data type sizes are not compatible");
706   }
707   shape_ = shape;
708   shape_.set_data_type(dtype);
709   if (buf_ != other.buf_) {
710     UnrefIfNonNull(buf_);
711     buf_ = other.buf_;
712     RefIfNonNull(buf_);
713   }
714   return Status::OK();
715 }
716 
717 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
718 // For the latter case, we have to make sure that the refcount is
719 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const720 bool Tensor::RefCountIsOne() const {
721   return buf_ != nullptr && buf_->RefCountIsOne() &&
722          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
723 }
724 
725 // The macro CASES() expands to a switch statement conditioned on
726 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
727 #define SINGLE_ARG(...) __VA_ARGS__
728 #define CASE(TYPE, STMTS)             \
729   case DataTypeToEnum<TYPE>::value: { \
730     typedef TYPE T;                   \
731     STMTS;                            \
732     break;                            \
733   }
734 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
735   switch (TYPE_ENUM) {                                         \
736     CASE(float, SINGLE_ARG(STMTS))                             \
737     CASE(double, SINGLE_ARG(STMTS))                            \
738     CASE(int32, SINGLE_ARG(STMTS))                             \
739     CASE(uint8, SINGLE_ARG(STMTS))                             \
740     CASE(uint16, SINGLE_ARG(STMTS))                            \
741     CASE(uint32, SINGLE_ARG(STMTS))                            \
742     CASE(uint64, SINGLE_ARG(STMTS))                            \
743     CASE(int16, SINGLE_ARG(STMTS))                             \
744     CASE(int8, SINGLE_ARG(STMTS))                              \
745     CASE(tstring, SINGLE_ARG(STMTS))                           \
746     CASE(complex64, SINGLE_ARG(STMTS))                         \
747     CASE(complex128, SINGLE_ARG(STMTS))                        \
748     CASE(int64, SINGLE_ARG(STMTS))                             \
749     CASE(bool, SINGLE_ARG(STMTS))                              \
750     CASE(qint32, SINGLE_ARG(STMTS))                            \
751     CASE(quint8, SINGLE_ARG(STMTS))                            \
752     CASE(qint8, SINGLE_ARG(STMTS))                             \
753     CASE(quint16, SINGLE_ARG(STMTS))                           \
754     CASE(qint16, SINGLE_ARG(STMTS))                            \
755     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
756     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
757     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
758     CASE(Variant, SINGLE_ARG(STMTS))                           \
759     case DT_INVALID:                                           \
760       INVALID;                                                 \
761       break;                                                   \
762     default:                                                   \
763       DEFAULT;                                                 \
764       break;                                                   \
765   }
766 
767 #define CASES(TYPE_ENUM, STMTS)                                      \
768   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
769                      , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
770 
Tensor(Allocator * a,DataType type,const TensorShape & shape)771 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
772     : shape_(shape), buf_(nullptr) {
773   set_dtype(type);
774   CHECK_NOTNULL(a);
775   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
776     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
777   }
778   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
779     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
780                                       *this);
781   }
782 }
783 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)784 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
785                const AllocationAttributes& allocation_attr)
786     : shape_(shape), buf_(nullptr) {
787   set_dtype(type);
788   CHECK_NOTNULL(a);
789   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
790     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
791   }
792   if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
793       buf_ != nullptr && buf_->data() != nullptr) {
794     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
795                                       LogMemory::UNKNOWN_STEP_ID, *this);
796   }
797 }
798 
799 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
800 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
801 // acquiring a lock, which guards initialization of the per-NUMA zone
802 // allocators, and becomes highly contended.
803 //
804 // Note also that it would be better if all Tensor allocations required the user
805 // to specify an allocator, for purposes of accounting, etc. However, the
806 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()807 static Allocator* get_default_cpu_allocator() {
808   static Allocator* default_cpu_allocator =
809       cpu_allocator(port::kNUMANoAffinity);
810   return default_cpu_allocator;
811 }
812 
Tensor(DataType type,const TensorShape & shape)813 Tensor::Tensor(DataType type, const TensorShape& shape)
814     : Tensor(get_default_cpu_allocator(), type, shape) {}
815 
GetAllocatedBytes(size_t * out_bytes) const816 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
817     size_t* out_bytes) const {
818   // `this->FillAllocationDescription()` never sets allocated bytes information,
819   // so we can short-circuit the construction of an `AllocationDescription`.
820   return false;
821 }
822 
FillAllocationDescription(AllocationDescription * proto) const823 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
824     AllocationDescription* proto) const {
825   proto->set_requested_bytes(size());
826   proto->set_allocator_name("HostScalarTensorBuffer");
827   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
828 }
829 
830 template <typename T>
831 class SubBuffer : public TensorBuffer {
832  public:
833   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)834   SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
835       : TensorBuffer(buf->base<T>() + delta),
836         root_(buf->root_buffer()),
837         elem_(n) {
838     // Sanity check. The caller should ensure the sub buffer is valid.
839     CHECK_LE(root_->base<T>(), this->base<T>());
840     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
841     CHECK_LE(this->base<T>(), root_limit);
842     CHECK_LE(this->base<T>() + n, root_limit);
843     // Hold a ref of the underlying root buffer.
844     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
845     root_->Ref();
846   }
847 
size() const848   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()849   TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const850   bool GetAllocatedBytes(size_t* out_bytes) const override {
851     return root_->GetAllocatedBytes(out_bytes);
852   }
FillAllocationDescription(AllocationDescription * proto) const853   void FillAllocationDescription(AllocationDescription* proto) const override {
854     root_->FillAllocationDescription(proto);
855   }
856 
857  private:
858   TensorBuffer* root_;
859   int64 elem_;
860 
~SubBuffer()861   ~SubBuffer() override { root_->Unref(); }
862 
863   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
864 };
865 
Slice(int64 start,int64 limit) const866 Tensor Tensor::Slice(int64 start, int64 limit) const {
867   CHECK_GE(dims(), 1);
868   CHECK_LE(0, start);
869   CHECK_LE(start, limit);
870   int64 dim0_size = shape_.dim_size(0);
871   CHECK_LE(limit, dim0_size);
872   if ((start == 0) && (limit == dim0_size)) {
873     return *this;
874   }
875   Tensor ret;
876   ret.shape_ = shape_;
877   ret.set_dtype(dtype());
878   ret.buf_ = nullptr;
879   if (dim0_size > 0) {
880     const int64 elems_per_dim0 = NumElements() / dim0_size;
881     const int64 delta = start * elems_per_dim0;
882     dim0_size = limit - start;
883     ret.shape_.set_dim(0, dim0_size);
884     const int64 num_elems = dim0_size * elems_per_dim0;
885     if (buf_) {
886       DataType dt = dtype();
887       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
888     }
889   }
890   return ret;
891 }
892 
SubSlice(int64 index) const893 Tensor Tensor::SubSlice(int64 index) const {
894   CHECK_GE(dims(), 1);  // Crash ok.
895   CHECK_LE(0, index);   // Crash ok.
896   int64 dim0_size = shape_.dim_size(0);
897   CHECK_LE(index, dim0_size);  // Crash ok.
898   Tensor ret;
899   ret.shape_ = shape_;
900   ret.shape_.RemoveDim(0);
901   ret.set_dtype(dtype());
902   ret.buf_ = nullptr;
903   if (dim0_size > 0) {
904     const int64 elems_per_dim0 = NumElements() / dim0_size;
905     const int64 delta = index * elems_per_dim0;
906     const int64 num_elems = elems_per_dim0;
907     if (buf_) {
908       DataType dt = dtype();
909       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
910     }
911   }
912   return ret;
913 }
914 
FromProto(const TensorProto & proto)915 bool Tensor::FromProto(const TensorProto& proto) {
916   return FromProto(get_default_cpu_allocator(), proto);
917 }
918 
FromProto(Allocator * a,const TensorProto & proto)919 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
920   CHECK_NOTNULL(a);
921   TensorBuffer* p = nullptr;
922   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
923   if (proto.dtype() == DT_INVALID) return false;
924   TensorShape shape(proto.tensor_shape());
925   const int64 N = shape.num_elements();
926   if (N > 0 && proto.dtype()) {
927     bool dtype_error = false;
928     if (!proto.tensor_content().empty()) {
929       const auto& content = proto.tensor_content();
930       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
931                          dtype_error = true, dtype_error = true);
932     } else {
933       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
934                          dtype_error = true, dtype_error = true);
935     }
936     if (dtype_error || p == nullptr) return false;
937   }
938   shape_ = shape;
939   set_dtype(proto.dtype());
940   UnrefIfNonNull(buf_);
941   buf_ = p;
942   // TODO(misard) add tracking of which kernels and steps are calling
943   // FromProto.
944   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
945     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
946                                       LogMemory::UNKNOWN_STEP_ID, *this);
947   }
948   return true;
949 }
950 
AsProtoField(TensorProto * proto) const951 void Tensor::AsProtoField(TensorProto* proto) const {
952   proto->Clear();
953   shape_.AsProto(proto->mutable_tensor_shape());
954   proto->set_dtype(dtype());
955   if (buf_) {
956     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
957   }
958 }
959 
AsProtoTensorContent(TensorProto * proto) const960 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
961   proto->Clear();
962   proto->set_dtype(dtype());
963   shape_.AsProto(proto->mutable_tensor_shape());
964   if (buf_) {
965     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
966                                      proto->mutable_tensor_content()));
967   }
968 }
969 
TotalBytes() const970 size_t Tensor::TotalBytes() const {
971   if (shape_.num_elements() == 0) return 0;
972   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
973   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
974   return 0;  // Makes compiler happy.
975 }
976 
AllocatedBytes() const977 size_t Tensor::AllocatedBytes() const {
978   if (buf_) {
979     size_t ret;
980     if (buf_->GetAllocatedBytes(&ret)) {
981       return ret;
982     }
983   }
984   return TotalBytes();
985 }
986 
CanUseDMA() const987 bool Tensor::CanUseDMA() const {
988   CASES(dtype(), return is_simple_type<T>::value);
989   return false;  // Makes compiler happy.
990 }
991 
992 #undef CASES
993 #undef CASE
994 
995 namespace {
996 
997 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
998 // we would like to keep them compatible with their absl counterparts, for ease
999 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
1000 // logic is so simple we can just replicate it here, where it is close to its
1001 // usage and easy to change later. And there's the extra benefit of not
1002 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)1003 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1004                                                 bool print_v2) {
1005   return a;
1006 }
PrintOneElement(const tstring & a,bool print_v2)1007 inline string PrintOneElement(const tstring& a, bool print_v2) {
1008   if (print_v2) {
1009     return "\"" + absl::CEscape(a) + "\"";
1010   } else {
1011     return absl::CEscape(a);
1012   }
1013 }
PrintOneElement(const Eigen::half & h,bool print_v2)1014 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1015   return static_cast<float>(h);
1016 }
1017 
PrintOneElement(bfloat16 f,bool print_v2)1018 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1019   return static_cast<float>(f);
1020 }
1021 
1022 // Print from left dim to right dim recursively.
1023 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)1024 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1025                  int64 limit, int shape_size, const T* data, int64* data_index,
1026                  string* result) {
1027   if (*data_index >= limit) return;
1028   int64 element_count = shape[dim_index];
1029   // We have reached the right-most dimension of the tensor.
1030   if (dim_index == shape_size - 1) {
1031     for (int64 i = 0; i < element_count; i++) {
1032       if (*data_index >= limit) {
1033         // If not enough elements has been printed, append "...".
1034         if (dim_index != 0) {
1035           strings::StrAppend(result, "...");
1036         }
1037         return;
1038       }
1039       if (i > 0) strings::StrAppend(result, " ");
1040       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1041     }
1042     return;
1043   }
1044   // Loop every element of one dim.
1045   for (int64 i = 0; i < element_count; i++) {
1046     bool flag = false;
1047     if (*data_index < limit) {
1048       strings::StrAppend(result, "[");
1049       flag = true;
1050     }
1051     // As for each element, print the sub-dim.
1052     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1053                 result);
1054     if (*data_index < limit || flag) {
1055       strings::StrAppend(result, "]");
1056       flag = false;
1057     }
1058   }
1059 }
1060 
1061 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1062 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1063   if (dim_index == num_dims - 1) {
1064     strings::StrAppend(result, " ");
1065     return;
1066   }
1067   for (int j = 0; j < num_dims - dim_index - 1; j++) {
1068     strings::StrAppend(result, "\n");
1069   }
1070   for (int j = 0; j <= dim_index; j++) {
1071     strings::StrAppend(result, " ");
1072   }
1073 }
1074 
1075 // Print from left dim to right dim recursively.
1076 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1077 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1078                    int64 num_elts_at_ends, int num_dims, const T* data,
1079                    int64 data_index, string* result) {
1080   // We have recursed beyond all the dimensions into a single element
1081   // of the tensor.
1082   if (dim_index == num_dims) {
1083     strings::StrAppend(result, PrintOneElement(data[data_index], true));
1084     return;
1085   }
1086 
1087   strings::StrAppend(result, "[");
1088   int64 element_count = shape[dim_index];
1089   int64 start_of_end =
1090       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1091 
1092   // Loop every element of one dim.
1093   int64 elements_per_iter = 1;
1094   for (int i = dim_index + 1; i < num_dims; i++) {
1095     elements_per_iter *= shape[i];
1096   }
1097   for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1098     if (i > 0) {
1099       PrintDimSpacing(dim_index, num_dims, result);
1100     }
1101 
1102     // As for each element, print the sub-dim.
1103     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1104                   data_index + elements_per_iter * i, result);
1105   }
1106   if (element_count > 2 * num_elts_at_ends) {
1107     PrintDimSpacing(dim_index, num_dims, result);
1108     strings::StrAppend(result, "...");
1109   }
1110   for (int64 i = start_of_end; i < element_count; i++) {
1111     // As for each element, print the sub-dim.
1112     PrintDimSpacing(dim_index, num_dims, result);
1113     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1114                   data_index + elements_per_iter * i, result);
1115   }
1116 
1117   strings::StrAppend(result, "]");
1118 }
1119 
1120 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1121 string SummarizeArray(int64 limit, int64 num_elts,
1122                       const TensorShape& tensor_shape, const char* data,
1123                       const bool print_v2) {
1124   string ret;
1125   const T* array = reinterpret_cast<const T*>(data);
1126 
1127   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1128   if (shape.empty()) {
1129     for (int64 i = 0; i < limit; ++i) {
1130       if (i > 0) strings::StrAppend(&ret, " ");
1131       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1132     }
1133     if (num_elts > limit) strings::StrAppend(&ret, "...");
1134     return ret;
1135   }
1136   if (print_v2) {
1137     const int num_dims = tensor_shape.dims();
1138     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1139   } else {
1140     int64 data_index = 0;
1141     const int shape_size = tensor_shape.dims();
1142     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1143 
1144     if (num_elts > limit) strings::StrAppend(&ret, "...");
1145   }
1146 
1147   return ret;
1148 }
1149 }  // namespace
1150 
SummarizeValue(int64 max_entries,bool print_v2) const1151 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1152   const int64 num_elts = NumElements();
1153   if (max_entries < 0) {
1154     max_entries = num_elts;
1155   }
1156   size_t limit = std::min(max_entries, num_elts);
1157   if ((limit > 0) && (buf_ == nullptr)) {
1158     return strings::StrCat("uninitialized Tensor of ", num_elts,
1159                            " elements of type ", dtype());
1160   }
1161   const char* data = limit > 0 ? tensor_data().data() : nullptr;
1162   switch (dtype()) {
1163     case DT_BFLOAT16:
1164       return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1165       break;
1166     case DT_HALF:
1167       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1168                                          print_v2);
1169       break;
1170     case DT_FLOAT:
1171       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1172       break;
1173     case DT_DOUBLE:
1174       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1175       break;
1176     case DT_UINT32:
1177       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1178       break;
1179     case DT_INT32:
1180       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1181       break;
1182     case DT_UINT8:
1183     case DT_QUINT8:
1184       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1185       break;
1186     case DT_UINT16:
1187     case DT_QUINT16:
1188       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1189       break;
1190     case DT_INT16:
1191     case DT_QINT16:
1192       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1193       break;
1194     case DT_INT8:
1195     case DT_QINT8:
1196       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1197       break;
1198     case DT_UINT64:
1199       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1200       break;
1201     case DT_INT64:
1202       return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1203       break;
1204     case DT_BOOL:
1205       // TODO(tucker): Is it better to emit "True False..."?  This
1206       // will emit "1 0..." which is more compact.
1207       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1208       break;
1209     case DT_STRING:
1210       return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1211       break;
1212     default: {
1213       // All irregular cases
1214       string ret;
1215       if (print_v2) {
1216         strings::StrAppend(&ret, "[");
1217       }
1218       // TODO(irving): Don't call flat every time around this
1219       // loop.
1220       for (size_t i = 0; i < limit; ++i) {
1221         if (i > 0) strings::StrAppend(&ret, " ");
1222         switch (dtype()) {
1223           case DT_VARIANT: {
1224             const Variant& v = flat<Variant>()(i);
1225             strings::StrAppend(&ret, v.DebugString());
1226           } break;
1227           default:
1228             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1229             // complex64, quantized).
1230             strings::StrAppend(&ret, "?");
1231         }
1232       }
1233       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1234       if (print_v2) {
1235         strings::StrAppend(&ret, "]");
1236       }
1237       return ret;
1238     }
1239   }
1240 }
1241 
tensor_data() const1242 StringPiece Tensor::tensor_data() const {
1243   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1244   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1245 }
1246 
SharesBufferWith(const Tensor & b) const1247 bool Tensor::SharesBufferWith(const Tensor& b) const {
1248   return buf_ != nullptr && b.buf_ != nullptr &&
1249          buf_->root_buffer() == b.buf_->root_buffer();
1250 }
1251 
DebugString(int num_values) const1252 string Tensor::DebugString(int num_values) const {
1253   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1254                          " shape: ", shape().DebugString(),
1255                          " values: ", SummarizeValue(num_values), ">");
1256 }
1257 
DeviceSafeDebugString() const1258 string Tensor::DeviceSafeDebugString() const {
1259   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1260                          " shape: ", shape().DebugString(), ">");
1261 }
1262 
FillDescription(TensorDescription * description) const1263 void Tensor::FillDescription(TensorDescription* description) const {
1264   description->set_dtype(dtype());
1265   shape().AsProto(description->mutable_shape());
1266   if (buf_ != nullptr && buf_->data() != nullptr) {
1267     buf_->FillAllocationDescription(
1268         description->mutable_allocation_description());
1269   }
1270 }
1271 
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1272 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1273     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1274   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1275   int64 offset = orig.size() - num_out_dims;
1276   for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1277     const int64 in_dim = out_dim + offset;
1278     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1279   }
1280   for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1281     out_dims[0] *= orig[in_dim];
1282   }
1283   return out_dims;
1284 }
1285 
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1286 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1287     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1288   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1289   for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1290     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1291   }
1292   for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1293     out_dims[num_out_dims - 1] *= orig[in_dim];
1294   }
1295   return out_dims;
1296 }
1297 
1298 }  // namespace tensorflow
1299