1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
18
19 #include <cstdint>
20 #include <type_traits>
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/mem.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace tensorflow {
37
38 // Forward declarations. In particular, we forward declare protos so that their
39 // symbols can be removed from .so exports.
40 class AllocationDescription;
41 class Allocator;
42 class OpKernelContext;
43 class Tensor;
44 class TensorBuffer;
45 class TensorCApi;
46 class TensorCord;
47 class TensorDescription;
48 class TensorProto;
49 class Var;
50
51 namespace batch_util {
52 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
53 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
54 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
55 } // namespace batch_util
56
57 /// @ingroup core
58
59 /// Interface to access the raw ref-counted data buffer.
60 class TensorBuffer : public core::RefCounted {
61 public:
TensorBuffer(void * data_ptr)62 explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
~TensorBuffer()63 ~TensorBuffer() override {}
64
65 /// \brief data() points to a memory region of size() bytes.
66 ///
67 /// NOTE(mrry): The `data()` method is not virtual for performance reasons.
68 /// It can be called multiple times when the contents of a `Tensor` are
69 /// accessed, and so making it non-virtual allows the body to be inlined.
data()70 void* data() const { return data_; }
71
72 /// \brief Size (in bytes) of the buffer.
73 virtual size_t size() const = 0;
74
75 /// \brief If this TensorBuffer is sub-buffer of another TensorBuffer,
76 /// returns that TensorBuffer. Otherwise, returns this.
77 virtual TensorBuffer* root_buffer() = 0;
78
79 /// \brief Fills metadata about the allocation into the proto.
80 virtual void FillAllocationDescription(
81 AllocationDescription* proto) const = 0;
82
83 virtual bool GetAllocatedBytes(size_t* out_bytes) const;
84
85 /// \brief Helper method to reinterpret the buffer as an array of `T`.
86 template <typename T>
base()87 T* base() const {
88 return reinterpret_cast<T*>(data());
89 }
90
91 /// \brief Whether this TensorBuffer owns the underlying memory.
OwnsMemory()92 virtual bool OwnsMemory() const { return true; }
93
94 private:
95 void* const data_;
96 };
97
98 /// Represents an n-dimensional array of values.
99 class Tensor {
100 public:
101 /// \brief Creates a 1-dimensional, 0-element float tensor.
102 ///
103 /// The returned Tensor is not a scalar (shape {}), but is instead
104 /// an empty one-dimensional Tensor (shape {0}, NumElements() ==
105 /// 0). Since it has no elements, it does not need to be assigned a
106 /// value and is initialized by default (IsInitialized() is
107 /// true). If this is undesirable, consider creating a one-element
108 /// scalar which does require initialization:
109 ///
110 /// ```c++
111 ///
112 /// Tensor(DT_FLOAT, TensorShape({}))
113 ///
114 /// ```
115 Tensor();
116
117 /// \brief Creates a Tensor of the given `type` and `shape`. If
118 /// LogMemory::IsEnabled() the allocation is logged as coming from
119 /// an unknown kernel and step. Calling the Tensor constructor
120 /// directly from within an Op is deprecated: use the
121 /// OpKernelConstruction/OpKernelContext allocate_* methods to
122 /// allocate a new tensor, which record the kernel and step.
123 ///
124 /// The underlying buffer is allocated using a `CPUAllocator`.
125 Tensor(DataType type, const TensorShape& shape);
126
127 /// \brief Creates a tensor with the input `type` and `shape`, using
128 /// the allocator `a` to allocate the underlying buffer. If
129 /// LogMemory::IsEnabled() the allocation is logged as coming from
130 /// an unknown kernel and step. Calling the Tensor constructor
131 /// directly from within an Op is deprecated: use the
132 /// OpKernelConstruction/OpKernelContext allocate_* methods to
133 /// allocate a new tensor, which record the kernel and step.
134 ///
135 /// `a` must outlive the lifetime of this Tensor.
136 Tensor(Allocator* a, DataType type, const TensorShape& shape);
137
138 /// \brief Creates a tensor with the input `type` and `shape`, using
139 /// the allocator `a` and the specified "allocation_attr" to
140 /// allocate the underlying buffer. If the kernel and step are known
141 /// allocation_attr.allocation_will_be_logged should be set to true
142 /// and LogMemory::RecordTensorAllocation should be called after the
143 /// tensor is constructed. Calling the Tensor constructor directly
144 /// from within an Op is deprecated: use the
145 /// OpKernelConstruction/OpKernelContext allocate_* methods to
146 /// allocate a new tensor, which record the kernel and step.
147 ///
148 /// `a` must outlive the lifetime of this Tensor.
149 Tensor(Allocator* a, DataType type, const TensorShape& shape,
150 const AllocationAttributes& allocation_attr);
151
152 /// \brief Creates a tensor with the input datatype, shape and buf.
153 ///
154 /// Acquires a ref on buf that belongs to this Tensor.
155 Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
156
157 /// \brief Creates an empty Tensor of the given data type.
158 ///
159 /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
160 /// IsInitialized() returning True. See the Tensor() documentation
161 /// for details.
162 explicit Tensor(DataType type);
163
164 private:
165 // A tag type for selecting the `Tensor` constructor overload that creates a
166 // scalar tensor in host memory.
167 struct host_scalar_tag {};
168
169 class HostScalarTensorBufferBase;
170 template <typename T>
171 struct ValueAndTensorBuffer;
172
173 // Creates a tensor with the given scalar `value` in CPU memory.
174 template <typename T>
175 Tensor(T value, host_scalar_tag tag);
176
177 public:
178 // A series of specialized constructors for scalar tensors in host memory.
179 //
180 // NOTE: The `Variant` host-scalar constructor is not defined, because Variant
181 // is implicitly constructible from many different types, and this causes
182 // ambiguities with some compilers.
Tensor(float scalar_value)183 explicit Tensor(float scalar_value)
184 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(double scalar_value)185 explicit Tensor(double scalar_value)
186 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(int32 scalar_value)187 explicit Tensor(int32 scalar_value)
188 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(uint32 scalar_value)189 explicit Tensor(uint32 scalar_value)
190 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(uint16 scalar_value)191 explicit Tensor(uint16 scalar_value)
192 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(uint8 scalar_value)193 explicit Tensor(uint8 scalar_value)
194 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(int16 scalar_value)195 explicit Tensor(int16 scalar_value)
196 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(int8 scalar_value)197 explicit Tensor(int8 scalar_value)
198 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(tstring scalar_value)199 explicit Tensor(tstring scalar_value)
200 : Tensor(std::move(scalar_value), host_scalar_tag{}) {}
Tensor(complex64 scalar_value)201 explicit Tensor(complex64 scalar_value)
202 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(complex128 scalar_value)203 explicit Tensor(complex128 scalar_value)
204 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(int64 scalar_value)205 explicit Tensor(int64 scalar_value)
206 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(uint64 scalar_value)207 explicit Tensor(uint64 scalar_value)
208 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(bool scalar_value)209 explicit Tensor(bool scalar_value)
210 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(qint8 scalar_value)211 explicit Tensor(qint8 scalar_value)
212 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(quint8 scalar_value)213 explicit Tensor(quint8 scalar_value)
214 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(qint16 scalar_value)215 explicit Tensor(qint16 scalar_value)
216 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(quint16 scalar_value)217 explicit Tensor(quint16 scalar_value)
218 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(qint32 scalar_value)219 explicit Tensor(qint32 scalar_value)
220 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(bfloat16 scalar_value)221 explicit Tensor(bfloat16 scalar_value)
222 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(Eigen::half scalar_value)223 explicit Tensor(Eigen::half scalar_value)
224 : Tensor(scalar_value, host_scalar_tag{}) {}
Tensor(ResourceHandle scalar_value)225 explicit Tensor(ResourceHandle scalar_value)
226 : Tensor(std::move(scalar_value), host_scalar_tag{}) {}
227
228 // NOTE: The `const char*` host-scalar constructor is provided as a
229 // convenience because otherwise passing a string literal would surprisingly
230 // construct a DT_BOOL tensor.
Tensor(const char * scalar_value)231 explicit Tensor(const char* scalar_value)
232 : Tensor(tstring(scalar_value), host_scalar_tag{}) {}
233
234 /// Copy constructor.
235 Tensor(const Tensor& other);
236
237 /// \brief Move constructor. After this call, <other> is safely destructible
238 /// and can be assigned to, but other calls on it (e.g. shape manipulation)
239 /// are not valid.
240 Tensor(Tensor&& other);
241
242 ~Tensor();
243
244 /// Returns the data type.
dtype()245 DataType dtype() const { return shape_.data_type(); }
246
247 /// Returns the shape of the tensor.
shape()248 const TensorShape& shape() const { return shape_; }
249
250 /// \brief Convenience accessor for the tensor shape.
251 ///
252 /// For all shape accessors, see comments for relevant methods of
253 /// `TensorShape` in `tensor_shape.h`.
dims()254 int dims() const { return shape().dims(); }
255
256 /// Convenience accessor for the tensor shape.
dim_size(int d)257 int64 dim_size(int d) const { return shape().dim_size(d); }
258
259 /// Convenience accessor for the tensor shape.
NumElements()260 int64 NumElements() const { return shape().num_elements(); }
261
IsSameSize(const Tensor & b)262 bool IsSameSize(const Tensor& b) const {
263 return shape().IsSameSize(b.shape());
264 }
265
266 // True iff the two tensors use the same underlying refcounted storage
267 bool SharesBufferWith(const Tensor& b) const;
268
269 /// \brief If necessary, has this Tensor been initialized?
270 ///
271 /// Zero-element Tensors are always considered initialized, even if they
272 /// have never been assigned to and do not have any memory allocated.
273 bool IsInitialized() const;
274
275 /// Returns the estimated memory usage of this tensor.
276 size_t TotalBytes() const;
277
278 // Returns the size of allocated memory for this tensor.
279 size_t AllocatedBytes() const;
280
281 /// Returns true iff this tensor is aligned.
IsAligned()282 bool IsAligned() const {
283 #if EIGEN_MAX_ALIGN_BYTES == 0
284 return true;
285 #else
286 void* ptr = base<void>();
287 return dtype() == DT_STRING ||
288 (reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0);
289 #endif
290 }
291
292 /// Assign operator. This tensor shares other's underlying storage.
293 Tensor& operator=(const Tensor& other) {
294 CopyFromInternal(other, other.shape());
295 return *this;
296 }
297
298 /// Move operator. See move constructor for details.
299 Tensor& operator=(Tensor&& other);
300
301 /// \brief Copy the other tensor into this tensor and reshape it.
302 ///
303 /// This tensor shares other's underlying storage. Returns `true`
304 /// iff `other.shape()` has the same number of elements of the given
305 /// `shape`.
CopyFrom(const Tensor & other,const TensorShape & shape)306 bool CopyFrom(const Tensor& other,
307 const TensorShape& shape) TF_MUST_USE_RESULT {
308 if (other.NumElements() != shape.num_elements()) return false;
309 CopyFromInternal(other, shape);
310 return true;
311 }
312
313 /// \brief Slice this tensor along the 1st dimension.
314
315 /// I.e., the returned tensor satisfies
316 /// returned[i, ...] == this[dim0_start + i, ...].
317 /// The returned tensor shares the underlying tensor buffer with this
318 /// tensor.
319 ///
320 /// NOTE: The returned tensor may not satisfy the same alignment
321 /// requirement as this tensor depending on the shape. The caller
322 /// must check the returned tensor's alignment before calling certain
323 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
324 ///
325 /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
326 /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
327 ///
328 /// REQUIRES: `dims()` >= 1
329 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
330 Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
331
332 /// \brief Select a subslice from this tensor along the 1st dimension.
333 ///
334 /// When fed with an N-dimensional tensor, this method returns a tensor with
335 /// N-1 dimensions, where the returned tensor is a subslice of the input
336 /// tensor along the first dimension. The N-1 dimensions of the returned
337 /// tensor are the last N-1 dimensions of the input tensor.
338 ///
339 /// NOTE: The returned tensor may not satisfy the same alignment
340 /// requirement as this tensor depending on the shape. The caller
341 /// must check the returned tensor's alignment before calling certain
342 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
343 ///
344 /// REQUIRES: `dims()` >= 1
345 /// REQUIRES: `0 <= index < dim_size(0)`
346 Tensor SubSlice(int64 index) const;
347
348 /// \brief Parse `other` and construct the tensor.
349
350 /// Returns `true` iff the parsing succeeds. If the parsing fails,
351 /// the state of `*this` is unchanged.
352 bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
353 bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
354
355 /// \brief Fills in `proto` with `*this` tensor's content.
356 ///
357 /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
358 /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
359 /// in a compact form.
360 void AsProtoField(TensorProto* proto) const;
361 void AsProtoTensorContent(TensorProto* proto) const;
362
363 /// \brief Return the tensor data as an `Eigen::Tensor` with the type and
364 /// sizes of this `Tensor`.
365 ///
366 /// Use these methods when you know the data type and the number of
367 /// dimensions of the Tensor and you want an `Eigen::Tensor`
368 /// automatically sized to the `Tensor` sizes. The implementation check
369 /// fails if either type or sizes mismatch.
370 ///
371 /// Example:
372 ///
373 /// ```c++
374 ///
375 /// typedef float T;
376 /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
377 /// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
378 /// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
379 /// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
380 /// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
381 /// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
382 ///
383 /// ```
384 template <typename T>
vec()385 typename TTypes<T>::Vec vec() {
386 return tensor<T, 1>();
387 }
388
389 template <typename T>
matrix()390 typename TTypes<T>::Matrix matrix() {
391 return tensor<T, 2>();
392 }
393
394 template <typename T, size_t NDIMS>
395 typename TTypes<T, NDIMS>::Tensor tensor();
396
397 /// \brief Return the tensor data to an `Eigen::Tensor` with the
398 /// same size but a bitwise cast to the specified dtype `T`.
399 ///
400 /// Using a bitcast is useful for move and copy operations.
401 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
402 template <typename T, size_t NDIMS>
403 typename TTypes<T, NDIMS>::Tensor bit_casted_tensor();
404
405 /// \brief Return the tensor data to an `Eigen::Tensor` with the
406 /// last dimension elements converted into single elements of a larger type.
407 ///
408 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
409 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
410 /// the original element type * num elements in the original last dimension.
411 /// NDIMS should be 1 less than the original number of dimensions.
412 template <typename T, size_t NDIMS>
413 typename TTypes<T, NDIMS>::Tensor reinterpret_last_dimension();
414
415 /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
416 /// specified shape.
417 ///
418 /// These methods allow you to access the data with the dimensions
419 /// and sizes of your choice. You do not need to know the number of
420 /// dimensions of the Tensor to call them. However, they `CHECK` that
421 /// the type matches and the dimensions requested creates an
422 /// `Eigen::Tensor` with the same number of elements as the tensor.
423 ///
424 /// Example:
425 ///
426 /// ```c++
427 ///
428 /// typedef float T;
429 /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
430 /// // 1D Eigen::Tensor, size 60:
431 /// auto flat = my_ten.flat<T>();
432 /// // 2D Eigen::Tensor 12 x 5:
433 /// auto inner = my_ten.flat_inner_dims<T>();
434 /// // 2D Eigen::Tensor 4 x 15:
435 /// auto outer = my_ten.shaped<T, 2>({4, 15});
436 /// // CHECK fails, bad num elements:
437 /// auto outer = my_ten.shaped<T, 2>({4, 8});
438 /// // 3D Eigen::Tensor 6 x 5 x 2:
439 /// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
440 /// // CHECK fails, type mismatch:
441 /// auto bad = my_ten.flat<int32>();
442 ///
443 /// ```
444 template <typename T>
flat()445 typename TTypes<T>::Flat flat() {
446 return shaped<T, 1>({NumElements()});
447 }
448
449 template <typename T>
unaligned_flat()450 typename TTypes<T>::UnalignedFlat unaligned_flat() {
451 return unaligned_shaped<T, 1>({NumElements()});
452 }
453
454 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
455 /// Tensor dimensions but the last NDIMS-1 into the first dimension of the
456 /// result. If NDIMS > dims() then leading dimensions of size 1 will be
457 /// added to make the output rank NDIMS.
458 template <typename T, size_t NDIMS = 2>
459 typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
460
461 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
462 /// Tensor dimensions but the first NDIMS-1 into the last dimension of the
463 /// result. If NDIMS > dims() then trailing dimensions of size 1 will be
464 /// added to make the output rank NDIMS.
465 template <typename T, size_t NDIMS = 2>
466 typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
467
468 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the
469 /// first 'begin' Tensor dimensions into the first dimension of the result and
470 /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last
471 /// dimension of the result. If 'begin' < 0 then the |'begin'| leading
472 /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then
473 /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added.
474 template <typename T, size_t NDIMS = 3>
475 typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin);
476
477 template <typename T, size_t NDIMS>
478 typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
479
480 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
481 /// shape specified in `new_sizes` and cast to a new dtype `T`.
482 ///
483 /// Using a bitcast is useful for move and copy operations.
484 /// The allowed bitcast is the only difference from `shaped()`.
485 template <typename T, size_t NDIMS>
486 typename TTypes<T, NDIMS>::Tensor bit_casted_shaped(
487 gtl::ArraySlice<int64> new_sizes);
488
489 template <typename T, size_t NDIMS>
490 typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
491 gtl::ArraySlice<int64> new_sizes);
492
493 /// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
494 /// `TensorMap<TensorFixedSize<T, 1>>`.
495
496 /// Using `scalar()` allows the compiler to perform optimizations as
497 /// the size of the tensor is known at compile time.
498 template <typename T>
499 typename TTypes<T>::Scalar scalar();
500
501 /// Const versions of all the methods above.
502 template <typename T>
vec()503 typename TTypes<T>::ConstVec vec() const {
504 return tensor<T, 1>();
505 }
506
507 template <typename T>
matrix()508 typename TTypes<T>::ConstMatrix matrix() const {
509 return tensor<T, 2>();
510 }
511
512 template <typename T, size_t NDIMS>
513 typename TTypes<T, NDIMS>::ConstTensor tensor() const;
514
515 /// \brief Return the tensor data to an `Eigen::Tensor` with the
516 /// same size but a bitwise cast to the specified dtype `T`.
517 ///
518 /// Using a bitcast is useful for move and copy operations.
519 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
520 template <typename T, size_t NDIMS>
521 typename TTypes<T, NDIMS>::ConstTensor bit_casted_tensor() const;
522
523 /// \brief Return the tensor data to an `Eigen::Tensor` with the
524 /// last dimension elements converted into single elements of a larger type.
525 ///
526 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
527 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
528 /// the original element type * num elements in the original last dimension.
529 /// NDIMS should be 1 less than the original number of dimensions.
530 template <typename T, size_t NDIMS>
531 typename TTypes<T, NDIMS>::ConstTensor reinterpret_last_dimension() const;
532
533 template <typename T>
flat()534 typename TTypes<T>::ConstFlat flat() const {
535 return shaped<T, 1>({NumElements()});
536 }
537
538 template <typename T>
unaligned_flat()539 typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
540 return unaligned_shaped<T, 1>({NumElements()});
541 }
542
543 template <typename T, size_t NDIMS>
544 typename TTypes<T, NDIMS>::ConstTensor shaped(
545 gtl::ArraySlice<int64> new_sizes) const;
546
547 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
548 /// shape specified in `new_sizes` and cast to a new dtype `T`.
549 ///
550 /// Using a bitcast is useful for move and copy operations.
551 /// The allowed bitcast is the only difference from `shaped()`.
552 template <typename T, size_t NDIMS>
553 typename TTypes<T, NDIMS>::ConstTensor bit_casted_shaped(
554 gtl::ArraySlice<int64> new_sizes) const;
555
556 template <typename T, size_t NDIMS>
557 typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
558 gtl::ArraySlice<int64> new_sizes) const;
559
560 template <typename T>
561 typename TTypes<T>::ConstScalar scalar() const;
562
563 template <typename T, size_t NDIMS = 2>
564 typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
565
566 template <typename T, size_t NDIMS = 2>
567 typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
568
569 template <typename T, size_t NDIMS = 3>
570 typename TTypes<T, NDIMS>::ConstTensor flat_inner_outer_dims(
571 int64 begin) const;
572
573 /// Render the first `max_entries` values in `*this` into a string.
574 string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
575
576 /// A human-readable summary of the tensor suitable for debugging.
577 // `num_values` is the number of actual data values in the tensor
578 // included in the message. If the tensor might be resident in
579 // GPU/TPU memory use DeviceSafeDebugString instead.
580 string DebugString(int num_values) const;
DebugString()581 string DebugString() const { return DebugString(3); }
582
583 // Variant of DebugString() that should be used for possibly non-CPU tensors.
584 // If the tensor is not resident on CPU, we can't read its values as
585 // DebugString() does.
586 string DeviceSafeDebugString() const;
587
588 /// Fill in the `TensorDescription` proto with metadata about the
589 /// tensor that is useful for monitoring and debugging.
590 void FillDescription(TensorDescription* description) const;
591
592 /// \brief Returns a `StringPiece` mapping the current tensor's buffer.
593 ///
594 /// The returned `StringPiece` may point to memory location on devices
595 /// that the CPU cannot address directly.
596 ///
597 /// NOTE: The underlying tensor buffer is refcounted, so the lifetime
598 /// of the contents mapped by the `StringPiece` matches the lifetime of
599 /// the buffer; callers should arrange to make sure the buffer does
600 /// not get destroyed while the `StringPiece` is still used.
601 ///
602 /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
603 StringPiece tensor_data() const;
604
605 /// Copy the other tensor into this tensor, reshape it and reinterpret the
606 /// buffer's datatype. If Status::OK() is returned, the two tensors now share
607 /// the same underlying storage.
608 ///
609 /// This call requires that the `other` tensor and the given type and shape
610 /// are "compatible" (i.e. they occupy the same number of bytes).
611 ///
612 /// Specifically:
613 ///
614 /// shape.num_elements() * DataTypeSize(type)
615 ///
616 /// must equal
617 ///
618 /// other.num_elements() * DataTypeSize(other.dtype())
619 ///
620 /// In addition, this function requires:
621 /// * DataTypeSize(other.dtype()) != 0
622 /// * DataTypeSize(type) != 0
623 ///
624 /// If any of the requirements are not met, errors::InvalidArgument is
625 /// returned.
626 Status BitcastFrom(const Tensor& other, DataType dtype,
627 const TensorShape& shape);
628
629 /// Like BitcastFrom, but CHECK fails if any preconditions are not met.
630 ///
631 /// Deprecated. Use BitcastFrom instead and check the returned Status.
UnsafeCopyFromInternal(const Tensor & other,DataType dtype,const TensorShape & shape)632 void UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
633 const TensorShape& shape) {
634 TF_CHECK_OK(BitcastFrom(other, dtype, shape));
635 }
636
637 // Returns true if the refcount on buf_ and any possible underlying root
638 // buffer is one.
639 bool RefCountIsOne() const;
640
641 private:
642 void CheckType(DataType expected_dtype) const;
643 void CheckTypeAndIsAligned(DataType expected_dtype) const;
644 void CheckIsAlignedAndSingleElement() const;
set_dtype(DataType t)645 void set_dtype(DataType t) { shape_.set_data_type(t); }
646
647 // TensorShape's InlineVector.
648 static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(
649 gtl::ArraySlice<int64> orig, int64 num_out_dims);
650 static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(
651 gtl::ArraySlice<int64> orig, int64 num_out_dims);
652
653 TensorShape shape_;
654 TensorBuffer* buf_;
655
656 friend class DMAHelper; // For access to buf_.
657 friend class TensorCApi; // For access to buf_.
658 friend class TensorCord; // For access to buf_.
659 friend class TensorReference; // For access to buf_.
660 friend class VariableOp; // For access to set_shape.
661 friend class AutoReloadVariableOp; // For access to set_shape.
662 friend class TensorTestHelper; // For access to set_shape.
663 friend class CastOpBase; // For access to set_dtype.
664 friend class ScopedAllocator; // For access to buf_.
665 friend Status batch_util::CopyElementToSlice(
666 Tensor element, Tensor* parent,
667 int64 index); // For access to base<T>().
668 friend Status batch_util::CopySliceToElement(
669 const Tensor& parent, Tensor* element,
670 int64 index); // For access to base<T>().
671 friend Status batch_util::MaybeMoveSliceToElement(
672 Tensor* parent, Tensor* element,
673 int64 index); // For access to base<T>().
674
675 bool CanUseDMA() const;
676
677 // Only needed by variable op to set the shape of an uninitialized
678 // Tensor.
679 // TODO: Remove this when we have a better story for detecting
680 // uninitialized tensors.
set_shape(const TensorShape & shape)681 void set_shape(const TensorShape& shape) {
682 DataType dt = dtype();
683 shape_ = shape;
684 set_dtype(dt);
685 }
686
687 void CopyFromInternal(const Tensor& other, const TensorShape& shape);
688
689 template <typename T>
690 T* base() const;
691
692 template <size_t NDIMS>
693 void FillDimsAndValidateCompatibleShape(
694 gtl::ArraySlice<int64> new_sizes,
695 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
696
697 template <typename T, size_t NDIMS>
698 void FillDimsAndValidateCompatibleShape(
699 gtl::ArraySlice<int64> new_sizes,
700 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
701 };
702
703 // Implementation details
704
705 // START_SKIP_DOXYGEN
706
707 template <typename T>
base()708 T* Tensor::base() const {
709 return buf_ == nullptr ? nullptr : buf_->base<T>();
710 }
711
712 template <typename T, size_t NDIMS>
tensor()713 typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
714 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
715 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
716 shape().AsEigenDSizes<NDIMS>());
717 }
718
719 template <typename T, size_t NDIMS>
tensor()720 typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
721 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
722 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
723 shape().AsEigenDSizes<NDIMS>());
724 }
725
726 template <typename T, size_t NDIMS>
bit_casted_tensor()727 typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_tensor() {
728 CHECK(IsAligned());
729 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
730 shape().AsEigenDSizes<NDIMS>());
731 }
732
733 template <typename T, size_t NDIMS>
bit_casted_tensor()734 typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_tensor() const {
735 CHECK(IsAligned());
736 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
737 shape().AsEigenDSizes<NDIMS>());
738 }
739
740 template <typename T, size_t NDIMS>
reinterpret_last_dimension()741 typename TTypes<T, NDIMS>::Tensor Tensor::reinterpret_last_dimension() {
742 if (NDIMS == dims()) {
743 return tensor<T, NDIMS>();
744 }
745 CHECK(IsAligned());
746 CHECK_EQ(NDIMS, dims() - 1);
747 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
748 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
749 for (int d = 0; d < NDIMS; ++d) {
750 dims[d] = shape_.dim_sizes()[d];
751 }
752 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
753 }
754
755 template <typename T, size_t NDIMS>
reinterpret_last_dimension()756 typename TTypes<T, NDIMS>::ConstTensor Tensor::reinterpret_last_dimension()
757 const {
758 if (NDIMS == dims()) {
759 return tensor<T, NDIMS>();
760 }
761 CHECK(IsAligned());
762 CHECK_EQ(NDIMS, dims() - 1);
763 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
764 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
765 for (int d = 0; d < NDIMS; ++d) {
766 dims[d] = shape_.dim_sizes()[d];
767 }
768 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), dims);
769 }
770
771 template <size_t NDIMS>
FillDimsAndValidateCompatibleShape(gtl::ArraySlice<int64> new_sizes,Eigen::array<Eigen::DenseIndex,NDIMS> * dims)772 void Tensor::FillDimsAndValidateCompatibleShape(
773 gtl::ArraySlice<int64> new_sizes,
774 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
775 CHECK_EQ(NDIMS, new_sizes.size());
776 int64 new_num_elements = 1;
777 for (size_t d = 0; d < NDIMS; d++) {
778 new_num_elements *= new_sizes[d];
779 (*dims)[d] = new_sizes[d];
780 }
781 CHECK_EQ(new_num_elements, NumElements());
782 }
783
784 template <typename T, size_t NDIMS>
FillDimsAndValidateCompatibleShape(gtl::ArraySlice<int64> new_sizes,Eigen::array<Eigen::DenseIndex,NDIMS> * dims)785 void Tensor::FillDimsAndValidateCompatibleShape(
786 gtl::ArraySlice<int64> new_sizes,
787 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
788 CHECK_EQ(NDIMS, new_sizes.size());
789 int64 new_num_elements = 1;
790 for (size_t d = 0; d < NDIMS; d++) {
791 new_num_elements *= new_sizes[d];
792 (*dims)[d] = new_sizes[d];
793 }
794 const int element_size = DataTypeSize(BaseType(dtype()));
795 if (element_size > 0) {
796 CHECK_EQ(new_num_elements * sizeof(T), NumElements() * element_size);
797 } else {
798 // DataTypeSize() returns 0 for some data types. In this case, assume that T
799 // has the same size as the buffer type.
800 // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD
801 // types, then we should check DataTypeToEnum<T>::v() == dtype(). Or simply
802 // check if `element_size > 0` to err when bit cast is attempted on Tensor
803 // of unknown data type size.
804 CHECK_EQ(new_num_elements, NumElements());
805 }
806 }
807
808 template <typename T, size_t NDIMS>
shaped(gtl::ArraySlice<int64> new_sizes)809 typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
810 gtl::ArraySlice<int64> new_sizes) {
811 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
812 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
813 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
814 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
815 }
816
817 template <typename T, size_t NDIMS>
bit_casted_shaped(gtl::ArraySlice<int64> new_sizes)818 typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_shaped(
819 gtl::ArraySlice<int64> new_sizes) {
820 CHECK(IsAligned());
821 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
822 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
823 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
824 }
825
826 template <typename T, size_t NDIMS>
unaligned_shaped(gtl::ArraySlice<int64> new_sizes)827 typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
828 gtl::ArraySlice<int64> new_sizes) {
829 CheckType(DataTypeToEnum<T>::v());
830 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
831 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
832 return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
833 }
834
835 template <typename T, size_t NDIMS>
shaped(gtl::ArraySlice<int64> new_sizes)836 typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
837 gtl::ArraySlice<int64> new_sizes) const {
838 CheckType(DataTypeToEnum<T>::v());
839 CHECK(IsAligned());
840 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
841 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
842 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
843 }
844
845 template <typename T, size_t NDIMS>
bit_casted_shaped(gtl::ArraySlice<int64> new_sizes)846 typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_shaped(
847 gtl::ArraySlice<int64> new_sizes) const {
848 CHECK(IsAligned());
849 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
850 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
851 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
852 }
853
854 template <typename T, size_t NDIMS>
unaligned_shaped(gtl::ArraySlice<int64> new_sizes)855 typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
856 gtl::ArraySlice<int64> new_sizes) const {
857 CheckType(DataTypeToEnum<T>::v());
858 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
859 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
860 return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
861 }
862
863 template <typename T>
scalar()864 typename TTypes<T>::Scalar Tensor::scalar() {
865 static_assert(
866 !std::is_same<T, std::string>::value,
867 "std::string is no longer a scalar type, use tensorflow::tstring");
868 CheckIsAlignedAndSingleElement();
869 return typename TTypes<T>::Scalar(base<T>());
870 }
871
872 template <typename T>
scalar()873 typename TTypes<T>::ConstScalar Tensor::scalar() const {
874 static_assert(
875 !std::is_same<T, std::string>::value,
876 "std::string is no longer a scalar type, use tensorflow::tstring");
877 CheckIsAlignedAndSingleElement();
878 return typename TTypes<T>::ConstScalar(base<T>());
879 }
880
881 template <typename T, size_t NDIMS>
flat_inner_dims()882 typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
883 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
884 }
885
886 template <typename T, size_t NDIMS>
flat_outer_dims()887 typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
888 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
889 }
890
891 template <typename T, size_t NDIMS>
flat_inner_outer_dims(int64 begin)892 typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) {
893 gtl::InlinedVector<int64, 4> flat_outer =
894 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
895 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
896 }
897
898 template <typename T, size_t NDIMS>
flat_inner_dims()899 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
900 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
901 }
902
903 template <typename T, size_t NDIMS>
flat_outer_dims()904 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
905 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
906 }
907
908 template <typename T, size_t NDIMS>
flat_inner_outer_dims(int64 begin)909 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_outer_dims(
910 int64 begin) const {
911 gtl::InlinedVector<int64, 4> flat_outer =
912 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
913 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
914 }
915
Tensor(const Tensor & other)916 inline Tensor::Tensor(const Tensor& other)
917 : shape_(other.shape()), buf_(other.buf_) {
918 if (buf_) buf_->Ref();
919 }
920
Tensor(Tensor && other)921 inline Tensor::Tensor(Tensor&& other)
922 : shape_(std::move(other.shape_)), buf_(other.buf_) {
923 other.buf_ = nullptr;
924 }
925
926 class Tensor::HostScalarTensorBufferBase : public TensorBuffer {
927 public:
928 using TensorBuffer::TensorBuffer;
929 bool GetAllocatedBytes(size_t* out_bytes) const final;
930 void FillAllocationDescription(AllocationDescription* proto) const final;
931 };
932
933 // A packed representation for a single scalar value of type `T`, and a
934 // `TensorBuffer` implementation that describes (and manages the lifetime of)
935 // that value.
936 template <typename T>
937 struct Tensor::ValueAndTensorBuffer {
938 class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase {
939 public:
HostScalarTensorBufferValueAndTensorBuffer940 explicit HostScalarTensorBuffer(void* data)
941 : HostScalarTensorBufferBase(data) {}
sizeValueAndTensorBuffer942 size_t size() const final { return sizeof(T); }
root_bufferValueAndTensorBuffer943 TensorBuffer* root_buffer() final { return this; }
944
945 // Override `operator delete` so that calling `delete this` in
946 // `core::Refcounted::Unref()` for an object of this type will free
947 // the enclosing `ValueAndTensorBuffer` for the tensor buffer.
948 //
949 // NOTE(mrry): The definition of this method must be outside the class
950 // definition in order to satisfy some compilers.
951 static void operator delete(void* ptr);
952
deleteValueAndTensorBuffer953 static void operator delete(void*, void*) {
954 // Some compilers require an overridden class-specific deallocation
955 // function, which will be called if placement `new` throws an
956 // exception.
957 }
958
959 private:
~HostScalarTensorBufferValueAndTensorBuffer960 ~HostScalarTensorBuffer() override { static_cast<T*>(data())->~T(); }
961 };
962
963 T value;
964 HostScalarTensorBuffer tensor_buffer;
965 };
966
967 /* static */
968 template <typename T>
delete(void * ptr)969 void Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer::operator delete(
970 void* ptr) {
971 // Use a dummy object to compute to offset of
972 // `ValueAndTensorBuffer::tensor_buffer`, because `offsetof()` is not
973 // necessarily defined on this non-POD type (until C++17).
974 //
975 // NOTE(mrry): Using `sizeof(Tensor::ValueAndTensorBuffer<T>)` here requires
976 // us to define this method outside the class definition, so that it is not
977 // considered an incomplete type.
978 typename std::aligned_storage<sizeof(Tensor::ValueAndTensorBuffer<T>),
979 alignof(Tensor::ValueAndTensorBuffer<T>)>::type
980 dummy_storage_;
981 Tensor::ValueAndTensorBuffer<T>* dummy_object =
982 reinterpret_cast<Tensor::ValueAndTensorBuffer<T>*>(&dummy_storage_);
983 intptr_t offset = reinterpret_cast<intptr_t>(&dummy_object->tensor_buffer) -
984 reinterpret_cast<intptr_t>(dummy_object);
985
986 port::AlignedFree(static_cast<char*>(ptr) - offset);
987 }
988
989 template <typename T>
Tensor(T value,host_scalar_tag tag)990 Tensor::Tensor(T value, host_scalar_tag tag) {
991 auto* value_and_buf = static_cast<Tensor::ValueAndTensorBuffer<T>*>(
992 port::AlignedMalloc(sizeof(typename Tensor::ValueAndTensorBuffer<T>),
993 EIGEN_MAX_ALIGN_BYTES));
994 new (&value_and_buf->value) T(std::move(value));
995 new (&value_and_buf->tensor_buffer)
996 typename Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer(
997 value_and_buf);
998 buf_ = &value_and_buf->tensor_buffer;
999 set_dtype(DataTypeToEnum<T>::value);
1000 }
1001
1002 inline Tensor& Tensor::operator=(Tensor&& other) {
1003 // Avoid self-assignment, since we might destroy our underlying buffer.
1004 if (&other != this) {
1005 shape_ = std::move(other.shape_);
1006 if (buf_) buf_->Unref();
1007 buf_ = other.buf_;
1008 other.buf_ = nullptr;
1009 }
1010 return *this;
1011 }
1012
1013 // END_SKIP_DOXYGEN
1014
1015 } // namespace tensorflow
1016
1017 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
1018