• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include "tensorflow/core/framework/allocation_description.pb.h"
33 #include "tensorflow/core/framework/log_memory.h"
34 #include "tensorflow/core/framework/resource_handle.pb.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/framework/type_traits.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/framework/variant.h"
40 #include "tensorflow/core/framework/variant_encode_decode.h"
41 #include "tensorflow/core/framework/variant_op_registry.h"
42 #include "tensorflow/core/framework/variant_tensor_data.h"
43 #include "tensorflow/core/lib/core/coding.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/gtl/stl_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/protobuf.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/platform/types.h"
55 
56 namespace tensorflow {
57 
58 // Allow Tensors to be stored inside Variants with automatic
59 // encoding/decoding when those Variants are themselves being decoded
60 // in a Tensor's FromProto.
61 //
62 // NOTE(mrry): The corresponding "copy function" registrations can be found in
63 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
64 // code).
65 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
66 
67 namespace {
68 
69 // An un-templated base class for Buffer.
70 class BufferBase : public TensorBuffer {
71  public:
BufferBase(Allocator * alloc,void * data_ptr)72   explicit BufferBase(Allocator* alloc, void* data_ptr)
73       : TensorBuffer(data_ptr), alloc_(alloc) {}
74 
root_buffer()75   TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const76   void FillAllocationDescription(AllocationDescription* proto) const override {
77     void* data_ptr = data();
78     int64 rb = size();
79     proto->set_requested_bytes(rb);
80     proto->set_allocator_name(alloc_->Name());
81     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
82     if (alloc_->TracksAllocationSizes()) {
83       int64 ab = alloc_->AllocatedSize(data_ptr);
84       proto->set_allocated_bytes(ab);
85       int64 id = alloc_->AllocationId(data_ptr);
86       if (id > 0) {
87         proto->set_allocation_id(id);
88       }
89       if (RefCountIsOne()) {
90         proto->set_has_single_reference(true);
91       }
92     }
93   }
94 
95  protected:
RecordDeallocation()96   void RecordDeallocation() {
97     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
98                                         alloc_->Name());
99   }
100 
101   Allocator* const alloc_;
102 };
103 
104 // Typed ref-counted buffer: T[n].
105 template <typename T>
106 class Buffer : public BufferBase {
107  public:
108   Buffer(Allocator* a, int64 n);
109   Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
110 
size() const111   size_t size() const override { return sizeof(T) * elem_; }
112 
113  private:
114   T* data_;
115   int64 elem_;
116 
117   ~Buffer() override;
118 
119   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
120 };
121 
LogUnexpectedSize(int64 actual,int64 expected)122 void LogUnexpectedSize(int64 actual, int64 expected) {
123   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
124 }
125 
126 // A set of helper functions depending on T.
127 template <typename T>
128 struct Helper {
129   // By default, we assume T is a simple type (float, int32, etc.)
130   static_assert(is_simple_type<T>::value, "T is not a simple type.");
131   typedef protobuf::RepeatedField<T> RepeatedFieldType;
132 
133   // Encoder of simple type T to a string.  We do a copy.
134   template <typename Destination>
Encodetensorflow::__anonae1448e20111::Helper135   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
136     DCHECK_EQ(in->size(), sizeof(T) * n);
137     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
138                            out);
139   }
140 
141   // Decoder of simple type T. Copy the bytes from "in" into the
142   // tensor buffer.
143   template <typename Source>
Decodetensorflow::__anonae1448e20111::Helper144   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
145     if (in.size() != sizeof(T) * n) {
146       LogUnexpectedSize(in.size(), sizeof(T) * n);
147       return nullptr;
148     }
149     Buffer<T>* buf = new Buffer<T>(a, n);
150     char* data = buf->template base<char>();
151     if (data == nullptr) {
152       buf->Unref();
153       return nullptr;
154     }
155     port::CopyToArray(in, data);
156     return buf;
157   }
158 
159   // Memory usage.
TotalBytestensorflow::__anonae1448e20111::Helper160   static int64 TotalBytes(TensorBuffer* in, int64 n) {
161     DCHECK_EQ(in->size(), sizeof(T) * n);
162     return in->size();
163   }
164 };
165 
166 // Helper specialization for string (the only non-simple type we
167 // support).
168 template <>
169 struct Helper<string> {
170   // Proto message uses RepeatedFieldType to hold repeated T.
171   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
172 
173   // Encodes "n" elements of type string stored in "in" into Cord
174   // "out", which is usually the TensorProto::tensor_content.
175   template <typename Destination>
Encodetensorflow::__anonae1448e20111::Helper176   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
177     port::EncodeStringList(in->base<const string>(), n, out);
178   }
179 
180   // Decodes "n" elements of type string from "in" and constructs a
181   // buffer out of it. Returns nullptr if the decoding fails. "in" is
182   // usually the TensorProto::tensor_content.
183   template <typename Source>
Decodetensorflow::__anonae1448e20111::Helper184   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
185     Buffer<string>* buf = new Buffer<string>(a, n);
186     string* strings = buf->template base<string>();
187     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
188       buf->Unref();
189       return nullptr;
190     }
191     return buf;
192   }
193 
194   // Returns the estimated memory usage of "n" elements of type T
195   // stored in buffer "in".
TotalBytestensorflow::__anonae1448e20111::Helper196   static int64 TotalBytes(TensorBuffer* in, int n) {
197     int64 tot = in->size();
198     DCHECK_EQ(tot, sizeof(string) * n);
199     const string* p = in->base<const string>();
200     for (int i = 0; i < n; ++i, ++p) tot += p->size();
201     return tot;
202   }
203 };
204 
205 template <>
206 struct Helper<ResourceHandle> {
207   // Proto message uses RepeatedFieldType to hold repeated T.
208   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
209 
210   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
211   // "out", which is usually the TensorProto::tensor_content.
212   template <typename Destination>
Encodetensorflow::__anonae1448e20111::Helper213   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
214     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
215                              port::NewStringListEncoder(out));
216   }
217 
218   // Decodes "n" elements of type string from "in" and constructs a
219   // buffer out of it. Returns nullptr if the decoding fails. "in" is
220   // usually the TensorProto::tensor_content.
221   template <typename Source>
Decodetensorflow::__anonae1448e20111::Helper222   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
223     auto* buf = new Buffer<ResourceHandle>(a, n);
224     ResourceHandle* ps = buf->template base<ResourceHandle>();
225     if (ps == nullptr ||
226         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
227       buf->Unref();
228       return nullptr;
229     }
230     return buf;
231   }
232 
233   // Returns the estimated memory usage of "n" elements of type T
234   // stored in buffer "in".
TotalBytestensorflow::__anonae1448e20111::Helper235   static int64 TotalBytes(TensorBuffer* in, int n) {
236     return n * sizeof(ResourceHandle);
237   }
238 };
239 
240 template <>
241 struct Helper<Variant> {
242   // Encodes "n" elements of type Variant stored in "in" into destination
243   // "out", which is usually the TensorProto::tensor_content.
244   template <typename Destination>
Encodetensorflow::__anonae1448e20111::Helper245   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
246     EncodeVariantList(in->base<const Variant>(), n,
247                       port::NewStringListEncoder(out));
248   }
249 
250   // Decodes "n" elements of type Variant from "in" and constructs a
251   // buffer out of it. Returns nullptr if the decoding fails. "in" is
252   // usually the TensorProto::tensor_content.
253   template <typename Source>
Decodetensorflow::__anonae1448e20111::Helper254   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
255     auto* buf = new Buffer<Variant>(a, n);
256     Variant* ps = buf->template base<Variant>();
257     if (ps == nullptr ||
258         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
259       buf->Unref();
260       return nullptr;
261     }
262     return buf;
263   }
264 
265   // Returns the estimated memory usage of "n" elements of type T
266   // stored in buffer "in".
TotalBytestensorflow::__anonae1448e20111::Helper267   static int64 TotalBytes(TensorBuffer* in, int n) {
268     return n * sizeof(Variant);
269   }
270 };
271 
272 template <typename T>
273 struct ProtoHelper {};
274 
275 // For a C++ type "T" (float, double, int32, etc.), the repeated field
276 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
277 // int32, string, etc) in the TensorProto is used for serializing the
278 // tensor of type "T".
279 #define PROTO_TRAITS(T, F, N)                                          \
280   template <>                                                          \
281   struct ProtoHelper<T> {                                              \
282     typedef Helper<F>::RepeatedFieldType FieldType;                    \
283     static FieldType::const_iterator Begin(const TensorProto& proto) { \
284       return proto.N##_val().begin();                                  \
285     }                                                                  \
286     static size_t NumElements(const TensorProto& proto) {              \
287       return proto.N##_val().size();                                   \
288     }                                                                  \
289     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
290       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
291       proto->mutable_##N##_val()->Swap(&copy);                         \
292     }                                                                  \
293   };
294 PROTO_TRAITS(float, float, float);
295 PROTO_TRAITS(double, double, double);
296 PROTO_TRAITS(int32, int32, int);
297 PROTO_TRAITS(uint8, int32, int);
298 PROTO_TRAITS(uint16, int32, int);
299 PROTO_TRAITS(uint32, uint32, uint32);
300 PROTO_TRAITS(int16, int32, int);
301 PROTO_TRAITS(int8, int32, int);
302 PROTO_TRAITS(bool, bool, bool);
303 PROTO_TRAITS(string, string, string);
304 PROTO_TRAITS(qint8, int32, int);
305 PROTO_TRAITS(quint8, int32, int);
306 PROTO_TRAITS(qint16, int32, int);
307 PROTO_TRAITS(quint16, int32, int);
308 #undef PROTO_TRAITS
309 
310 template <>
311 struct ProtoHelper<int64> {
Begintensorflow::__anonae1448e20111::ProtoHelper312   static const int64* Begin(const TensorProto& proto) {
313     return reinterpret_cast<const int64*>(proto.int64_val().begin());
314   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper315   static size_t NumElements(const TensorProto& proto) {
316     return proto.int64_val().size();
317   }
Filltensorflow::__anonae1448e20111::ProtoHelper318   static void Fill(const int64* data, size_t n, TensorProto* proto) {
319     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
320     proto->mutable_int64_val()->Swap(&copy);
321   }
322 };
323 
324 template <>
325 struct ProtoHelper<uint64> {
Begintensorflow::__anonae1448e20111::ProtoHelper326   static const uint64* Begin(const TensorProto& proto) {
327     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
328   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper329   static size_t NumElements(const TensorProto& proto) {
330     return proto.uint64_val().size();
331   }
Filltensorflow::__anonae1448e20111::ProtoHelper332   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
333     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
334     proto->mutable_uint64_val()->Swap(&copy);
335   }
336 };
337 
338 template <>
339 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anonae1448e20111::ProtoHelper340   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
341       const TensorProto& proto) {
342     return proto.resource_handle_val().begin();
343   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper344   static size_t NumElements(const TensorProto& proto) {
345     return proto.resource_handle_val().size();
346   }
Filltensorflow::__anonae1448e20111::ProtoHelper347   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
348     auto* handles = proto->mutable_resource_handle_val();
349     handles->Clear();
350     for (size_t i = 0; i < n; i++) {
351       data[i].AsProto(handles->Add());
352     }
353   }
354 };
355 
356 template <>
357 struct ProtoHelper<Variant> {
358   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anonae1448e20111::ProtoHelper359   Begin(const TensorProto& proto) {
360     return proto.variant_val().begin();
361   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper362   static size_t NumElements(const TensorProto& proto) {
363     return proto.variant_val().size();
364   }
Filltensorflow::__anonae1448e20111::ProtoHelper365   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
366     auto* variant_values = proto->mutable_variant_val();
367     variant_values->Clear();
368     for (size_t i = 0; i < n; ++i) {
369       VariantTensorData tmp;
370       data[i].Encode(&tmp);
371       tmp.ToProto(variant_values->Add());
372     }
373   }
374 };
375 
376 template <>
377 struct ProtoHelper<complex64> {
378   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anonae1448e20111::ProtoHelper379   static const complex64* Begin(const TensorProto& proto) {
380     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
381   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper382   static size_t NumElements(const TensorProto& proto) {
383     return proto.scomplex_val().size() / 2;
384   }
Filltensorflow::__anonae1448e20111::ProtoHelper385   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
386     const float* p = reinterpret_cast<const float*>(data);
387     FieldType copy(p, p + n * 2);
388     proto->mutable_scomplex_val()->Swap(&copy);
389   }
390 };
391 
392 template <>
393 struct ProtoHelper<complex128> {
394   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anonae1448e20111::ProtoHelper395   static const complex128* Begin(const TensorProto& proto) {
396     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
397   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper398   static size_t NumElements(const TensorProto& proto) {
399     return proto.dcomplex_val().size() / 2;
400   }
Filltensorflow::__anonae1448e20111::ProtoHelper401   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
402     const double* p = reinterpret_cast<const double*>(data);
403     FieldType copy(p, p + n * 2);
404     proto->mutable_dcomplex_val()->Swap(&copy);
405   }
406 };
407 
408 template <>
409 struct ProtoHelper<qint32> {
410   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anonae1448e20111::ProtoHelper411   static const qint32* Begin(const TensorProto& proto) {
412     return reinterpret_cast<const qint32*>(proto.int_val().data());
413   }
NumElementstensorflow::__anonae1448e20111::ProtoHelper414   static size_t NumElements(const TensorProto& proto) {
415     return proto.int_val().size();
416   }
Filltensorflow::__anonae1448e20111::ProtoHelper417   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
418     const int32* p = reinterpret_cast<const int32*>(data);
419     FieldType copy(p, p + n);
420     proto->mutable_int_val()->Swap(&copy);
421   }
422 };
423 
424 template <>
425 struct ProtoHelper<bfloat16> {
Filltensorflow::__anonae1448e20111::ProtoHelper426   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
427     proto->mutable_half_val()->Reserve(n);
428     for (size_t i = 0; i < n; ++i) {
429       proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
430     }
431   }
432 };
433 
434 template <>
435 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anonae1448e20111::ProtoHelper436   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
437     proto->mutable_half_val()->Reserve(n);
438     for (size_t i = 0; i < n; ++i) {
439       proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
440     }
441   }
442 };
443 
444 template <typename T>
Buffer(Allocator * a,int64 n)445 Buffer<T>::Buffer(Allocator* a, int64 n)
446     : BufferBase(a, a->Allocate<T>(n)), elem_(n) {}
447 
448 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)449 Buffer<T>::Buffer(Allocator* a, int64 n,
450                   const AllocationAttributes& allocation_attr)
451     : BufferBase(a, a->Allocate<T>(n, allocation_attr)), elem_(n) {}
452 
453 template <typename T>
~Buffer()454 Buffer<T>::~Buffer() {
455   if (data()) {
456     if (LogMemory::IsEnabled()) {
457       RecordDeallocation();
458     }
459     alloc_->Deallocate<T>(static_cast<T*>(data()), elem_);
460   }
461 }
462 
463 // Allocates a T[n] buffer. Fills in the buffer with repeated values
464 // in "in".  If "in" has less values than "n", fills the rest of T[n]
465 // with the last value. If "in" has no values, fills T[n] with the
466 // default value for T.
467 //
468 // This routine is using the typed fields (float_val, etc.) in the
469 // tensor proto as opposed to the untyped binary representation
470 // (tensor_content). This is used when we expect the TensorProto is
471 // used by a client program which may not know how to encode a tensor
472 // in the compact binary representation.
473 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)474 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
475   CHECK_GT(n, 0);
476   Buffer<T>* buf = new Buffer<T>(a, n);
477   T* data = buf->template base<T>();
478   if (data == nullptr) {
479     buf->Unref();
480     return nullptr;
481   }
482 
483   const int64 in_n = ProtoHelper<T>::NumElements(in);
484   if (in_n <= 0) {
485     std::fill_n(data, n, T());
486   } else {
487     auto begin = ProtoHelper<T>::Begin(in);
488     if (n <= in_n) {
489       std::copy_n(begin, n, data);
490     } else {
491       std::copy_n(begin, in_n, data);
492       const T& last = *(data + in_n - 1);
493       std::fill_n(data + in_n, n - in_n, last);
494     }
495   }
496 
497   return buf;
498 }
499 
500 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)501 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
502                                       int64 n) {
503   CHECK_GT(n, 0);
504   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
505   Variant* data = buf->template base<Variant>();
506   if (data == nullptr) {
507     buf->Unref();
508     return nullptr;
509   }
510   const int64 in_n = ProtoHelper<Variant>::NumElements(in);
511   if (in_n <= 0) {
512     std::fill_n(data, n, Variant());
513   } else {
514     for (int64 i = 0; i < in_n; ++i) {
515       data[i] = in.variant_val(i);
516       if (!DecodeUnaryVariant(&data[i])) {
517         LOG(ERROR) << "Could not decode variant with type_name: \""
518                    << data[i].TypeName()
519                    << "\".  Perhaps you forgot to register a "
520                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
521         buf->Unref();
522         return nullptr;
523       }
524     }
525     for (int64 i = in_n; i < n; ++i) {
526       data[i] = Variant();
527     }
528   }
529   return buf;
530 }
531 
532 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
533 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
534 // we don't use ProtoHelper<uint16>).
535 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)536 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
537                                           int64 n) {
538   CHECK_GT(n, 0);
539   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
540   uint16* data = buf->template base<uint16>();
541   if (data == nullptr) {
542     buf->Unref();
543     return nullptr;
544   }
545   const int64 in_n = in.half_val().size();
546   auto begin = in.half_val().begin();
547   if (n <= in_n) {
548     std::copy_n(begin, n, data);
549   } else if (in_n > 0) {
550     std::copy_n(begin, in_n, data);
551     const uint16 last = *(data + in_n - 1);
552     std::fill_n(data + in_n, n - in_n, last);
553   } else {
554     std::fill_n(data, n, 0);
555   }
556   return buf;
557 }
558 
559 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)560 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
561                                        int64 n) {
562   CHECK_GT(n, 0);
563   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
564   uint16* data = buf->template base<uint16>();
565   if (data == nullptr) {
566     buf->Unref();
567     return nullptr;
568   }
569   const int64 in_n = in.half_val().size();
570   auto begin = in.half_val().begin();
571   if (n <= in_n) {
572     std::copy_n(begin, n, data);
573   } else if (in_n > 0) {
574     std::copy_n(begin, in_n, data);
575     const uint16 last = *(data + in_n - 1);
576     std::fill_n(data + in_n, n - in_n, last);
577   } else {
578     std::fill_n(data, n, 0);
579   }
580   return buf;
581 }
582 
583 // Copies T[n] stored in the buffer "in" into the repeated field in
584 // "out" corresponding to type T.
585 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)586 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
587   const T* data = in.base<const T>();
588   // NOTE: T may not the same as
589   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
590   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
591   // critical, we can specialize T=float and do memcpy directly.
592   ProtoHelper<T>::Fill(data, n, out);
593 }
594 
RefIfNonNull(core::RefCounted * buf)595 void RefIfNonNull(core::RefCounted* buf) {
596   if (buf) buf->Ref();
597 }
598 
UnrefIfNonNull(core::RefCounted * buf)599 void UnrefIfNonNull(core::RefCounted* buf) {
600   if (buf) buf->Unref();
601 }
602 
603 }  // end namespace
604 
Tensor()605 Tensor::Tensor() : Tensor(DT_FLOAT) {}
606 
Tensor(DataType type)607 Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); }
608 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)609 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
610     : shape_(shape), buf_(buf) {
611   set_dtype(type);
612   RefIfNonNull(buf);
613 }
614 
IsInitialized() const615 bool Tensor::IsInitialized() const {
616   return (buf_ != nullptr && buf_->data() != nullptr) ||
617          shape_.num_elements() == 0;
618 }
619 
CheckType(DataType expected_dtype) const620 void Tensor::CheckType(DataType expected_dtype) const {
621   CHECK_EQ(dtype(), expected_dtype) << " "
622       << DataTypeString(expected_dtype) << " expected, got "
623       << DataTypeString(dtype());
624 }
625 
CheckTypeAndIsAligned(DataType expected_dtype) const626 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
627   CHECK_EQ(dtype(), expected_dtype) << " "
628       << DataTypeString(expected_dtype) << " expected, got "
629       << DataTypeString(dtype());
630   CHECK(IsAligned()) << "ptr = " << base<void>();
631 }
632 
CheckIsAlignedAndSingleElement() const633 void Tensor::CheckIsAlignedAndSingleElement() const {
634   CHECK(IsAligned()) << "Aligned and single element";
635   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
636 }
637 
~Tensor()638 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
639 
CopyFromInternal(const Tensor & other,const TensorShape & shape)640 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
641   CHECK_EQ(shape.num_elements(), other.NumElements());
642   // Data type will be overwritten if this == &other, since dtype is part of
643   // shape.
644   DataType other_dtype = other.dtype();
645   shape_ = shape;
646   set_dtype(other_dtype);
647   if (buf_ != other.buf_) {
648     UnrefIfNonNull(buf_);
649     buf_ = other.buf_;
650     RefIfNonNull(buf_);
651   }
652 }
653 
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)654 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
655                            const TensorShape& shape) {
656   int in_size = DataTypeSize(other.dtype());
657   int out_size = DataTypeSize(dtype);
658   if (in_size == 0) {
659     return errors::InvalidArgument("other tensor has zero-sized data type");
660   }
661   if (out_size == 0) {
662     return errors::InvalidArgument("specified output type is zero-sized");
663   }
664   if (shape.num_elements() * out_size !=
665       other.shape().num_elements() * in_size) {
666     return errors::InvalidArgument(
667         "input and output shapes/data type sizes are not compatible");
668   }
669   shape_ = shape;
670   shape_.set_data_type(dtype);
671   if (buf_ != other.buf_) {
672     UnrefIfNonNull(buf_);
673     buf_ = other.buf_;
674     RefIfNonNull(buf_);
675   }
676   return Status::OK();
677 }
678 
679 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
680 // For the latter case, we have to make sure that the refcount is
681 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const682 bool Tensor::RefCountIsOne() const {
683   return buf_ != nullptr && buf_->RefCountIsOne() &&
684          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
685 }
686 
687 // The macro CASES() expands to a switch statement conditioned on
688 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
689 #define SINGLE_ARG(...) __VA_ARGS__
690 #define CASE(TYPE, STMTS)             \
691   case DataTypeToEnum<TYPE>::value: { \
692     typedef TYPE T;                   \
693     STMTS;                            \
694     break;                            \
695   }
696 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
697   switch (TYPE_ENUM) {                                         \
698     CASE(float, SINGLE_ARG(STMTS))                             \
699     CASE(double, SINGLE_ARG(STMTS))                            \
700     CASE(int32, SINGLE_ARG(STMTS))                             \
701     CASE(uint8, SINGLE_ARG(STMTS))                             \
702     CASE(uint16, SINGLE_ARG(STMTS))                            \
703     CASE(uint32, SINGLE_ARG(STMTS))                            \
704     CASE(uint64, SINGLE_ARG(STMTS))                            \
705     CASE(int16, SINGLE_ARG(STMTS))                             \
706     CASE(int8, SINGLE_ARG(STMTS))                              \
707     CASE(string, SINGLE_ARG(STMTS))                            \
708     CASE(complex64, SINGLE_ARG(STMTS))                         \
709     CASE(complex128, SINGLE_ARG(STMTS))                        \
710     CASE(int64, SINGLE_ARG(STMTS))                             \
711     CASE(bool, SINGLE_ARG(STMTS))                              \
712     CASE(qint32, SINGLE_ARG(STMTS))                            \
713     CASE(quint8, SINGLE_ARG(STMTS))                            \
714     CASE(qint8, SINGLE_ARG(STMTS))                             \
715     CASE(quint16, SINGLE_ARG(STMTS))                           \
716     CASE(qint16, SINGLE_ARG(STMTS))                            \
717     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
718     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
719     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
720     CASE(Variant, SINGLE_ARG(STMTS))                           \
721     case DT_INVALID:                                           \
722       INVALID;                                                 \
723       break;                                                   \
724     default:                                                   \
725       DEFAULT;                                                 \
726       break;                                                   \
727   }
728 
729 #define CASES(TYPE_ENUM, STMTS)                                      \
730   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
731                      , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
732 
Tensor(Allocator * a,DataType type,const TensorShape & shape)733 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
734     : shape_(shape), buf_(nullptr) {
735   set_dtype(type);
736   CHECK_NOTNULL(a);
737   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
738     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
739   }
740   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
741     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
742                                       *this);
743   }
744 }
745 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)746 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
747                const AllocationAttributes& allocation_attr)
748     : shape_(shape), buf_(nullptr) {
749   set_dtype(type);
750   CHECK_NOTNULL(a);
751   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
752     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
753   }
754   if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
755       buf_->data() != nullptr && LogMemory::IsEnabled()) {
756     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
757                                       LogMemory::UNKNOWN_STEP_ID, *this);
758   }
759 }
760 
Tensor(DataType type,const TensorShape & shape)761 Tensor::Tensor(DataType type, const TensorShape& shape)
762     : Tensor(cpu_allocator(), type, shape) {}
763 
FillAllocationDescription(AllocationDescription * proto) const764 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
765     AllocationDescription* proto) const {
766   proto->set_requested_bytes(size());
767   proto->set_allocator_name("HostScalarTensorBuffer");
768   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
769 }
770 
771 template <typename T>
772 class SubBuffer : public TensorBuffer {
773  public:
774   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)775   SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
776       : TensorBuffer(buf->base<T>() + delta),
777         root_(buf->root_buffer()),
778         elem_(n) {
779     // Sanity check. The caller should ensure the sub buffer is valid.
780     CHECK_LE(root_->base<T>(), this->base<T>());
781     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
782     CHECK_LE(this->base<T>(), root_limit);
783     CHECK_LE(this->base<T>() + n, root_limit);
784     // Hold a ref of the underlying root buffer.
785     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
786     root_->Ref();
787   }
788 
size() const789   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()790   TensorBuffer* root_buffer() override { return root_; }
FillAllocationDescription(AllocationDescription * proto) const791   void FillAllocationDescription(AllocationDescription* proto) const override {
792     root_->FillAllocationDescription(proto);
793   }
794 
795  private:
796   TensorBuffer* root_;
797   T* data_;
798   int64 elem_;
799 
~SubBuffer()800   ~SubBuffer() override { root_->Unref(); }
801 
802   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
803 };
804 
Slice(int64 start,int64 limit) const805 Tensor Tensor::Slice(int64 start, int64 limit) const {
806   CHECK_GE(dims(), 1);
807   CHECK_LE(0, start);
808   CHECK_LE(start, limit);
809   int64 dim0_size = shape_.dim_size(0);
810   CHECK_LE(limit, dim0_size);
811   if ((start == 0) && (limit == dim0_size)) {
812     return *this;
813   }
814   Tensor ret;
815   ret.shape_ = shape_;
816   ret.set_dtype(dtype());
817   ret.buf_ = nullptr;
818   if (dim0_size > 0) {
819     const int64 elems_per_dim0 = NumElements() / dim0_size;
820     const int64 delta = start * elems_per_dim0;
821     dim0_size = limit - start;
822     ret.shape_.set_dim(0, dim0_size);
823     const int64 num_elems = dim0_size * elems_per_dim0;
824     if (buf_) {
825       DataType dt = dtype();
826       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
827     }
828   }
829   return ret;
830 }
831 
SubSlice(int64 index) const832 Tensor Tensor::SubSlice(int64 index) const {
833   CHECK_GE(dims(), 1);  // Crash ok.
834   CHECK_LE(0, index);   // Crash ok.
835   int64 dim0_size = shape_.dim_size(0);
836   CHECK_LE(index, dim0_size);  // Crash ok.
837   Tensor ret;
838   ret.shape_ = shape_;
839   ret.shape_.RemoveDim(0);
840   ret.set_dtype(dtype());
841   ret.buf_ = nullptr;
842   if (dim0_size > 0) {
843     const int64 elems_per_dim0 = NumElements() / dim0_size;
844     const int64 delta = index * elems_per_dim0;
845     const int64 num_elems = elems_per_dim0;
846     if (buf_) {
847       DataType dt = dtype();
848       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
849     }
850   }
851   return ret;
852 }
853 
FromProto(const TensorProto & proto)854 bool Tensor::FromProto(const TensorProto& proto) {
855   return FromProto(cpu_allocator(), proto);
856 }
857 
FromProto(Allocator * a,const TensorProto & proto)858 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
859   CHECK_NOTNULL(a);
860   TensorBuffer* p = nullptr;
861   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
862   if (proto.dtype() == DT_INVALID) return false;
863   TensorShape shape(proto.tensor_shape());
864   const int64 N = shape.num_elements();
865   if (N > 0 && proto.dtype()) {
866     bool dtype_error = false;
867     if (!proto.tensor_content().empty()) {
868       const auto& content = proto.tensor_content();
869       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
870                          dtype_error = true, dtype_error = true);
871     } else {
872       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
873                          dtype_error = true, dtype_error = true);
874     }
875     if (dtype_error || p == nullptr) return false;
876   }
877   shape_ = shape;
878   set_dtype(proto.dtype());
879   UnrefIfNonNull(buf_);
880   buf_ = p;
881   // TODO(misard) add tracking of which kernels and steps are calling
882   // FromProto.
883   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
884     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
885                                       LogMemory::UNKNOWN_STEP_ID, *this);
886   }
887   return true;
888 }
889 
AsProtoField(TensorProto * proto) const890 void Tensor::AsProtoField(TensorProto* proto) const {
891   proto->Clear();
892   shape_.AsProto(proto->mutable_tensor_shape());
893   proto->set_dtype(dtype());
894   if (buf_) {
895     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
896   }
897 }
898 
AsProtoTensorContent(TensorProto * proto) const899 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
900   proto->Clear();
901   proto->set_dtype(dtype());
902   shape_.AsProto(proto->mutable_tensor_shape());
903   if (buf_) {
904     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
905                                      proto->mutable_tensor_content()));
906   }
907 }
908 
TotalBytes() const909 size_t Tensor::TotalBytes() const {
910   if (shape_.num_elements() == 0) return 0;
911   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
912   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
913   return 0;  // Makes compiler happy.
914 }
915 
AllocatedBytes() const916 size_t Tensor::AllocatedBytes() const {
917   TensorDescription tensor_description;
918   FillDescription(&tensor_description);
919   if (tensor_description.has_allocation_description() &&
920       tensor_description.allocation_description().allocated_bytes() > 0) {
921     return tensor_description.allocation_description().allocated_bytes();
922   } else {
923     // Fall back to TotalBytes() if the allocator doesn't have its size.
924     return TotalBytes();
925   }
926 }
927 
CanUseDMA() const928 bool Tensor::CanUseDMA() const {
929   CASES(dtype(), return is_simple_type<T>::value);
930   return false;  // Makes compiler happy.
931 }
932 
933 #undef CASES
934 #undef CASE
935 
936 namespace {
937 
938 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
939 // we would like to keep them compatible with their absl counterparts, for ease
940 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
941 // logic is so simple we can just replicate it here, where it is close to its
942 // usage and easy to change later. And there's the extra benefit of not
943 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)944 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
945                                                 bool print_v2) {
946   return a;
947 }
PrintOneElement(const string & a,bool print_v2)948 inline string PrintOneElement(const string& a, bool print_v2) {
949   if (print_v2) {
950     return "\"" + str_util::CEscape(a) + "\"";
951   } else {
952     return str_util::CEscape(a);
953   }
954 }
PrintOneElement(const Eigen::half & h,bool print_v2)955 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
956   return static_cast<float>(h);
957 }
958 
959 // Print from left dim to right dim recursively.
960 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)961 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
962                  int64 limit, int shape_size, const T* data, int64* data_index,
963                  string* result) {
964   if (*data_index >= limit) return;
965   int64 element_count = shape[dim_index];
966   // We have reached the right-most dimension of the tensor.
967   if (dim_index == shape_size - 1) {
968     for (int64 i = 0; i < element_count; i++) {
969       if (*data_index >= limit) {
970         // If not enough elements has been printed, append "...".
971         if (dim_index != 0 && i < element_count) {
972           strings::StrAppend(result, "...");
973         }
974         return;
975       }
976       if (i > 0) strings::StrAppend(result, " ");
977       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
978     }
979     return;
980   }
981   // Loop every element of one dim.
982   for (int64 i = 0; i < element_count; i++) {
983     bool flag = false;
984     if (*data_index < limit) {
985       strings::StrAppend(result, "[");
986       flag = true;
987     }
988     // As for each element, print the sub-dim.
989     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
990                 result);
991     if (*data_index < limit || flag) {
992       strings::StrAppend(result, "]");
993       flag = false;
994     }
995   }
996 }
997 
998 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)999 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1000   if (dim_index == num_dims - 1) {
1001     strings::StrAppend(result, " ");
1002     return;
1003   }
1004   for (int j = 0; j < num_dims - dim_index - 1; j++) {
1005     strings::StrAppend(result, "\n");
1006   }
1007   for (int j = 0; j <= dim_index; j++) {
1008     strings::StrAppend(result, " ");
1009   }
1010 }
1011 
1012 // Print from left dim to right dim recursively.
1013 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1014 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1015                    int64 num_elts_at_ends, int num_dims, const T* data,
1016                    int64 data_index, string* result) {
1017   // We have recursed beyond all the dimensions into a single element
1018   // of the tensor.
1019   if (dim_index == num_dims) {
1020     strings::StrAppend(result, PrintOneElement(data[data_index], true));
1021     return;
1022   }
1023 
1024   strings::StrAppend(result, "[");
1025   int64 element_count = shape[dim_index];
1026   int64 start_of_end =
1027       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1028 
1029   // Loop every element of one dim.
1030   int64 elements_per_iter = 1;
1031   for (int i = dim_index + 1; i < num_dims; i++) {
1032     elements_per_iter *= shape[i];
1033   }
1034   for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1035     if (i > 0) {
1036       PrintDimSpacing(dim_index, num_dims, result);
1037     }
1038 
1039     // As for each element, print the sub-dim.
1040     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1041                   data_index + elements_per_iter * i, result);
1042   }
1043   if (element_count > 2 * num_elts_at_ends) {
1044     PrintDimSpacing(dim_index, num_dims, result);
1045     strings::StrAppend(result, "...");
1046   }
1047   for (int64 i = start_of_end; i < element_count; i++) {
1048     // As for each element, print the sub-dim.
1049     PrintDimSpacing(dim_index, num_dims, result);
1050     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1051                   data_index + elements_per_iter * i, result);
1052   }
1053 
1054   strings::StrAppend(result, "]");
1055 }
1056 
1057 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1058 string SummarizeArray(int64 limit, int64 num_elts,
1059                       const TensorShape& tensor_shape, const char* data,
1060                       const bool print_v2) {
1061   string ret;
1062   const T* array = reinterpret_cast<const T*>(data);
1063 
1064   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1065   if (shape.empty()) {
1066     for (int64 i = 0; i < limit; ++i) {
1067       if (i > 0) strings::StrAppend(&ret, " ");
1068       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1069     }
1070     if (num_elts > limit) strings::StrAppend(&ret, "...");
1071     return ret;
1072   }
1073   if (print_v2) {
1074     const int num_dims = tensor_shape.dims();
1075     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1076   } else {
1077     int64 data_index = 0;
1078     const int shape_size = tensor_shape.dims();
1079     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1080 
1081     if (num_elts > limit) strings::StrAppend(&ret, "...");
1082   }
1083 
1084   return ret;
1085 }
1086 }  // namespace
1087 
SummarizeValue(int64 max_entries,bool print_v2) const1088 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1089   const int64 num_elts = NumElements();
1090   if (max_entries < 0) {
1091     max_entries = num_elts;
1092   }
1093   size_t limit = std::min(max_entries, num_elts);
1094   if ((limit > 0) && (buf_ == nullptr)) {
1095     return strings::StrCat("uninitialized Tensor of ", num_elts,
1096                            " elements of type ", dtype());
1097   }
1098   const char* data = limit > 0 ? tensor_data().data() : nullptr;
1099   switch (dtype()) {
1100     case DT_HALF:
1101       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1102                                          print_v2);
1103       break;
1104     case DT_FLOAT:
1105       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1106       break;
1107     case DT_DOUBLE:
1108       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1109       break;
1110     case DT_UINT32:
1111       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1112       break;
1113     case DT_INT32:
1114       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1115       break;
1116     case DT_UINT8:
1117     case DT_QUINT8:
1118       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1119       break;
1120     case DT_UINT16:
1121     case DT_QUINT16:
1122       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1123       break;
1124     case DT_INT16:
1125     case DT_QINT16:
1126       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1127       break;
1128     case DT_INT8:
1129     case DT_QINT8:
1130       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1131       break;
1132     case DT_UINT64:
1133       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1134       break;
1135     case DT_INT64:
1136       return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1137       break;
1138     case DT_BOOL:
1139       // TODO(tucker): Is it better to emit "True False..."?  This
1140       // will emit "1 0..." which is more compact.
1141       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1142       break;
1143     case DT_STRING:
1144       return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2);
1145       break;
1146     default: {
1147       // All irregular cases
1148       string ret;
1149       if (print_v2) {
1150         strings::StrAppend(&ret, "[");
1151       }
1152       // TODO(irving): Don't call flat every time around this
1153       // loop.
1154       for (size_t i = 0; i < limit; ++i) {
1155         if (i > 0) strings::StrAppend(&ret, " ");
1156         switch (dtype()) {
1157           case DT_VARIANT: {
1158             const Variant& v = flat<Variant>()(i);
1159             strings::StrAppend(&ret, v.DebugString());
1160           } break;
1161           default:
1162             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1163             // complex64, quantized).
1164             strings::StrAppend(&ret, "?");
1165         }
1166       }
1167       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1168       if (print_v2) {
1169         strings::StrAppend(&ret, "]");
1170       }
1171       return ret;
1172     }
1173   }
1174 }
1175 
tensor_data() const1176 StringPiece Tensor::tensor_data() const {
1177   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1178   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1179 }
1180 
SharesBufferWith(const Tensor & b) const1181 bool Tensor::SharesBufferWith(const Tensor& b) const {
1182   return buf_ != nullptr && b.buf_ != nullptr &&
1183          buf_->root_buffer() == b.buf_->root_buffer();
1184 }
1185 
DebugString(int num_values) const1186 string Tensor::DebugString(int num_values) const {
1187   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1188                          " shape: ", shape().DebugString(),
1189                          " values: ", SummarizeValue(num_values), ">");
1190 }
1191 
DeviceSafeDebugString() const1192 string Tensor::DeviceSafeDebugString() const {
1193   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1194                          " shape: ", shape().DebugString(), ">");
1195 }
1196 
FillDescription(TensorDescription * description) const1197 void Tensor::FillDescription(TensorDescription* description) const {
1198   description->set_dtype(dtype());
1199   shape().AsProto(description->mutable_shape());
1200   if (buf_ != nullptr && buf_->data() != nullptr) {
1201     buf_->FillAllocationDescription(
1202         description->mutable_allocation_description());
1203   }
1204 }
1205 
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1206 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1207     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1208   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1209   int64 offset = orig.size() - num_out_dims;
1210   for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1211     const int64 in_dim = out_dim + offset;
1212     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1213   }
1214   for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1215     out_dims[0] *= orig[in_dim];
1216   }
1217   return out_dims;
1218 }
1219 
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1220 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1221     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1222   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1223   for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1224     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1225   }
1226   for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1227     out_dims[num_out_dims - 1] *= orig[in_dim];
1228   }
1229   return out_dims;
1230 }
1231 
1232 }  // namespace tensorflow
1233