1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19 #include <string>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32
33 namespace tensorflow {
34
35 // START_SKIP_DOXYGEN
36 template <class Shape>
37 class TensorShapeIter;
38 class TensorShape;
39 class TensorShapeProto;
40 class PartialTensorShape;
41 // END_SKIP_DOXYGEN
42
43 /// Internal representation for both TensorShape and PartialTensorShape.
44 class TensorShapeRep {
45 public:
46 ~TensorShapeRep();
47
48 /// Copy the specified shape
49 TensorShapeRep(const TensorShapeRep& b);
50 void operator=(const TensorShapeRep& b);
51
52 /// Move the specified shape. After moving, `b` is safe for destruction and
53 // can be reassigned into, but its dimensions and number of elements can be
54 // nonsensical (e.g., negative dimension sizes, or number of elements not
55 // properly recomputed).
56 TensorShapeRep(TensorShapeRep&& b);
57 void operator=(TensorShapeRep&& b);
58
59 /// Clear a tensor shape, producing the scalar shape.
60 void Clear();
61
62 // Maximum number of dimensions in a tensor.
63 // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()64 static constexpr int MaxDimensions() { return 254; }
65
66 /// \brief Returns the number of elements in the tensor.
67 ///
68 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
69 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully
70 /// defined.
num_elements()71 int64 num_elements() const { return num_elements_; }
72
73 /// For error messages.
74 std::string DebugString() const;
75 static std::string DebugString(const TensorShapeProto& proto);
76
77 protected:
78 // Constructable only via TensorShapeBase
79 TensorShapeRep() = default;
80
81 void ClearAllButDataType();
82
83 // We use 16 bytes to represent a TensorShape. Because we need to
84 // be able to support full 64-bit dimension sizes and an arbitrary
85 // number of dimensions for a Tensor, but most tensor dimensions are
86 // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87 // dimensions, we have several representations.
88 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91 // an out of line vector.
92 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93 // This value is not allowed in TensorShape either for format compatibility.
94 struct Rep16 {
95 uint16 dims_[6];
96 };
97 struct Rep32 {
98 uint32 dims_[3];
99 };
100 struct Rep64 {
101 gtl::InlinedVector<int64, 4>* dims_;
102 };
103
104 // We use the max value of uint16 or uint32 to represent unknown shapes, so
105 // the maximum representable valid shape in these representations is one less.
106 static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107 static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108 static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109 static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110
as16()111 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114
as16()115 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118
119 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120
121 // Since we have a convenient extra byte available, we allow the
122 // Tensor class to store an 8-bit value in this extra storage. This
123 // allows it to store the Tensor's datatype enum value here and avoid
124 // an extra word of storage.
125 friend class Tensor;
126 friend class TensorShapeTestHelper;
data_type()127 DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128 void set_data_type(DataType dt) {
129 // We only have 8 bits available to store DataType, so make sure it fits
130 DCHECK_LT(static_cast<uint32>(dt), 256u);
131 buf()[13] = static_cast<uint8>(dt);
132 }
133
134 // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135 // Bytes [0..13] vary depending on the representation.
136 // A value of 255 indicates unknown rank in the PartialTensorShape case.
137 static constexpr uint8 kUnknownRank = 255;
ndims_byte()138 uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139 void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140
tag()141 RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143
set_num_elements(int64_t n)144 void set_num_elements(int64_t n) { num_elements_ = n; }
145
146 private:
147 void DestructorOutOfLine();
148 void SlowCopyFrom(const TensorShapeRep& b);
149
buf()150 uint8* buf() { return &u_.buf[0]; }
buf()151 const uint8* buf() const { return &u_.buf[0]; }
152
153 union {
154 uint8 buf[16];
155 // Force data to be aligned enough for a pointer.
156 Rep64* unused_aligner;
157 } u_;
158 int64 num_elements_;
159 };
160
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167 public:
168 /// \brief Construct a `TensorShapeBase` from the provided sizes.
169 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170 explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171 TensorShapeBase(std::initializer_list<int64> dim_sizes)
172 : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173
174 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175 TensorShapeBase();
176
177 // TODO(mihaimaruseac): Mark this explicit in a subsequent change
178 TensorShapeBase(const TensorShapeProto& proto);
179
180 // These factory methods should be used instead of the constructors that take
181 // an array of sizes if calling code cannot validate that the sizes specify a
182 // valid `TensorShape`.
183 // The value in `*out` is valid iff the returned value is `Status::OK`.
184 static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
185 TensorShapeBase* out);
BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,TensorShapeBase * out)186 static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
187 TensorShapeBase* out) {
188 return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
189 }
190 static Status BuildTensorShapeBase(const TensorShapeProto& proto,
191 TensorShapeBase* out);
192
193 /// Returns `true` iff `proto` is a valid tensor shape.
194 // For TensorShape, the proto shape must be fully defined.
195 static bool IsValid(const TensorShapeProto& proto);
196
197 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
198 /// status otherwise.
199 static Status IsValidShape(const TensorShapeProto& proto);
200
201 /// Returns `true` iff this is a valid tensor shape.
202 bool IsValid();
203
204 /// \brief Add a dimension to the end ("inner-most").
205 /// REQUIRES: `size >= 0`
206 void AddDim(int64_t size);
207
208 /// Same as `AddDim` but returns a `Status`.
209 /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
210 Status AddDimWithStatus(int64_t size);
211
212 /// Appends all the dimensions from `shape`.
213 void AppendShape(const TensorShapeBase& shape);
214
215 /// Same as `RemoveDim` but returns a `Status`.
216 /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
217 Status AppendShapeWithStatus(const TensorShapeBase& shape);
218
219 /// \brief Insert a dimension somewhere in the `TensorShape`.
220 /// REQUIRES: `0 <= d <= dims()`
221 /// REQUIRES: `size >= 0`
222 void InsertDim(int d, int64_t size);
223
224 /// Same as `InsertDim` but returns a `Status`.
225 /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
226 /// `CHECK`-fail crashes.
227 Status InsertDimWithStatus(int d, int64_t size);
228
229 /// \brief Modifies the size of the dimension `d` to be `size`
230 /// REQUIRES: `0 <= d < dims()`
231 /// REQUIRES: `size >= 0`
232 void set_dim(int d, int64_t size);
233
234 /// Same as `set_dim` but returns a `Status`.
235 /// Use if unsure if requirements in `set_dim` are satistified, to prevent
236 /// `CHECK`-fail crashes.
237 Status SetDimWithStatus(int d, int64_t size);
238
239 /// \brief Removes dimension `d` from the `TensorShape`.
240 /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)241 void RemoveDim(int d) {
242 CHECK_GE(d, 0);
243 RemoveDimRange(d, d + 1);
244 }
245
246 /// Same as `RemoveDim` but returns a `Status`.
247 /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
RemoveDimWithStatus(int64_t d)248 Status RemoveDimWithStatus(int64_t d) {
249 if (TF_PREDICT_FALSE(d < 0)) {
250 return errors::Internal(
251 "Expected dimension index to be non-negative, got ", d);
252 }
253 return RemoveDimRangeWithStatus(d, d + 1);
254 }
255
256 /// \brief Removes last `n` dimensions from the `TensorShape`.
257 /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)258 void RemoveLastDims(int n) {
259 CHECK_LE(n, dims());
260 RemoveDimRange(dims() - n, dims());
261 }
262
263 /// Same as `RemoveLastDims` but returns a `Status`.
264 /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
RemoveLastDimsWithStatus(int64_t n)265 Status RemoveLastDimsWithStatus(int64_t n) {
266 if (TF_PREDICT_FALSE(n < dims())) {
267 return errors::Internal("Expected dimension index to be at most ", dims(),
268 " got ", n);
269 }
270 return RemoveDimRangeWithStatus(dims() - n, dims());
271 }
272
273 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
274 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
275 /// Python). The same is true for negative values of `begin`.
276 /// REQUIRES: `-(dims()+1) <= begin <= dims()`
277 /// REQUIRES: `-(dims()+1) <= end <= dims()`
278 void RemoveDimRange(int begin, int end);
279
280 /// Same as `RemoveDimRange` but returns a `Status`.
281 /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
282 /// prevent `CHECK`-fail crashes.
283 Status RemoveDimRangeWithStatus(int begin, int end);
284
285 /// Return whether the rank is unknown
unknown_rank()286 bool unknown_rank() const {
287 return kIsPartial && ndims_byte() == kUnknownRank;
288 }
289
290 /// Return the number of dimensions in the tensor.
291 /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()292 int dims() const {
293 uint8 dims = ndims_byte();
294 return kIsPartial && dims == kUnknownRank ? -1 : dims;
295 }
296
297 /// \brief Returns the number of elements in dimension `d`.
298 /// REQUIRES: `0 <= d < dims()`
299 // TODO(touts): Rename to `dimension()` to match
300 // `Eigen::Tensor::dimension()`?
301 int64 dim_size(int d) const;
302
303 /// Returns sizes of all dimensions.
304 // Returns an empty list for unknown rank PartialTensorShape.
305 gtl::InlinedVector<int64, 4> dim_sizes() const;
306
307 /// Return true iff the rank and all of the dimensions are well defined
308 // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()309 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
310
311 /// Fill `*proto` from `*this`.
312 void AsProto(TensorShapeProto* proto) const;
313
314 /// For iterating through the dimensions.
315 TensorShapeIter<Shape> begin() const;
316 TensorShapeIter<Shape> end() const;
317
318 protected:
319 // Optimized constructor for a shape representing an empty vector.
320 //
321 // This constructor is provided to optimize the default constructor for
322 // `Tensor`.
323 explicit TensorShapeBase(DataType dt);
324
325 private:
326 Status RecomputeNumElements();
327 Status InitDims(gtl::ArraySlice<int64> dim_sizes);
328
329 // True for PartialTensorShape, false for TensorShape
330 static constexpr bool kIsPartial =
331 std::is_same<Shape, PartialTensorShape>::value;
332 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333 "Shape is neither TensorShape nor PartialTensorShape");
334
335 // Used by AddDim and MakeShapeHelper. Does no error checking.
336 void UnsafeAddDim(int64_t size, int64_t new_num_elements);
337
338 // For use by TensorShapeUtils::MakeShape
339 template <class T, class S>
340 friend Status MakeShapeHelper(const T*, int64_t, S*);
341 };
342
343 /// Outputs `TensorShapeBase` to `std::ostream`.
344 template <typename Shape>
345 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346 return os << tsb.DebugString();
347 }
348
349 /// Represents the shape of a Tensor.
350 ///
351 /// A tensor's shape is denoted by its number of dimensions and a size for each
352 /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have
353 /// a shape of 2-D, [3,4].
354 ///
355 /// If you know the exact shape of your Tensor when you create the TensorShape
356 /// object, you can specify it then, or you can create a TensorShape with
357 /// zero dimensions and one element, and call AddDim() to add dimensions later.
358 class TensorShape : public TensorShapeBase<TensorShape> {
359 public:
360 using TensorShapeBase<TensorShape>::TensorShapeBase;
361
362 /// Allow a TensorShape to be used as a PartialTensorShape without copying
363 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit)
364
365 /// Returns true if `*this` and `b` have the same sizes. Ignores
366 /// dimension names.
367 bool IsSameSize(const TensorShape& b) const;
368 bool operator==(const TensorShape& b) const { return IsSameSize(b); }
369 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
370
371 /// Fill `*dsizes` from `*this`.
372 /// Notice: Using IndexType=int32 in combination with To32Bit() can
373 /// significantly improve performance on GPU.
374 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
375 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
376
377 // Same as `AsEigenDSizes()` but returns a `Status` instead.
378 // Use this method to surface error to user instead of crashing if `NDMIS` is
379 // not equal to `dims()`.
380 // Caller must take ownership of `out`.
381 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
382 Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
383
384 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
385 /// which case we pad the rest of the sizes with 1.
386 /// Notice: Using IndexType=int32 in combination with To32Bit() can
387 /// significantly improve performance on GPU.
388 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
389 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
390
391 // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
392 // Use this method to surface error to user instead of crashing if `NDMIS` is
393 // not equal to `dims()`.
394 // Caller must take ownership of `out`.
395 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
396 Status AsEigenDSizesWithPaddingWithStatus(
397 Eigen::DSizes<IndexType, NDIMS>* out) const;
398
399 private:
400 // These CHECK fail to ease debugging.
401 // REQUIRES: dims() == NDIMS
402 void CheckDimsEqual(int NDIMS) const;
403 // REQUIRES: dims() >= NDIMS
404 void CheckDimsAtLeast(int NDIMS) const;
405
406 // Fill output from `*this`.
407 // Helper method for common code between `AsEigenDSize()` and
408 // `AsEigenDSizeWithStatus()`.
409 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
410 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
411
412 // Fill output from `*this`.
413 // Helper method for common code between `AsEigenDSizesWithPadding()` and
414 // `AsEigenDSizeWithPaddingWithStatus()`.
415 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
416 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
417
418 // For access to TensorShapeBase(DataType).
419 friend class Tensor;
420 };
421
422 /// Outputs `TensorShapeBase` to `std::ostream`.
423 inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
424 return os << ts.DebugString();
425 }
426
427 /// Represents the value of one dimension in a TensorShape.
428 struct TensorShapeDim {
TensorShapeDimTensorShapeDim429 explicit TensorShapeDim(int64_t s) : size(s) {}
430 int64 size;
431 };
432
433 // START_SKIP_DOXYGEN
434 template <class Shape>
435 class TensorShapeIter {
436 public:
TensorShapeIter(const Shape * shape,int d)437 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
438 bool operator==(const TensorShapeIter& rhs) {
439 DCHECK(shape_ == rhs.shape_);
440 return d_ == rhs.d_;
441 }
442 bool operator!=(const TensorShapeIter& rhs) {
443 DCHECK(shape_ == rhs.shape_);
444 return d_ != rhs.d_;
445 }
446 void operator++() { ++d_; }
447 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
448
449 private:
450 const Shape* shape_;
451 int d_;
452 };
453 // END_SKIP_DOXYGEN
454
455 /// \brief Static helper routines for `TensorShape`. Includes a few common
456 /// predicates on a tensor shape.
457 class TensorShapeUtils {
458 public:
IsScalar(const TensorShape & shape)459 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
460
IsVector(const TensorShape & shape)461 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
462
IsVectorOrHigher(const TensorShape & shape)463 static bool IsVectorOrHigher(const TensorShape& shape) {
464 return shape.dims() >= 1;
465 }
466
IsMatrix(const TensorShape & shape)467 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
468
IsSquareMatrix(const TensorShape & shape)469 static bool IsSquareMatrix(const TensorShape& shape) {
470 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
471 }
472
IsMatrixOrHigher(const TensorShape & shape)473 static bool IsMatrixOrHigher(const TensorShape& shape) {
474 return shape.dims() >= 2;
475 }
476
477 /// \brief Returns a `TensorShape` whose dimensions are
478 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
479 static Status MakeShape(const int32* dims, int64_t n, TensorShape* out);
480 static Status MakeShape(const int64* dims, int64_t n, TensorShape* out);
481 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
482 static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
483 static Status MakeShape(const int32* dims, int64_t n,
484 PartialTensorShape* out);
485 static Status MakeShape(const int64* dims, int64_t n,
486 PartialTensorShape* out);
487 static Status MakeShape(gtl::ArraySlice<int32> shape,
488 PartialTensorShape* out);
489 static Status MakeShape(gtl::ArraySlice<int64> shape,
490 PartialTensorShape* out);
491
492 static std::string ShapeListString(
493 const gtl::ArraySlice<TensorShape>& shapes);
494
495 /// \brief Returns true iff `shape` starts with `prefix`.
496 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
497
498 /// \brief Returns true iff `shape` ends with `suffix`.
499 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
500
501 /// \brief Returns the product of values in an int64 array,
502 /// or a failing Status if the array represents a value larger than
503 /// a `TensorShape` can hold.
504 static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
505 };
506
507 /// Manages the partially known dimensions of a Tensor and their sizes.
508 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
509 public:
PartialTensorShape()510 PartialTensorShape() {}
511 using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
512
513 /// Add a dimension to the end ("inner-most"), returns a new
514 /// PartialTensorShape.
515 /// REQUIRES: `size >= -1`, where -1 means unknown
516 PartialTensorShape Concatenate(int64_t size) const;
517
518 /// Similar to `Concatenate` but returning `Status`.
519 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
520 /// are to be avoided.
521 Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const;
522
523 /// Appends all the dimensions from `shape`. Returns a new
524 /// PartialTensorShape.
525 PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
526
527 /// Similar to `Concatenate` but returning `Status`.
528 /// Use if calling code cannot validate all requirements and if `CHECK`-fails
529 /// are to be avoided.
530 Status ConcatenateWithStatus(const PartialTensorShape& shape,
531 PartialTensorShape* out) const;
532
533 /// Merges all the dimensions from `shape`. Returns
534 /// `InvalidArgument` error if either `shape` has a different rank
535 /// or if any of the dimensions are incompatible.
536 Status MergeWith(const PartialTensorShape& shape,
537 PartialTensorShape* result) const;
538
539 /// Exact equality test. Returns true iff the ranks match (i.e., both are
540 /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
541 /// both dimensions are known, or both are known and equal). This is a
542 /// stronger condition that IsCompatibleWith.
543 bool IsIdenticalTo(const PartialTensorShape& shape) const;
544
545 /// Return true iff the ranks match, and if the
546 /// dimensions all either match or one is unknown.
547 bool IsCompatibleWith(const PartialTensorShape& shape) const;
548
549 // Fill `*shape` from `*this`.
550 // If `*this` is not fully defined, returns false and
551 // `*shape` is left in an intermediate state. Otherwise
552 // returns true.
553 bool AsTensorShape(TensorShape* shape) const;
554
555 /// \brief Returns a `PartialTensorShape` whose dimensions are
556 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
557 /// considered "unknown".
558 template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)559 static Status MakePartialShape(const T* dims, int n,
560 PartialTensorShape* out) {
561 return TensorShapeUtils::MakeShape(dims, n, out);
562 }
563 };
564
565 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
566 /// common predicates on a partially known tensor shape.
567 class PartialTensorShapeUtils {
568 public:
569 static std::string PartialShapeListString(
570 const gtl::ArraySlice<PartialTensorShape>& shapes);
571
572 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
573 const gtl::ArraySlice<PartialTensorShape>& shapes1);
574
575 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
576 const gtl::ArraySlice<PartialTensorShape>& shapes1);
577 };
578
579 // ----------------------------------------------------------------------------
580 // Template method implementation details below
581 // ----------------------------------------------------------------------------
582
583 template <int NDIMS, typename IndexType>
AsEigenDSizesCopy()584 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
585 Eigen::DSizes<IndexType, NDIMS> dsizes;
586 for (int d = 0; d < NDIMS; d++) {
587 dsizes[d] = static_cast<IndexType>(dim_size(d));
588 }
589 return dsizes;
590 }
591
592 template <int NDIMS, typename IndexType>
AsEigenDSizesCopyAndPad()593 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
594 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
595 Eigen::DSizes<IndexType, NDIMS> dsizes;
596 for (int d = 0; d < dims(); d++) {
597 dsizes[d] = static_cast<IndexType>(dim_size(d));
598 }
599 for (int d = dims(); d < NDIMS; d++) {
600 dsizes[d] = 1;
601 }
602 return dsizes;
603 }
604
605 template <int NDIMS, typename IndexType>
AsEigenDSizes()606 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
607 CheckDimsEqual(NDIMS);
608 return AsEigenDSizesCopy<NDIMS, IndexType>();
609 }
610
611 template <int NDIMS, typename IndexType>
AsEigenDSizesWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)612 Status TensorShape::AsEigenDSizesWithStatus(
613 Eigen::DSizes<IndexType, NDIMS>* out) const {
614 if (TF_PREDICT_FALSE(NDIMS != dims())) {
615 return errors::Internal("Asking for tensor of ", NDIMS,
616 " dimensions from a tensor of ", dims(),
617 " dimensions");
618 }
619 *out = AsEigenDSizesCopy<NDIMS, IndexType>();
620 }
621
622 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()623 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
624 CheckDimsAtLeast(NDIMS);
625 return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
626 }
627
628 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPaddingWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)629 Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
630 Eigen::DSizes<IndexType, NDIMS>* out) const {
631 if (TF_PREDICT_FALSE(NDIMS < dims())) {
632 return errors::Internal("Asking for tensor of at least ", NDIMS,
633 " dimensions from a tensor of ", dims(),
634 " dimensions");
635 }
636 *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
637 }
638
639 // ----------------------------------------------------------------------------
640 // Inlining of some performance critical routines
641 // ----------------------------------------------------------------------------
642
TensorShapeRep(const TensorShapeRep & b)643 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
644 num_elements_ = b.num_elements_;
645 if (b.tag() != REP_OUT_OF_LINE) {
646 memcpy(buf(), b.buf(), sizeof(u_.buf));
647 // memcpy above Implicitly does:
648 // set_ndims_byte(b.ndims_byte());
649 // set_tag(b.tag());
650 } else {
651 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate
652 SlowCopyFrom(b);
653 }
654 }
655
TensorShapeRep(TensorShapeRep && b)656 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
657 num_elements_ = b.num_elements_;
658 memcpy(buf(), b.buf(), sizeof(u_.buf));
659 // memcpy above Implicitly does:
660 // set_ndims_byte(b.ndims_byte());
661 // set_tag(b.tag());
662 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
663 }
664
~TensorShapeRep()665 inline TensorShapeRep::~TensorShapeRep() {
666 if (tag() == REP_OUT_OF_LINE) {
667 DestructorOutOfLine();
668 }
669 }
670
671 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
672 num_elements_ = b.num_elements_;
673 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
674 memcpy(buf(), b.buf(), sizeof(u_.buf));
675 // memcpy above implicitly also does:
676 // set_tag(b.tag());
677 // set_ndims_byte(b.ndims_byte());
678 } else {
679 SlowCopyFrom(b);
680 }
681 }
682
683 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
684 if (tag() == REP_OUT_OF_LINE) {
685 DestructorOutOfLine();
686 }
687 num_elements_ = b.num_elements_;
688 memcpy(buf(), b.buf(), sizeof(u_.buf));
689 // memcpy above Implicitly does:
690 // set_ndims_byte(b.ndims_byte());
691 // set_tag(b.tag());
692 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
693 }
694
695 inline TensorShape::operator const PartialTensorShape&() const {
696 // Downcast to the shared representation and upcast to PartialTensorShape
697 const TensorShapeRep* rep = this;
698 return *static_cast<const PartialTensorShape*>(rep);
699 }
700
701 template <class Shape>
TensorShapeBase(DataType dt)702 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
703 set_tag(REP16);
704 set_data_type(dt);
705
706 // Optimized implementation of InitDims() where the shape is statically known
707 // to be {0}.
708 set_ndims_byte(1);
709 uint16* dst = as16()->dims_;
710 *dst = 0;
711 set_num_elements(0);
712 }
713
714 // Declare explicit instantiations in .cc file
715 extern template class TensorShapeBase<TensorShape>;
716 extern template class TensorShapeBase<PartialTensorShape>;
717
718 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
719 // often used in shape inference.
720 struct DtypeAndPartialTensorShape {
721 DataType dtype;
722 PartialTensorShape shape;
723 };
724
725 } // namespace tensorflow
726
727 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
728