• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include <utility>
33 
34 #include "absl/strings/escaping.h"
35 #include "tensorflow/core/framework/allocation_description.pb.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/resource_handle.h"
38 #include "tensorflow/core/framework/resource_handle.pb.h"
39 #include "tensorflow/core/framework/tensor.pb.h"
40 #include "tensorflow/core/framework/tensor_description.pb.h"
41 #include "tensorflow/core/framework/type_traits.h"
42 #include "tensorflow/core/framework/typed_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/framework/variant.h"
46 #include "tensorflow/core/framework/variant_encode_decode.h"
47 #include "tensorflow/core/framework/variant_op_registry.h"
48 #include "tensorflow/core/framework/variant_tensor_data.h"
49 #include "tensorflow/core/lib/core/coding.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
53 #include "tensorflow/core/lib/strings/str_util.h"
54 #include "tensorflow/core/lib/strings/strcat.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/tensor_coding.h"
59 #include "tensorflow/core/platform/types.h"
60 
61 namespace tensorflow {
62 
63 // Allow Tensors to be stored inside Variants with automatic
64 // encoding/decoding when those Variants are themselves being decoded
65 // in a Tensor's FromProto.
66 //
67 // NOTE(mrry): The corresponding "copy function" registrations can be found in
68 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
69 // code).
70 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
71 
GetAllocatedBytes(size_t * out_bytes) const72 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
73   AllocationDescription allocation_description;
74   FillAllocationDescription(&allocation_description);
75   if (allocation_description.allocated_bytes() > 0) {
76     *out_bytes = allocation_description.allocated_bytes();
77     return true;
78   } else {
79     return false;
80   }
81 }
82 
83 namespace {
84 
85 // An un-templated base class for Buffer.
86 class BufferBase : public TensorBuffer {
87  public:
BufferBase(Allocator * alloc,void * data_ptr)88   explicit BufferBase(Allocator* alloc, void* data_ptr)
89       : TensorBuffer(data_ptr), alloc_(alloc) {}
90 
root_buffer()91   TensorBuffer* root_buffer() override { return this; }
92 
GetAllocatedBytes(size_t * out_bytes) const93   bool GetAllocatedBytes(size_t* out_bytes) const override {
94     if (alloc_->TracksAllocationSizes()) {
95       *out_bytes = alloc_->AllocatedSize(data());
96       return *out_bytes > 0;
97     } else {
98       return false;
99     }
100   }
101 
FillAllocationDescription(AllocationDescription * proto) const102   void FillAllocationDescription(AllocationDescription* proto) const override {
103     void* data_ptr = data();
104     int64_t rb = size();
105     proto->set_requested_bytes(rb);
106     proto->set_allocator_name(alloc_->Name());
107     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
108     if (alloc_->TracksAllocationSizes()) {
109       int64_t ab = alloc_->AllocatedSize(data_ptr);
110       proto->set_allocated_bytes(ab);
111       int64_t id = alloc_->AllocationId(data_ptr);
112       if (id > 0) {
113         proto->set_allocation_id(id);
114       }
115       if (RefCountIsOne()) {
116         proto->set_has_single_reference(true);
117       }
118     }
119   }
120 
121  protected:
RecordDeallocation()122   void RecordDeallocation() {
123     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
124                                         alloc_->Name());
125   }
126 
127   Allocator* const alloc_;
128 };
129 
130 // Typed ref-counted buffer: T[n].
131 template <typename T>
132 class Buffer : public BufferBase {
133  public:
134   Buffer(Allocator* a, int64_t n);
135   Buffer(Allocator* a, int64_t n, const AllocationAttributes& allocation_attr);
136 
size() const137   size_t size() const override { return sizeof(T) * elem_; }
138 
139  private:
140   int64 elem_;
141 
142   ~Buffer() override;
143 
144   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
145 };
146 
LogUnexpectedSize(int64_t actual,int64_t expected)147 void LogUnexpectedSize(int64_t actual, int64_t expected) {
148   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
149 }
150 
MemoryLoggingEnabled()151 bool MemoryLoggingEnabled() {
152   static bool memory_logging_enabled = LogMemory::IsEnabled();
153   return memory_logging_enabled;
154 }
155 
156 // A set of helper functions depending on T.
157 template <typename T>
158 struct Helper {
159   // By default, we assume T is a simple type (float, int32, etc.)
160   static_assert(is_simple_type<T>::value, "T is not a simple type.");
161   typedef protobuf::RepeatedField<T> RepeatedFieldType;
162 
163   // Encoder of simple type T to a string.  We do a copy.
164   template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper165   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
166     DCHECK_EQ(in->size(), sizeof(T) * n);
167     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
168                            out);
169   }
170 
171   // Decoder of simple type T. Copy the bytes from "in" into the
172   // tensor buffer.
173   template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper174   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
175     if (in.size() != sizeof(T) * n) {
176       LogUnexpectedSize(in.size(), sizeof(T) * n);
177       return nullptr;
178     }
179     Buffer<T>* buf = new Buffer<T>(a, n);
180     char* data = buf->template base<char>();
181     if (data == nullptr) {
182       buf->Unref();
183       return nullptr;
184     }
185     port::CopyToArray(in, data);
186     return buf;
187   }
188 
189   // Memory usage.
TotalBytestensorflow::__anon6d9e61650111::Helper190   static int64 TotalBytes(TensorBuffer* in, int64_t n) {
191     DCHECK_EQ(in->size(), sizeof(T) * n);
192     return in->size();
193   }
194 };
195 
196 // Helper specialization for string (the only non-simple type we
197 // support).
198 template <>
199 struct Helper<tstring> {
200   // Proto message uses RepeatedFieldType to hold repeated T.
201   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
202 
203   // Encodes "n" elements of type string stored in "in" into Cord
204   // "out", which is usually the TensorProto::tensor_content.
205   template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper206   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
207     port::EncodeStringList(in->base<const tstring>(), n, out);
208   }
209 
210   // Decodes "n" elements of type string from "in" and constructs a
211   // buffer out of it. Returns nullptr if the decoding fails. "in" is
212   // usually the TensorProto::tensor_content.
213   template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper214   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
215     Buffer<tstring>* buf = new Buffer<tstring>(a, n);
216     tstring* strings = buf->template base<tstring>();
217     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
218       buf->Unref();
219       return nullptr;
220     }
221     return buf;
222   }
223 
224   // Returns the estimated memory usage of "n" elements of type T
225   // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper226   static int64 TotalBytes(TensorBuffer* in, int n) {
227     int64_t tot = in->size();
228     DCHECK_EQ(tot, sizeof(tstring) * n);
229     const tstring* p = in->base<const tstring>();
230     for (int i = 0; i < n; ++i, ++p) tot += p->size();
231     return tot;
232   }
233 };
234 
235 template <>
236 struct Helper<ResourceHandle> {
237   // Proto message uses RepeatedFieldType to hold repeated T.
238   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
239 
240   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
241   // "out", which is usually the TensorProto::tensor_content.
242   template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper243   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
244     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
245                              port::NewStringListEncoder(out));
246   }
247 
248   // Decodes "n" elements of type string from "in" and constructs a
249   // buffer out of it. Returns nullptr if the decoding fails. "in" is
250   // usually the TensorProto::tensor_content.
251   template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper252   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
253     auto* buf = new Buffer<ResourceHandle>(a, n);
254     ResourceHandle* ps = buf->template base<ResourceHandle>();
255     if (ps == nullptr ||
256         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
257       buf->Unref();
258       return nullptr;
259     }
260     return buf;
261   }
262 
263   // Returns the estimated memory usage of "n" elements of type T
264   // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper265   static int64 TotalBytes(TensorBuffer* in, int n) {
266     return n * sizeof(ResourceHandle);
267   }
268 };
269 
270 template <>
271 struct Helper<Variant> {
272   // Encodes "n" elements of type Variant stored in "in" into destination
273   // "out", which is usually the TensorProto::tensor_content.
274   template <typename Destination>
Encodetensorflow::__anon6d9e61650111::Helper275   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
276     EncodeVariantList(in->base<const Variant>(), n,
277                       port::NewStringListEncoder(out));
278   }
279 
280   // Decodes "n" elements of type Variant from "in" and constructs a
281   // buffer out of it. Returns nullptr if the decoding fails. "in" is
282   // usually the TensorProto::tensor_content.
283   template <typename Source>
Decodetensorflow::__anon6d9e61650111::Helper284   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
285     auto* buf = new Buffer<Variant>(a, n);
286     Variant* ps = buf->template base<Variant>();
287     if (ps == nullptr ||
288         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
289       buf->Unref();
290       return nullptr;
291     }
292     return buf;
293   }
294 
295   // Returns the estimated memory usage of "n" elements of type T
296   // stored in buffer "in".
TotalBytestensorflow::__anon6d9e61650111::Helper297   static int64 TotalBytes(TensorBuffer* in, int n) {
298     return n * sizeof(Variant);
299   }
300 };
301 
302 template <typename T>
303 struct ProtoHelper {};
304 
305 // For a C++ type "T" (float, double, int32, etc.), the repeated field
306 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
307 // int32, string, etc) in the TensorProto is used for serializing the
308 // tensor of type "T".
309 #define PROTO_TRAITS(T, F, N)                                          \
310   template <>                                                          \
311   struct ProtoHelper<T> {                                              \
312     typedef Helper<F>::RepeatedFieldType FieldType;                    \
313     static FieldType::const_iterator Begin(const TensorProto& proto) { \
314       return proto.N##_val().begin();                                  \
315     }                                                                  \
316     static size_t NumElements(const TensorProto& proto) {              \
317       return proto.N##_val().size();                                   \
318     }                                                                  \
319     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
320       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
321       proto->mutable_##N##_val()->Swap(&copy);                         \
322     }                                                                  \
323   };
324 PROTO_TRAITS(float, float, float);
325 PROTO_TRAITS(double, double, double);
326 PROTO_TRAITS(int32, int32, int);
327 PROTO_TRAITS(uint8, int32, int);
328 PROTO_TRAITS(uint16, int32, int);
329 PROTO_TRAITS(uint32, uint32, uint32);
330 PROTO_TRAITS(int16, int32, int);
331 PROTO_TRAITS(int8, int32, int);
332 PROTO_TRAITS(bool, bool, bool);
333 PROTO_TRAITS(tstring, tstring, string);
334 PROTO_TRAITS(qint8, int32, int);
335 PROTO_TRAITS(quint8, int32, int);
336 PROTO_TRAITS(qint16, int32, int);
337 PROTO_TRAITS(quint16, int32, int);
338 #undef PROTO_TRAITS
339 
340 template <>
341 struct ProtoHelper<int64> {
Begintensorflow::__anon6d9e61650111::ProtoHelper342   static const int64* Begin(const TensorProto& proto) {
343     return reinterpret_cast<const int64*>(proto.int64_val().begin());
344   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper345   static size_t NumElements(const TensorProto& proto) {
346     return proto.int64_val().size();
347   }
Filltensorflow::__anon6d9e61650111::ProtoHelper348   static void Fill(const int64* data, size_t n, TensorProto* proto) {
349     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
350     proto->mutable_int64_val()->Swap(&copy);
351   }
352 };
353 
354 template <>
355 struct ProtoHelper<uint64> {
Begintensorflow::__anon6d9e61650111::ProtoHelper356   static const uint64* Begin(const TensorProto& proto) {
357     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
358   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper359   static size_t NumElements(const TensorProto& proto) {
360     return proto.uint64_val().size();
361   }
Filltensorflow::__anon6d9e61650111::ProtoHelper362   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
363     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
364     proto->mutable_uint64_val()->Swap(&copy);
365   }
366 };
367 
368 template <>
369 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon6d9e61650111::ProtoHelper370   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
371       const TensorProto& proto) {
372     return proto.resource_handle_val().begin();
373   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper374   static size_t NumElements(const TensorProto& proto) {
375     return proto.resource_handle_val().size();
376   }
Filltensorflow::__anon6d9e61650111::ProtoHelper377   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
378     auto* handles = proto->mutable_resource_handle_val();
379     handles->Clear();
380     for (size_t i = 0; i < n; i++) {
381       data[i].AsProto(handles->Add());
382     }
383   }
384 };
385 
386 template <>
387 struct ProtoHelper<Variant> {
388   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon6d9e61650111::ProtoHelper389   Begin(const TensorProto& proto) {
390     return proto.variant_val().begin();
391   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper392   static size_t NumElements(const TensorProto& proto) {
393     return proto.variant_val().size();
394   }
Filltensorflow::__anon6d9e61650111::ProtoHelper395   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
396     auto* variant_values = proto->mutable_variant_val();
397     variant_values->Clear();
398     for (size_t i = 0; i < n; ++i) {
399       VariantTensorData tmp;
400       data[i].Encode(&tmp);
401       tmp.ToProto(variant_values->Add());
402     }
403   }
404 };
405 
406 template <>
407 struct ProtoHelper<complex64> {
408   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper409   static const complex64* Begin(const TensorProto& proto) {
410     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
411   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper412   static size_t NumElements(const TensorProto& proto) {
413     return proto.scomplex_val().size() / 2;
414   }
Filltensorflow::__anon6d9e61650111::ProtoHelper415   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
416     const float* p = reinterpret_cast<const float*>(data);
417     FieldType copy(p, p + n * 2);
418     proto->mutable_scomplex_val()->Swap(&copy);
419   }
420 };
421 
422 template <>
423 struct ProtoHelper<complex128> {
424   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper425   static const complex128* Begin(const TensorProto& proto) {
426     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
427   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper428   static size_t NumElements(const TensorProto& proto) {
429     return proto.dcomplex_val().size() / 2;
430   }
Filltensorflow::__anon6d9e61650111::ProtoHelper431   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
432     const double* p = reinterpret_cast<const double*>(data);
433     FieldType copy(p, p + n * 2);
434     proto->mutable_dcomplex_val()->Swap(&copy);
435   }
436 };
437 
438 template <>
439 struct ProtoHelper<qint32> {
440   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon6d9e61650111::ProtoHelper441   static const qint32* Begin(const TensorProto& proto) {
442     return reinterpret_cast<const qint32*>(proto.int_val().data());
443   }
NumElementstensorflow::__anon6d9e61650111::ProtoHelper444   static size_t NumElements(const TensorProto& proto) {
445     return proto.int_val().size();
446   }
Filltensorflow::__anon6d9e61650111::ProtoHelper447   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
448     const int32* p = reinterpret_cast<const int32*>(data);
449     FieldType copy(p, p + n);
450     proto->mutable_int_val()->Swap(&copy);
451   }
452 };
453 
454 template <>
455 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon6d9e61650111::ProtoHelper456   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
457     proto->mutable_half_val()->Reserve(n);
458     for (size_t i = 0; i < n; ++i) {
459       proto->mutable_half_val()->AddAlreadyReserved(
460           Eigen::numext::bit_cast<uint16>(data[i]));
461     }
462   }
463 };
464 
465 template <>
466 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon6d9e61650111::ProtoHelper467   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
468     proto->mutable_half_val()->Reserve(n);
469     for (size_t i = 0; i < n; ++i) {
470       proto->mutable_half_val()->AddAlreadyReserved(
471           Eigen::numext::bit_cast<uint16>(data[i]));
472     }
473   }
474 };
475 
476 template <typename T>
Buffer(Allocator * a,int64_t n)477 Buffer<T>::Buffer(Allocator* a, int64_t n)
478     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
479       elem_(n) {}
480 
481 template <typename T>
Buffer(Allocator * a,int64_t n,const AllocationAttributes & allocation_attr)482 Buffer<T>::Buffer(Allocator* a, int64_t n,
483                   const AllocationAttributes& allocation_attr)
484     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
485       elem_(n) {}
486 
487 template <typename T>
~Buffer()488 Buffer<T>::~Buffer() {
489   if (data()) {
490     if (MemoryLoggingEnabled()) {
491       RecordDeallocation();
492     }
493     TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
494   }
495 }
496 
497 // Allocates a T[n] buffer. Fills in the buffer with repeated values
498 // in "in".  If "in" has less values than "n", fills the rest of T[n]
499 // with the last value. If "in" has no values, fills T[n] with the
500 // default value for T.
501 //
502 // This routine is using the typed fields (float_val, etc.) in the
503 // tensor proto as opposed to the untyped binary representation
504 // (tensor_content). This is used when we expect the TensorProto is
505 // used by a client program which may not know how to encode a tensor
506 // in the compact binary representation.
507 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)508 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) {
509   CHECK_GT(n, 0);
510   Buffer<T>* buf = new Buffer<T>(a, n);
511   T* data = buf->template base<T>();
512   if (data == nullptr) {
513     buf->Unref();
514     return nullptr;
515   }
516 
517   const int64_t in_n = ProtoHelper<T>::NumElements(in);
518   if (in_n <= 0) {
519     std::fill_n(data, n, T());
520   } else {
521     auto begin = ProtoHelper<T>::Begin(in);
522     if (n <= in_n) {
523       std::copy_n(begin, n, data);
524     } else {
525       std::copy_n(begin, in_n, data);
526       if (std::is_trivially_copyable<T>::value) {
527         const T last = *(data + in_n - 1);
528         std::fill_n(data + in_n, n - in_n, last);
529       } else {
530         const T& last = *(data + in_n - 1);
531         std::fill_n(data + in_n, n - in_n, last);
532       }
533     }
534   }
535 
536   return buf;
537 }
538 
539 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)540 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
541                                       int64_t n) {
542   CHECK_GT(n, 0);
543   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
544   Variant* data = buf->template base<Variant>();
545   if (data == nullptr) {
546     buf->Unref();
547     return nullptr;
548   }
549   const int64_t in_n = ProtoHelper<Variant>::NumElements(in);
550   if (in_n <= 0) {
551     std::fill_n(data, n, Variant());
552   } else {
553     // If tensor shape says we have n < in_n elements in the output tensor
554     // then make sure to only decode the first n out of the in_n elements in the
555     // in tensors. In all other cases, we decode all in_n elements of in and set
556     // the remaining elements up to n to be the default Variant() value.
557     const int64_t real_n = n < in_n ? n : in_n;
558     for (int64_t i = 0; i < real_n; ++i) {
559       data[i] = in.variant_val(i);
560       if (!DecodeUnaryVariant(&data[i])) {
561         LOG(ERROR) << "Could not decode variant with type_name: \""
562                    << data[i].TypeName()
563                    << "\".  Perhaps you forgot to register a "
564                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
565         buf->Unref();
566         return nullptr;
567       }
568     }
569     for (int64_t i = in_n; i < n; ++i) {
570       data[i] = Variant();
571     }
572   }
573   return buf;
574 }
575 
576 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
577 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
578 // we don't use ProtoHelper<uint16>).
579 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)580 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
581                                           int64_t n) {
582   CHECK_GT(n, 0);
583   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
584   uint16* data = buf->template base<uint16>();
585   if (data == nullptr) {
586     buf->Unref();
587     return nullptr;
588   }
589   const int64_t in_n = in.half_val().size();
590   auto begin = in.half_val().begin();
591   if (n <= in_n) {
592     std::copy_n(begin, n, data);
593   } else if (in_n > 0) {
594     std::copy_n(begin, in_n, data);
595     const uint16 last = *(data + in_n - 1);
596     std::fill_n(data + in_n, n - in_n, last);
597   } else {
598     std::fill_n(data, n, 0);
599   }
600   return buf;
601 }
602 
603 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)604 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
605                                        int64_t n) {
606   CHECK_GT(n, 0);
607   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
608   uint16* data = buf->template base<uint16>();
609   if (data == nullptr) {
610     buf->Unref();
611     return nullptr;
612   }
613   const int64_t in_n = in.half_val().size();
614   auto begin = in.half_val().begin();
615   if (n <= in_n) {
616     std::copy_n(begin, n, data);
617   } else if (in_n > 0) {
618     std::copy_n(begin, in_n, data);
619     const uint16 last = *(data + in_n - 1);
620     std::fill_n(data + in_n, n - in_n, last);
621   } else {
622     std::fill_n(data, n, 0);
623   }
624   return buf;
625 }
626 
627 // Copies T[n] stored in the buffer "in" into the repeated field in
628 // "out" corresponding to type T.
629 template <typename T>
ToProtoField(const TensorBuffer & in,int64_t n,TensorProto * out)630 void ToProtoField(const TensorBuffer& in, int64_t n, TensorProto* out) {
631   const T* data = in.base<const T>();
632   // NOTE: T may not the same as
633   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
634   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
635   // critical, we can specialize T=float and do memcpy directly.
636   ProtoHelper<T>::Fill(data, n, out);
637 }
638 
RefIfNonNull(core::RefCounted * buf)639 void RefIfNonNull(core::RefCounted* buf) {
640   if (buf) buf->Ref();
641 }
642 
UnrefIfNonNull(core::RefCounted * buf)643 void UnrefIfNonNull(core::RefCounted* buf) {
644   if (buf) buf->Unref();
645 }
646 
647 }  // end namespace
648 
Tensor()649 Tensor::Tensor() : Tensor(DT_FLOAT) {}
650 
Tensor(DataType type)651 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
652 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)653 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
654     : shape_(shape), buf_(buf) {
655   set_dtype(type);
656   RefIfNonNull(buf);
657 }
658 
Tensor(DataType type,TensorShape shape,core::RefCountPtr<TensorBuffer> buf)659 Tensor::Tensor(DataType type, TensorShape shape,
660                core::RefCountPtr<TensorBuffer> buf)
661     : shape_(std::move(shape)), buf_(buf.release()) {
662   set_dtype(type);
663 }
664 
IsInitialized() const665 bool Tensor::IsInitialized() const {
666   return (buf_ != nullptr && buf_->data() != nullptr) ||
667          shape_.num_elements() == 0;
668 }
669 
CheckType(DataType expected_dtype) const670 void Tensor::CheckType(DataType expected_dtype) const {
671   CHECK_EQ(dtype(), expected_dtype)
672       << " " << DataTypeString(expected_dtype) << " expected, got "
673       << DataTypeString(dtype());
674 }
675 
CheckTypeAndIsAligned(DataType expected_dtype) const676 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
677   CHECK_EQ(dtype(), expected_dtype)
678       << " " << DataTypeString(expected_dtype) << " expected, got "
679       << DataTypeString(dtype());
680   CHECK(IsAligned()) << "ptr = " << base<void>();
681 }
682 
CheckIsAlignedAndSingleElement() const683 void Tensor::CheckIsAlignedAndSingleElement() const {
684   CHECK(IsAligned()) << "Aligned and single element";
685   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
686 }
687 
~Tensor()688 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
689 
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)690 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
691                            const TensorShape& shape) {
692   int in_size = DataTypeSize(other.dtype());
693   int out_size = DataTypeSize(dtype);
694   if (in_size == 0) {
695     return errors::InvalidArgument("other tensor has zero-sized data type");
696   }
697   if (out_size == 0) {
698     return errors::InvalidArgument("specified output type is zero-sized");
699   }
700   if (shape.num_elements() * out_size !=
701       other.shape().num_elements() * in_size) {
702     return errors::InvalidArgument(
703         "input and output shapes/data type sizes are not compatible");
704   }
705   shape_ = shape;
706   shape_.set_data_type(dtype);
707   if (buf_ != other.buf_) {
708     UnrefIfNonNull(buf_);
709     buf_ = other.buf_;
710     RefIfNonNull(buf_);
711   }
712   return Status::OK();
713 }
714 
715 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
716 // For the latter case, we have to make sure that the refcount is
717 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const718 bool Tensor::RefCountIsOne() const {
719   return buf_ != nullptr && buf_->RefCountIsOne() &&
720          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
721 }
722 
723 // The macro CASES() expands to a switch statement conditioned on
724 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
725 #define SINGLE_ARG(...) __VA_ARGS__
726 #define CASE(TYPE, STMTS)             \
727   case DataTypeToEnum<TYPE>::value: { \
728     typedef TYPE T;                   \
729     STMTS;                            \
730     break;                            \
731   }
732 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
733   switch (TYPE_ENUM) {                                         \
734     CASE(float, SINGLE_ARG(STMTS))                             \
735     CASE(double, SINGLE_ARG(STMTS))                            \
736     CASE(int32, SINGLE_ARG(STMTS))                             \
737     CASE(uint8, SINGLE_ARG(STMTS))                             \
738     CASE(uint16, SINGLE_ARG(STMTS))                            \
739     CASE(uint32, SINGLE_ARG(STMTS))                            \
740     CASE(uint64, SINGLE_ARG(STMTS))                            \
741     CASE(int16, SINGLE_ARG(STMTS))                             \
742     CASE(int8, SINGLE_ARG(STMTS))                              \
743     CASE(tstring, SINGLE_ARG(STMTS))                           \
744     CASE(complex64, SINGLE_ARG(STMTS))                         \
745     CASE(complex128, SINGLE_ARG(STMTS))                        \
746     CASE(int64, SINGLE_ARG(STMTS))                             \
747     CASE(bool, SINGLE_ARG(STMTS))                              \
748     CASE(qint32, SINGLE_ARG(STMTS))                            \
749     CASE(quint8, SINGLE_ARG(STMTS))                            \
750     CASE(qint8, SINGLE_ARG(STMTS))                             \
751     CASE(quint16, SINGLE_ARG(STMTS))                           \
752     CASE(qint16, SINGLE_ARG(STMTS))                            \
753     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
754     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
755     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
756     CASE(Variant, SINGLE_ARG(STMTS))                           \
757     case DT_INVALID:                                           \
758       INVALID;                                                 \
759       break;                                                   \
760     default:                                                   \
761       DEFAULT;                                                 \
762       break;                                                   \
763   }
764 
765 #define CASES(TYPE_ENUM, STMTS)                                      \
766   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS,                               \
767                      LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
768                      , LOG(FATAL) << "Type not set";)
769 
Tensor(Allocator * a,DataType type,const TensorShape & shape)770 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
771     : shape_(shape), buf_(nullptr) {
772   set_dtype(type);
773   CHECK_NOTNULL(a);
774   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
775     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
776   }
777   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
778     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
779                                       *this);
780   }
781 }
782 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)783 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
784                const AllocationAttributes& allocation_attr)
785     : shape_(shape), buf_(nullptr) {
786   set_dtype(type);
787   CHECK_NOTNULL(a);
788   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
789     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
790   }
791   if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
792       buf_ != nullptr && buf_->data() != nullptr) {
793     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
794                                       LogMemory::UNKNOWN_STEP_ID, *this);
795   }
796 }
797 
798 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
799 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
800 // acquiring a lock, which guards initialization of the per-NUMA zone
801 // allocators, and becomes highly contended.
802 //
803 // Note also that it would be better if all Tensor allocations required the user
804 // to specify an allocator, for purposes of accounting, etc. However, the
805 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()806 static Allocator* get_default_cpu_allocator() {
807   static Allocator* default_cpu_allocator =
808       cpu_allocator(port::kNUMANoAffinity);
809   return default_cpu_allocator;
810 }
811 
Tensor(DataType type,const TensorShape & shape)812 Tensor::Tensor(DataType type, const TensorShape& shape)
813     : Tensor(get_default_cpu_allocator(), type, shape) {}
814 
GetAllocatedBytes(size_t * out_bytes) const815 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
816     size_t* out_bytes) const {
817   // `this->FillAllocationDescription()` never sets allocated bytes information,
818   // so we can short-circuit the construction of an `AllocationDescription`.
819   return false;
820 }
821 
FillAllocationDescription(AllocationDescription * proto) const822 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
823     AllocationDescription* proto) const {
824   proto->set_requested_bytes(size());
825   proto->set_allocator_name("HostScalarTensorBuffer");
826   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
827 }
828 
829 template <typename T>
830 class SubBuffer : public TensorBuffer {
831  public:
832   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64_t delta,int64_t n)833   SubBuffer(TensorBuffer* buf, int64_t delta, int64_t n)
834       : TensorBuffer(buf->base<T>() + delta),
835         root_(buf->root_buffer()),
836         elem_(n) {
837     // Sanity check. The caller should ensure the sub buffer is valid.
838     CHECK_LE(root_->base<T>(), this->base<T>());
839     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
840     CHECK_LE(this->base<T>(), root_limit);
841     CHECK_LE(this->base<T>() + n, root_limit);
842     // Hold a ref of the underlying root buffer.
843     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
844     root_->Ref();
845   }
846 
size() const847   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()848   TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const849   bool GetAllocatedBytes(size_t* out_bytes) const override {
850     return root_->GetAllocatedBytes(out_bytes);
851   }
FillAllocationDescription(AllocationDescription * proto) const852   void FillAllocationDescription(AllocationDescription* proto) const override {
853     root_->FillAllocationDescription(proto);
854   }
855 
856  private:
857   TensorBuffer* root_;
858   int64 elem_;
859 
~SubBuffer()860   ~SubBuffer() override { root_->Unref(); }
861 
862   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
863 };
864 
Slice(int64_t start,int64_t limit) const865 Tensor Tensor::Slice(int64_t start, int64_t limit) const {
866   CHECK_GE(dims(), 1);
867   CHECK_LE(0, start);
868   CHECK_LE(start, limit);
869   int64_t dim0_size = shape_.dim_size(0);
870   CHECK_LE(limit, dim0_size);
871   if ((start == 0) && (limit == dim0_size)) {
872     return *this;
873   }
874   Tensor ret;
875   ret.shape_ = shape_;
876   ret.set_dtype(dtype());
877   ret.buf_ = nullptr;
878   if (dim0_size > 0) {
879     const int64_t elems_per_dim0 = NumElements() / dim0_size;
880     const int64_t delta = start * elems_per_dim0;
881     dim0_size = limit - start;
882     ret.shape_.set_dim(0, dim0_size);
883     const int64_t num_elems = dim0_size * elems_per_dim0;
884     if (buf_) {
885       DataType dt = dtype();
886       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
887     }
888   }
889   return ret;
890 }
891 
SubSlice(int64_t index) const892 Tensor Tensor::SubSlice(int64_t index) const {
893   CHECK_GE(dims(), 1);  // Crash ok.
894   CHECK_LE(0, index);   // Crash ok.
895   int64_t dim0_size = shape_.dim_size(0);
896   CHECK_LE(index, dim0_size);  // Crash ok.
897   Tensor ret;
898   ret.shape_ = shape_;
899   ret.shape_.RemoveDim(0);
900   ret.set_dtype(dtype());
901   ret.buf_ = nullptr;
902   if (dim0_size > 0) {
903     const int64_t elems_per_dim0 = NumElements() / dim0_size;
904     const int64_t delta = index * elems_per_dim0;
905     const int64_t num_elems = elems_per_dim0;
906     if (buf_) {
907       DataType dt = dtype();
908       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
909     }
910   }
911   return ret;
912 }
913 
FromProto(const TensorProto & proto)914 bool Tensor::FromProto(const TensorProto& proto) {
915   return FromProto(get_default_cpu_allocator(), proto);
916 }
917 
FromProto(Allocator * a,const TensorProto & proto)918 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
919   CHECK_NOTNULL(a);
920   TensorBuffer* p = nullptr;
921   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
922   if (proto.dtype() == DT_INVALID) return false;
923   TensorShape shape(proto.tensor_shape());
924   const int64_t N = shape.num_elements();
925   if (N > 0 && proto.dtype()) {
926     bool dtype_error = false;
927     if (!proto.tensor_content().empty()) {
928       const auto& content = proto.tensor_content();
929       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
930                          dtype_error = true, dtype_error = true);
931     } else {
932       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
933                          dtype_error = true, dtype_error = true);
934     }
935     if (dtype_error || p == nullptr) return false;
936   }
937   shape_ = shape;
938   set_dtype(proto.dtype());
939   UnrefIfNonNull(buf_);
940   buf_ = p;
941   // TODO(misard) add tracking of which kernels and steps are calling
942   // FromProto.
943   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
944     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
945                                       LogMemory::UNKNOWN_STEP_ID, *this);
946   }
947   return true;
948 }
949 
AsProtoField(TensorProto * proto) const950 void Tensor::AsProtoField(TensorProto* proto) const {
951   proto->Clear();
952   shape_.AsProto(proto->mutable_tensor_shape());
953   proto->set_dtype(dtype());
954   if (buf_) {
955     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
956   }
957 }
958 
AsProtoTensorContent(TensorProto * proto) const959 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
960   proto->Clear();
961   proto->set_dtype(dtype());
962   shape_.AsProto(proto->mutable_tensor_shape());
963   if (buf_) {
964     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
965                                      proto->mutable_tensor_content()));
966   }
967 }
968 
TotalBytes() const969 size_t Tensor::TotalBytes() const {
970   if (shape_.num_elements() == 0) return 0;
971   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
972   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
973   return 0;  // Makes compiler happy.
974 }
975 
AllocatedBytes() const976 size_t Tensor::AllocatedBytes() const {
977   if (buf_) {
978     size_t ret;
979     if (buf_->GetAllocatedBytes(&ret)) {
980       return ret;
981     }
982   }
983   return TotalBytes();
984 }
985 
CanUseDMA() const986 bool Tensor::CanUseDMA() const {
987   CASES(dtype(), return is_simple_type<T>::value);
988   return false;  // Makes compiler happy.
989 }
990 
991 #undef CASES
992 #undef CASE
993 
994 namespace {
995 
996 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
997 // we would like to keep them compatible with their absl counterparts, for ease
998 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
999 // logic is so simple we can just replicate it here, where it is close to its
1000 // usage and easy to change later. And there's the extra benefit of not
1001 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)1002 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1003                                                 bool print_v2) {
1004   return a;
1005 }
PrintOneElement(const tstring & a,bool print_v2)1006 inline string PrintOneElement(const tstring& a, bool print_v2) {
1007   if (print_v2) {
1008     return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1009   } else {
1010     return absl::Utf8SafeCEscape(a);
1011   }
1012 }
PrintOneElement(const Eigen::half & h,bool print_v2)1013 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1014   return static_cast<float>(h);
1015 }
1016 
PrintOneElement(bfloat16 f,bool print_v2)1017 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1018   return static_cast<float>(f);
1019 }
1020 
1021 // Print from left dim to right dim recursively.
1022 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t limit,int shape_size,const T * data,int64 * data_index,string * result)1023 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1024                  int64_t limit, int shape_size, const T* data,
1025                  int64* data_index, string* result) {
1026   if (*data_index >= limit) return;
1027   int64_t element_count = shape[dim_index];
1028   // We have reached the right-most dimension of the tensor.
1029   if (dim_index == shape_size - 1) {
1030     for (int64_t i = 0; i < element_count; i++) {
1031       if (*data_index >= limit) {
1032         // If not enough elements has been printed, append "...".
1033         if (dim_index != 0) {
1034           strings::StrAppend(result, "...");
1035         }
1036         return;
1037       }
1038       if (i > 0) strings::StrAppend(result, " ");
1039       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1040     }
1041     return;
1042   }
1043   // Loop every element of one dim.
1044   for (int64_t i = 0; i < element_count; i++) {
1045     bool flag = false;
1046     if (*data_index < limit) {
1047       strings::StrAppend(result, "[");
1048       flag = true;
1049     }
1050     // As for each element, print the sub-dim.
1051     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1052                 result);
1053     if (*data_index < limit || flag) {
1054       strings::StrAppend(result, "]");
1055       flag = false;
1056     }
1057   }
1058 }
1059 
1060 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1061 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1062   if (dim_index == num_dims - 1) {
1063     strings::StrAppend(result, " ");
1064     return;
1065   }
1066   for (int j = 0; j < num_dims - dim_index - 1; j++) {
1067     strings::StrAppend(result, "\n");
1068   }
1069   for (int j = 0; j <= dim_index; j++) {
1070     strings::StrAppend(result, " ");
1071   }
1072 }
1073 
1074 // Print from left dim to right dim recursively.
1075 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t num_elts_at_ends,int num_dims,const T * data,int64_t data_index,string * result)1076 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1077                    int64_t num_elts_at_ends, int num_dims, const T* data,
1078                    int64_t data_index, string* result) {
1079   // We have recursed beyond all the dimensions into a single element
1080   // of the tensor.
1081   if (dim_index == num_dims) {
1082     strings::StrAppend(result, PrintOneElement(data[data_index], true));
1083     return;
1084   }
1085 
1086   strings::StrAppend(result, "[");
1087   int64_t element_count = shape[dim_index];
1088   int64_t start_of_end =
1089       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1090 
1091   // Loop every element of one dim.
1092   int64_t elements_per_iter = 1;
1093   for (int i = dim_index + 1; i < num_dims; i++) {
1094     elements_per_iter *= shape[i];
1095   }
1096   for (int64_t i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1097     if (i > 0) {
1098       PrintDimSpacing(dim_index, num_dims, result);
1099     }
1100 
1101     // As for each element, print the sub-dim.
1102     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1103                   data_index + elements_per_iter * i, result);
1104   }
1105   if (element_count > 2 * num_elts_at_ends) {
1106     PrintDimSpacing(dim_index, num_dims, result);
1107     strings::StrAppend(result, "...");
1108   }
1109   for (int64_t i = start_of_end; i < element_count; i++) {
1110     // As for each element, print the sub-dim.
1111     PrintDimSpacing(dim_index, num_dims, result);
1112     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1113                   data_index + elements_per_iter * i, result);
1114   }
1115 
1116   strings::StrAppend(result, "]");
1117 }
1118 
1119 template <typename T>
SummarizeArray(int64_t limit,int64_t num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1120 string SummarizeArray(int64_t limit, int64_t num_elts,
1121                       const TensorShape& tensor_shape, const char* data,
1122                       const bool print_v2) {
1123   string ret;
1124   const T* array = reinterpret_cast<const T*>(data);
1125 
1126   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1127   if (shape.empty()) {
1128     for (int64_t i = 0; i < limit; ++i) {
1129       if (i > 0) strings::StrAppend(&ret, " ");
1130       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1131     }
1132     if (num_elts > limit) strings::StrAppend(&ret, "...");
1133     return ret;
1134   }
1135   if (print_v2) {
1136     const int num_dims = tensor_shape.dims();
1137     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1138   } else {
1139     int64_t data_index = 0;
1140     const int shape_size = tensor_shape.dims();
1141     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1142 
1143     if (num_elts > limit) strings::StrAppend(&ret, "...");
1144   }
1145 
1146   return ret;
1147 }
1148 }  // namespace
1149 
SummarizeValue(int64_t max_entries,bool print_v2) const1150 string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const {
1151   const int64_t num_elts = NumElements();
1152   if (max_entries < 0) {
1153     max_entries = num_elts;
1154   }
1155   size_t limit = std::min(max_entries, num_elts);
1156   if ((limit > 0) && (buf_ == nullptr)) {
1157     return strings::StrCat("uninitialized Tensor of ", num_elts,
1158                            " elements of type ", dtype());
1159   }
1160   const char* data = limit > 0 ? tensor_data().data() : nullptr;
1161   switch (dtype()) {
1162     case DT_BFLOAT16:
1163       return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1164       break;
1165     case DT_HALF:
1166       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1167                                          print_v2);
1168       break;
1169     case DT_FLOAT:
1170       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1171       break;
1172     case DT_DOUBLE:
1173       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1174       break;
1175     case DT_UINT32:
1176       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1177       break;
1178     case DT_INT32:
1179       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1180       break;
1181     case DT_UINT8:
1182     case DT_QUINT8:
1183       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1184       break;
1185     case DT_UINT16:
1186     case DT_QUINT16:
1187       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1188       break;
1189     case DT_INT16:
1190     case DT_QINT16:
1191       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1192       break;
1193     case DT_INT8:
1194     case DT_QINT8:
1195       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1196       break;
1197     case DT_UINT64:
1198       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1199       break;
1200     case DT_INT64:
1201       return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1202       break;
1203     case DT_BOOL:
1204       // TODO(tucker): Is it better to emit "True False..."?  This
1205       // will emit "1 0..." which is more compact.
1206       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1207       break;
1208     case DT_STRING:
1209       return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1210       break;
1211     default: {
1212       // All irregular cases
1213       string ret;
1214       if (print_v2 && (dims() > 0)) {
1215         strings::StrAppend(&ret, "[");
1216       }
1217       // TODO(irving): Don't call flat every time around this
1218       // loop.
1219       for (size_t i = 0; i < limit; ++i) {
1220         if (i > 0) strings::StrAppend(&ret, " ");
1221         switch (dtype()) {
1222           case DT_VARIANT: {
1223             const Variant& v = flat<Variant>()(i);
1224             strings::StrAppend(&ret, "<", v.SummarizeValue(), ">");
1225           } break;
1226           case DT_RESOURCE: {
1227             const ResourceHandle& r = flat<ResourceHandle>()(i);
1228             strings::StrAppend(&ret, "<", r.SummarizeValue(), ">");
1229           } break;
1230           default:
1231             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1232             // complex64, quantized).
1233             strings::StrAppend(&ret, "?");
1234         }
1235       }
1236       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1237       if (print_v2 && (dims() > 0)) {
1238         strings::StrAppend(&ret, "]");
1239       }
1240       return ret;
1241     }
1242   }
1243 }
1244 
tensor_data() const1245 StringPiece Tensor::tensor_data() const {
1246   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1247   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1248 }
1249 
data() const1250 void* Tensor::data() const {
1251   if (buf_ == nullptr) return nullptr;  // Don't die for empty tensors
1252   return static_cast<void*>(buf_->data());
1253 }
1254 
SharesBufferWith(const Tensor & b) const1255 bool Tensor::SharesBufferWith(const Tensor& b) const {
1256   return buf_ != nullptr && b.buf_ != nullptr &&
1257          buf_->root_buffer() == b.buf_->root_buffer();
1258 }
1259 
DebugString(int num_values) const1260 string Tensor::DebugString(int num_values) const {
1261   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1262                          " shape: ", shape().DebugString(),
1263                          " values: ", SummarizeValue(num_values), ">");
1264 }
1265 
DeviceSafeDebugString() const1266 string Tensor::DeviceSafeDebugString() const {
1267   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1268                          " shape: ", shape().DebugString(), ">");
1269 }
1270 
FillDescription(TensorDescription * description) const1271 void Tensor::FillDescription(TensorDescription* description) const {
1272   description->set_dtype(dtype());
1273   shape().AsProto(description->mutable_shape());
1274   if (buf_ != nullptr && buf_->data() != nullptr) {
1275     buf_->FillAllocationDescription(
1276         description->mutable_allocation_description());
1277   }
1278 }
1279 
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64_t num_out_dims)1280 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1281     gtl::ArraySlice<int64> orig, int64_t num_out_dims) {
1282   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1283   int64_t offset = orig.size() - num_out_dims;
1284   for (int64_t out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1285     const int64_t in_dim = out_dim + offset;
1286     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1287   }
1288   for (int64_t in_dim = 0; in_dim < offset; ++in_dim) {
1289     out_dims[0] *= orig[in_dim];
1290   }
1291   return out_dims;
1292 }
1293 
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64_t num_out_dims)1294 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1295     gtl::ArraySlice<int64> orig, int64_t num_out_dims) {
1296   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1297   for (int64_t out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1298     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1299   }
1300   for (int64_t in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1301     out_dims[num_out_dims - 1] *= orig[in_dim];
1302   }
1303   return out_dims;
1304 }
1305 
1306 }  // namespace tensorflow
1307