1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19 #include <string>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32
33 namespace tensorflow {
34
35 // START_SKIP_DOXYGEN
36 template <class Shape>
37 class TensorShapeIter;
38 class TensorShape;
39 class TensorShapeProto;
40 class PartialTensorShape;
41 // END_SKIP_DOXYGEN
42
43 /// Internal representation for both TensorShape and PartialTensorShape.
44 class TensorShapeRep {
45 public:
46 ~TensorShapeRep();
47
48 /// Copy the specified shape
49 TensorShapeRep(const TensorShapeRep& b);
50 void operator=(const TensorShapeRep& b);
51
52 /// Move the specified shape. After moving, `b` is safe for destruction and
53 // can be reassigned into, but its dimensions and number of elements can be
54 // nonsensical (e.g., negative dimension sizes, or number of elements not
55 // properly recomputed).
56 TensorShapeRep(TensorShapeRep&& b);
57 void operator=(TensorShapeRep&& b);
58
59 /// Clear a tensor shape, producing the scalar shape.
60 void Clear();
61
62 // Maximum number of dimensions in a tensor.
63 // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()64 static constexpr int MaxDimensions() { return 254; }
65
66 /// \brief Returns the number of elements in the tensor.
67 ///
68 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
69 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully
70 /// defined.
num_elements()71 int64 num_elements() const { return num_elements_; }
72
73 /// For error messages.
74 std::string DebugString() const;
75 static std::string DebugString(const TensorShapeProto& proto);
76
77 protected:
78 // Constructable only via TensorShapeBase
79 TensorShapeRep() = default;
80
81 void ClearAllButDataType();
82
83 // We use 16 bytes to represent a TensorShape. Because we need to
84 // be able to support full 64-bit dimension sizes and an arbitrary
85 // number of dimensions for a Tensor, but most tensor dimensions are
86 // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87 // dimensions, we have several representations.
88 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91 // an out of line vector.
92 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93 // This value is not allowed in TensorShape either for format compatibility.
94 struct Rep16 {
95 uint16 dims_[6];
96 };
97 struct Rep32 {
98 uint32 dims_[3];
99 };
100 struct Rep64 {
101 gtl::InlinedVector<int64, 4>* dims_;
102 };
103
104 // We use the max value of uint16 or uint32 to represent unknown shapes, so
105 // the maximum representable valid shape in these representations is one less.
106 static constexpr int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107 static constexpr int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108 static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109 static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110
as16()111 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114
as16()115 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118
119 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120
121 // Since we have a convenient extra byte available, we allow the
122 // Tensor class to store an 8-bit value in this extra storage. This
123 // allows it to store the Tensor's datatype enum value here and avoid
124 // an extra word of storage.
125 friend class Tensor;
126 friend class TensorShapeTestHelper;
data_type()127 DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128 void set_data_type(DataType dt) {
129 // We only have 8 bits available to store DataType, so make sure it fits
130 DCHECK_LT(static_cast<uint32>(dt), 256u);
131 buf()[13] = static_cast<uint8>(dt);
132 }
133
134 // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135 // Bytes [0..13] vary depending on the representation.
136 // A value of 255 indicates unknown rank in the PartialTensorShape case.
137 static constexpr uint8 kUnknownRank = 255;
ndims_byte()138 uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139 void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140
tag()141 RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143
set_num_elements(int64 n)144 void set_num_elements(int64 n) { num_elements_ = n; }
145
146 private:
147 void DestructorOutOfLine();
148 void SlowCopyFrom(const TensorShapeRep& b);
149
buf()150 uint8* buf() { return &u_.buf[0]; }
buf()151 const uint8* buf() const { return &u_.buf[0]; }
152
153 union {
154 uint8 buf[16];
155 // Force data to be aligned enough for a pointer.
156 Rep64* unused_aligner;
157 } u_;
158 int64 num_elements_;
159 };
160
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167 public:
168 /// \brief Construct a `TensorShapeBase` from the provided sizes.
169 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170 explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171 TensorShapeBase(std::initializer_list<int64> dim_sizes)
172 : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173
174 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175 TensorShapeBase();
176
177 // TODO(mihaimaruseac): Mark this explicit in a subsequent change
178 TensorShapeBase(const TensorShapeProto& proto);
179
180 // These factory methods should be used instead of the constructors that take
181 // an array of sizes if calling code cannot validate that the sizes specify a
182 // valid `TensorShape`.
183 // The value in `*out` is valid iff the returned value is `Status::OK`.
184 static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
185 TensorShapeBase* out);
BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,TensorShapeBase * out)186 static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
187 TensorShapeBase* out) {
188 return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
189 }
190 static Status BuildTensorShapeBase(const TensorShapeProto& proto,
191 TensorShapeBase* out);
192
193 /// Returns `true` iff `proto` is a valid tensor shape.
194 // For TensorShape, the proto shape must be fully defined.
195 static bool IsValid(const TensorShapeProto& proto);
196
197 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
198 /// status otherwise.
199 static Status IsValidShape(const TensorShapeProto& proto);
200
201 /// Returns `true` iff this is a valid tensor shape.
202 bool IsValid();
203
204 /// \brief Add a dimension to the end ("inner-most").
205 /// REQUIRES: `size >= 0`
206 void AddDim(int64 size);
207
208 /// Same as `AddDim` but returns a `Status`.
209 /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
210 Status AddDimWithStatus(int64 size);
211
212 /// Appends all the dimensions from `shape`.
213 void AppendShape(const TensorShapeBase& shape);
214
215 /// Same as `RemoveDim` but returns a `Status`.
216 /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
217 Status AppendShapeWithStatus(const TensorShapeBase& shape);
218
219 /// \brief Insert a dimension somewhere in the `TensorShape`.
220 /// REQUIRES: `0 <= d <= dims()`
221 /// REQUIRES: `size >= 0`
222 void InsertDim(int d, int64 size);
223
224 /// Same as `InsertDim` but returns a `Status`.
225 /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
226 /// `CHECK`-fail crashes.
227 Status InsertDimWithStatus(int d, int64 size);
228
229 /// \brief Modifies the size of the dimension `d` to be `size`
230 /// REQUIRES: `0 <= d < dims()`
231 /// REQUIRES: `size >= 0`
232 void set_dim(int d, int64 size);
233
234 /// Same as `set_dim` but returns a `Status`.
235 /// Use if unsure if requirements in `set_dim` are satistified, to prevent
236 /// `CHECK`-fail crashes.
237 Status SetDimWithStatus(int d, int64 size);
238
239 /// \brief Removes dimension `d` from the `TensorShape`.
240 /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)241 void RemoveDim(int d) {
242 CHECK_GE(d, 0);
243 RemoveDimRange(d, d + 1);
244 }
245
246 /// Same as `RemoveDim` but returns a `Status`.
247 /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
RemoveDimWithStatus(int64 d)248 Status RemoveDimWithStatus(int64 d) {
249 if (TF_PREDICT_FALSE(d < 0)) {
250 return errors::Internal(
251 "Expected dimension index to be non-negative, got ", d);
252 }
253 return RemoveDimRangeWithStatus(d, d + 1);
254 }
255
256 /// \brief Removes last `n` dimensions from the `TensorShape`.
257 /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)258 void RemoveLastDims(int n) {
259 CHECK_LE(n, dims());
260 RemoveDimRange(dims() - n, dims());
261 }
262
263 /// Same as `RemoveLastDims` but returns a `Status`.
264 /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
RemoveLastDimsWithStatus(int64 n)265 Status RemoveLastDimsWithStatus(int64 n) {
266 if (TF_PREDICT_FALSE(n < dims())) {
267 return errors::Internal("Expected dimension index to be at most ", dims(),
268 " got ", n);
269 }
270 return RemoveDimRangeWithStatus(dims() - n, dims());
271 }
272
273 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
274 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
275 /// Python). The same is true for negative values of `begin`.
276 /// REQUIRES: `-(dims()+1) <= begin <= dims()`
277 /// REQUIRES: `-(dims()+1) <= end <= dims()`
278 void RemoveDimRange(int begin, int end);
279
280 /// Same as `RemoveDimRange` but returns a `Status`.
281 /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
282 /// prevent `CHECK`-fail crashes.
283 Status RemoveDimRangeWithStatus(int begin, int end);
284
285 /// Return whether the rank is unknown
unknown_rank()286 bool unknown_rank() const {
287 return kIsPartial && ndims_byte() == kUnknownRank;
288 }
289
290 /// Return the number of dimensions in the tensor.
291 /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()292 int dims() const {
293 uint8 dims = ndims_byte();
294 return kIsPartial && dims == kUnknownRank ? -1 : dims;
295 }
296
297 /// \brief Returns the number of elements in dimension `d`.
298 /// REQUIRES: `0 <= d < dims()`
299 // TODO(touts): Rename to `dimension()` to match
300 // `Eigen::Tensor::dimension()`?
301 int64 dim_size(int d) const;
302
303 /// Returns sizes of all dimensions.
304 // Returns an empty list for unknown rank PartialTensorShape.
305 gtl::InlinedVector<int64, 4> dim_sizes() const;
306
307 /// Return true iff the rank and all of the dimensions are well defined
308 // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()309 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
310
311 /// Fill `*proto` from `*this`.
312 void AsProto(TensorShapeProto* proto) const;
313
314 /// For iterating through the dimensions.
315 TensorShapeIter<Shape> begin() const;
316 TensorShapeIter<Shape> end() const;
317
318 protected:
319 // Optimized constructor for a shape representing an empty vector.
320 //
321 // This constructor is provided to optimize the default constructor for
322 // `Tensor`.
323 explicit TensorShapeBase(DataType dt);
324
325 private:
326 Status RecomputeNumElements();
327 Status InitDims(gtl::ArraySlice<int64> dim_sizes);
328
329 // True for PartialTensorShape, false for TensorShape
330 static constexpr bool kIsPartial =
331 std::is_same<Shape, PartialTensorShape>::value;
332 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333 "Shape is neither TensorShape nor PartialTensorShape");
334
335 // Used by AddDim and MakeShapeHelper. Does no error checking.
336 void UnsafeAddDim(int64 size, int64 new_num_elements);
337
338 // For use by TensorShapeUtils::MakeShape
339 template <class T, class S>
340 friend Status MakeShapeHelper(const T*, int64, S*);
341 };
342
343 /// Outputs `TensorShapeBase` to `std::ostream`.
344 template <typename Shape>
345 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346 return os << tsb.DebugString();
347 }
348
349 /// Represents the shape of a Tensor.
350 ///
351 /// A tensor's shape is denoted by its number of dimensions and a size for each
352 /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have
353 /// a shape of 2-D, [3,4].
354 ///
355 /// If you know the exact shape of your Tensor when you create the TensorShape
356 /// object, you can specify it then, or you can create a TensorShape with
357 /// zero dimensions and one element, and call AddDim() to add dimensions later.
358 class TensorShape : public TensorShapeBase<TensorShape> {
359 public:
360 using TensorShapeBase<TensorShape>::TensorShapeBase;
361
362 /// Allow a TensorShape to be used as a PartialTensorShape without copying
363 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit)
364
365 /// Returns true if `*this` and `b` have the same sizes. Ignores
366 /// dimension names.
367 bool IsSameSize(const TensorShape& b) const;
368 bool operator==(const TensorShape& b) const { return IsSameSize(b); }
369 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
370
371 /// Fill `*dsizes` from `*this`.
372 /// Notice: Using IndexType=int32 in combination with To32Bit() can
373 /// significantly improve performance on GPU.
374 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
375 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
376
377 // Same as `AsEigenDSizes()` but returns a `Status` instead.
378 // Use this method to surface error to user instead of crashing if `NDMIS` is
379 // not equal to `dims()`.
380 // Caller must take ownership of `out`.
381 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
382 Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
383
384 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
385 /// which case we pad the rest of the sizes with 1.
386 /// Notice: Using IndexType=int32 in combination with To32Bit() can
387 /// significantly improve performance on GPU.
388 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
389 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
390
391 // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
392 // Use this method to surface error to user instead of crashing if `NDMIS` is
393 // not equal to `dims()`.
394 // Caller must take ownership of `out`.
395 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
396 Status AsEigenDSizesWithPaddingWithStatus(
397 Eigen::DSizes<IndexType, NDIMS>* out) const;
398
399 private:
400 // These CHECK fail to ease debugging.
401 // REQUIRES: dims() == NDIMS
402 void CheckDimsEqual(int NDIMS) const;
403 // REQUIRES: dims() >= NDIMS
404 void CheckDimsAtLeast(int NDIMS) const;
405
406 // Fill output from `*this`.
407 // Helper method for common code between `AsEigenDSize()` and
408 // `AsEigenDSizeWithStatus()`.
409 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
410 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
411
412 // Fill output from `*this`.
413 // Helper method for common code between `AsEigenDSizesWithPadding()` and
414 // `AsEigenDSizeWithPaddingWithStatus()`.
415 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
416 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
417
418 // For access to TensorShapeBase(DataType).
419 friend class Tensor;
420 };
421
422 /// Outputs `TensorShapeBase` to `std::ostream`.
423 inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
424 return os << ts.DebugString();
425 }
426
427 /// Represents the value of one dimension in a TensorShape.
428 struct TensorShapeDim {
TensorShapeDimTensorShapeDim429 explicit TensorShapeDim(int64 s) : size(s) {}
430 int64 size;
431 };
432
433 // START_SKIP_DOXYGEN
434 template <class Shape>
435 class TensorShapeIter {
436 public:
TensorShapeIter(const Shape * shape,int d)437 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
438 bool operator==(const TensorShapeIter& rhs) {
439 DCHECK(shape_ == rhs.shape_);
440 return d_ == rhs.d_;
441 }
442 bool operator!=(const TensorShapeIter& rhs) {
443 DCHECK(shape_ == rhs.shape_);
444 return d_ != rhs.d_;
445 }
446 void operator++() { ++d_; }
447 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
448
449 private:
450 const Shape* shape_;
451 int d_;
452 };
453 // END_SKIP_DOXYGEN
454
455 /// \brief Static helper routines for `TensorShape`. Includes a few common
456 /// predicates on a tensor shape.
457 class TensorShapeUtils {
458 public:
IsScalar(const TensorShape & shape)459 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
460
IsVector(const TensorShape & shape)461 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
462
IsVectorOrHigher(const TensorShape & shape)463 static bool IsVectorOrHigher(const TensorShape& shape) {
464 return shape.dims() >= 1;
465 }
466
IsMatrix(const TensorShape & shape)467 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
468
IsSquareMatrix(const TensorShape & shape)469 static bool IsSquareMatrix(const TensorShape& shape) {
470 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
471 }
472
IsMatrixOrHigher(const TensorShape & shape)473 static bool IsMatrixOrHigher(const TensorShape& shape) {
474 return shape.dims() >= 2;
475 }
476
477 /// \brief Returns a `TensorShape` whose dimensions are
478 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
479 static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
480 static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
481 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
482 static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
483 static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
484 static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
485 static Status MakeShape(gtl::ArraySlice<int32> shape,
486 PartialTensorShape* out);
487 static Status MakeShape(gtl::ArraySlice<int64> shape,
488 PartialTensorShape* out);
489
490 static std::string ShapeListString(
491 const gtl::ArraySlice<TensorShape>& shapes);
492
493 /// \brief Returns true iff `shape` starts with `prefix`.
494 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
495
496 /// \brief Returns true iff `shape` ends with `suffix`.
497 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
498
499 /// \brief Returns the product of values in an int64 array,
500 /// or a failing Status if the array represents a value larger than
501 /// a `TensorShape` can hold.
502 static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
503 };
504
505 /// Manages the partially known dimensions of a Tensor and their sizes.
506 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
507 public:
PartialTensorShape()508 PartialTensorShape() {}
509 using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
510
511 /// Add a dimension to the end ("inner-most"), returns a new
512 /// PartialTensorShape.
513 /// REQUIRES: `size >= -1`, where -1 means unknown
514 PartialTensorShape Concatenate(int64 size) const;
515
516 /// Similar to `Concatenate` but returning `Status`.
517 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
518 /// are to be avoided.
519 Status ConcatenateWithStatus(int64 size, PartialTensorShape* out) const;
520
521 /// Appends all the dimensions from `shape`. Returns a new
522 /// PartialTensorShape.
523 PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
524
525 /// Similar to `Concatenate` but returning `Status`.
526 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
527 /// are to be avoided.
528 Status ConcatenateWithStatus(const PartialTensorShape& shape,
529 PartialTensorShape* out) const;
530
531 /// Merges all the dimensions from `shape`. Returns
532 /// `InvalidArgument` error if either `shape` has a different rank
533 /// or if any of the dimensions are incompatible.
534 Status MergeWith(const PartialTensorShape& shape,
535 PartialTensorShape* result) const;
536
537 /// Exact equality test. Returns true iff the ranks match (i.e., both are
538 /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
539 /// both dimensions are known, or both are known and equal). This is a
540 /// stronger condition that IsCompatibleWith.
541 bool IsIdenticalTo(const PartialTensorShape& shape) const;
542
543 /// Return true iff the ranks match, and if the
544 /// dimensions all either match or one is unknown.
545 bool IsCompatibleWith(const PartialTensorShape& shape) const;
546
547 // Fill `*shape` from `*this`.
548 // If `*this` is not fully defined, returns false and
549 // `*shape` is left in an intermediate state. Otherwise
550 // returns true.
551 bool AsTensorShape(TensorShape* shape) const;
552
553 /// \brief Returns a `PartialTensorShape` whose dimensions are
554 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
555 /// considered "unknown".
556 template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)557 static Status MakePartialShape(const T* dims, int n,
558 PartialTensorShape* out) {
559 return TensorShapeUtils::MakeShape(dims, n, out);
560 }
561 };
562
563 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
564 /// common predicates on a partially known tensor shape.
565 class PartialTensorShapeUtils {
566 public:
567 static std::string PartialShapeListString(
568 const gtl::ArraySlice<PartialTensorShape>& shapes);
569
570 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
571 const gtl::ArraySlice<PartialTensorShape>& shapes1);
572
573 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
574 const gtl::ArraySlice<PartialTensorShape>& shapes1);
575 };
576
577 // ----------------------------------------------------------------------------
578 // Template method implementation details below
579 // ----------------------------------------------------------------------------
580
581 template <int NDIMS, typename IndexType>
AsEigenDSizesCopy()582 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
583 Eigen::DSizes<IndexType, NDIMS> dsizes;
584 for (int d = 0; d < NDIMS; d++) {
585 dsizes[d] = static_cast<IndexType>(dim_size(d));
586 }
587 return dsizes;
588 }
589
590 template <int NDIMS, typename IndexType>
AsEigenDSizesCopyAndPad()591 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
592 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
593 Eigen::DSizes<IndexType, NDIMS> dsizes;
594 for (int d = 0; d < dims(); d++) {
595 dsizes[d] = static_cast<IndexType>(dim_size(d));
596 }
597 for (int d = dims(); d < NDIMS; d++) {
598 dsizes[d] = 1;
599 }
600 return dsizes;
601 }
602
603 template <int NDIMS, typename IndexType>
AsEigenDSizes()604 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
605 CheckDimsEqual(NDIMS);
606 return AsEigenDSizesCopy<NDIMS, IndexType>();
607 }
608
609 template <int NDIMS, typename IndexType>
AsEigenDSizesWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)610 Status TensorShape::AsEigenDSizesWithStatus(
611 Eigen::DSizes<IndexType, NDIMS>* out) const {
612 if (TF_PREDICT_FALSE(NDIMS != dims())) {
613 return errors::Internal("Asking for tensor of ", NDIMS,
614 " dimensions from a tensor of ", dims(),
615 " dimensions");
616 }
617 *out = AsEigenDSizesCopy<NDIMS, IndexType>();
618 }
619
620 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()621 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
622 CheckDimsAtLeast(NDIMS);
623 return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
624 }
625
626 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPaddingWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)627 Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
628 Eigen::DSizes<IndexType, NDIMS>* out) const {
629 if (TF_PREDICT_FALSE(NDIMS < dims())) {
630 return errors::Internal("Asking for tensor of at least ", NDIMS,
631 " dimensions from a tensor of ", dims(),
632 " dimensions");
633 }
634 *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
635 }
636
637 // ----------------------------------------------------------------------------
638 // Inlining of some performance critical routines
639 // ----------------------------------------------------------------------------
640
TensorShapeRep(const TensorShapeRep & b)641 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
642 num_elements_ = b.num_elements_;
643 if (b.tag() != REP_OUT_OF_LINE) {
644 memcpy(buf(), b.buf(), sizeof(u_.buf));
645 // memcpy above Implicitly does:
646 // set_ndims_byte(b.ndims_byte());
647 // set_tag(b.tag());
648 } else {
649 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate
650 SlowCopyFrom(b);
651 }
652 }
653
TensorShapeRep(TensorShapeRep && b)654 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
655 num_elements_ = b.num_elements_;
656 memcpy(buf(), b.buf(), sizeof(u_.buf));
657 // memcpy above Implicitly does:
658 // set_ndims_byte(b.ndims_byte());
659 // set_tag(b.tag());
660 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
661 }
662
~TensorShapeRep()663 inline TensorShapeRep::~TensorShapeRep() {
664 if (tag() == REP_OUT_OF_LINE) {
665 DestructorOutOfLine();
666 }
667 }
668
669 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
670 num_elements_ = b.num_elements_;
671 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
672 memcpy(buf(), b.buf(), sizeof(u_.buf));
673 // memcpy above implicitly also does:
674 // set_tag(b.tag());
675 // set_ndims_byte(b.ndims_byte());
676 } else {
677 SlowCopyFrom(b);
678 }
679 }
680
681 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
682 if (tag() == REP_OUT_OF_LINE) {
683 DestructorOutOfLine();
684 }
685 num_elements_ = b.num_elements_;
686 memcpy(buf(), b.buf(), sizeof(u_.buf));
687 // memcpy above Implicitly does:
688 // set_ndims_byte(b.ndims_byte());
689 // set_tag(b.tag());
690 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
691 }
692
693 inline TensorShape::operator const PartialTensorShape&() const {
694 // Downcast to the shared representation and upcast to PartialTensorShape
695 const TensorShapeRep* rep = this;
696 return *static_cast<const PartialTensorShape*>(rep);
697 }
698
699 template <class Shape>
TensorShapeBase(DataType dt)700 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
701 set_tag(REP16);
702 set_data_type(dt);
703
704 // Optimized implementation of InitDims() where the shape is statically known
705 // to be {0}.
706 set_ndims_byte(1);
707 uint16* dst = as16()->dims_;
708 *dst = 0;
709 set_num_elements(0);
710 }
711
712 // Declare explicit instantiations in .cc file
713 extern template class TensorShapeBase<TensorShape>;
714 extern template class TensorShapeBase<PartialTensorShape>;
715
716 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
717 // often used in shape inference.
718 struct DtypeAndPartialTensorShape {
719 DataType dtype;
720 PartialTensorShape shape;
721 };
722
723 } // namespace tensorflow
724
725 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
726