• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18 
19 #include <string>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 
33 namespace tensorflow {
34 
35 // START_SKIP_DOXYGEN
36 template <class Shape>
37 class TensorShapeIter;
38 class TensorShape;
39 class TensorShapeProto;
40 class PartialTensorShape;
41 // END_SKIP_DOXYGEN
42 
43 /// Internal representation for both TensorShape and PartialTensorShape.
44 class TensorShapeRep {
45  public:
46   ~TensorShapeRep();
47 
48   /// Copy the specified shape
49   TensorShapeRep(const TensorShapeRep& b);
50   void operator=(const TensorShapeRep& b);
51 
52   /// Move the specified shape.  After moving, `b` is safe for destruction and
53   // can be reassigned into, but its dimensions and number of elements can be
54   // nonsensical (e.g., negative dimension sizes, or number of elements not
55   // properly recomputed).
56   TensorShapeRep(TensorShapeRep&& b);
57   void operator=(TensorShapeRep&& b);
58 
59   /// Clear a tensor shape, producing the scalar shape.
60   void Clear();
61 
62   // Maximum number of dimensions in a tensor.
63   // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()64   static constexpr int MaxDimensions() { return 254; }
65 
66   /// \brief Returns the number of elements in the tensor.
67   ///
68   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
69   /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
70   /// defined.
num_elements()71   int64 num_elements() const { return num_elements_; }
72 
73   /// For error messages.
74   std::string DebugString() const;
75   static std::string DebugString(const TensorShapeProto& proto);
76 
77  protected:
78   // Constructable only via TensorShapeBase
79   TensorShapeRep() = default;
80 
81   void ClearAllButDataType();
82 
83   // We use 16 bytes to represent a TensorShape.  Because we need to
84   // be able to support full 64-bit dimension sizes and an arbitrary
85   // number of dimensions for a Tensor, but most tensor dimensions are
86   // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87   // dimensions, we have several representations.
88   // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89   // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90   // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91   //        an out of line vector.
92   // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93   // This value is not allowed in TensorShape either for format compatibility.
94   struct Rep16 {
95     uint16 dims_[6];
96   };
97   struct Rep32 {
98     uint32 dims_[3];
99   };
100   struct Rep64 {
101     gtl::InlinedVector<int64, 4>* dims_;
102   };
103 
104   // We use the max value of uint16 or uint32 to represent unknown shapes, so
105   // the maximum representable valid shape in these representations is one less.
106   static constexpr int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107   static constexpr int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108   static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109   static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110 
as16()111   Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112   Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113   Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114 
as16()115   const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116   const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117   const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118 
119   enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120 
121   // Since we have a convenient extra byte available, we allow the
122   // Tensor class to store an 8-bit value in this extra storage.  This
123   // allows it to store the Tensor's datatype enum value here and avoid
124   // an extra word of storage.
125   friend class Tensor;
126   friend class TensorShapeTestHelper;
data_type()127   DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128   void set_data_type(DataType dt) {
129     // We only have 8 bits available to store DataType, so make sure it fits
130     DCHECK_LT(static_cast<uint32>(dt), 256u);
131     buf()[13] = static_cast<uint8>(dt);
132   }
133 
134   // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135   // Bytes [0..13] vary depending on the representation.
136   // A value of 255 indicates unknown rank in the PartialTensorShape case.
137   static constexpr uint8 kUnknownRank = 255;
ndims_byte()138   uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139   void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140 
tag()141   RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142   void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143 
set_num_elements(int64 n)144   void set_num_elements(int64 n) { num_elements_ = n; }
145 
146  private:
147   void DestructorOutOfLine();
148   void SlowCopyFrom(const TensorShapeRep& b);
149 
buf()150   uint8* buf() { return &u_.buf[0]; }
buf()151   const uint8* buf() const { return &u_.buf[0]; }
152 
153   union {
154     uint8 buf[16];
155     // Force data to be aligned enough for a pointer.
156     Rep64* unused_aligner;
157   } u_;
158   int64 num_elements_;
159 };
160 
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167  public:
168   /// \brief Construct a `TensorShapeBase` from the provided sizes.
169   /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170   explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171   TensorShapeBase(std::initializer_list<int64> dim_sizes)
172       : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173 
174   /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175   TensorShapeBase();
176 
177   // TODO(mihaimaruseac): Mark this explicit in a subsequent change
178   TensorShapeBase(const TensorShapeProto& proto);
179 
180   // These factory methods should be used instead of the constructors that take
181   // an array of sizes if calling code cannot validate that the sizes specify a
182   // valid `TensorShape`.
183   // The value in `*out` is valid iff the returned value is `Status::OK`.
184   static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
185                                      TensorShapeBase* out);
BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,TensorShapeBase * out)186   static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
187                                      TensorShapeBase* out) {
188     return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
189   }
190   static Status BuildTensorShapeBase(const TensorShapeProto& proto,
191                                      TensorShapeBase* out);
192 
193   /// Returns `true` iff `proto` is a valid tensor shape.
194   // For TensorShape, the proto shape must be fully defined.
195   static bool IsValid(const TensorShapeProto& proto);
196 
197   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
198   /// status otherwise.
199   static Status IsValidShape(const TensorShapeProto& proto);
200 
201   /// Returns `true` iff this is a valid tensor shape.
202   bool IsValid();
203 
204   /// \brief Add a dimension to the end ("inner-most").
205   /// REQUIRES: `size >= 0`
206   void AddDim(int64 size);
207 
208   /// Same as `AddDim` but returns a `Status`.
209   /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
210   Status AddDimWithStatus(int64 size);
211 
212   /// Appends all the dimensions from `shape`.
213   void AppendShape(const TensorShapeBase& shape);
214 
215   /// Same as `RemoveDim` but returns a `Status`.
216   /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
217   Status AppendShapeWithStatus(const TensorShapeBase& shape);
218 
219   /// \brief Insert a dimension somewhere in the `TensorShape`.
220   /// REQUIRES: `0 <= d <= dims()`
221   /// REQUIRES: `size >= 0`
222   void InsertDim(int d, int64 size);
223 
224   /// Same as `InsertDim` but returns a `Status`.
225   /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
226   /// `CHECK`-fail crashes.
227   Status InsertDimWithStatus(int d, int64 size);
228 
229   /// \brief Modifies the size of the dimension `d` to be `size`
230   /// REQUIRES: `0 <= d < dims()`
231   /// REQUIRES: `size >= 0`
232   void set_dim(int d, int64 size);
233 
234   /// Same as `set_dim` but returns a `Status`.
235   /// Use if unsure if requirements in `set_dim` are satistified, to prevent
236   /// `CHECK`-fail crashes.
237   Status SetDimWithStatus(int d, int64 size);
238 
239   /// \brief Removes dimension `d` from the `TensorShape`.
240   /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)241   void RemoveDim(int d) {
242     CHECK_GE(d, 0);
243     RemoveDimRange(d, d + 1);
244   }
245 
246   /// Same as `RemoveDim` but returns a `Status`.
247   /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
RemoveDimWithStatus(int64 d)248   Status RemoveDimWithStatus(int64 d) {
249     if (TF_PREDICT_FALSE(d < 0)) {
250       return errors::Internal(
251           "Expected dimension index to be non-negative, got ", d);
252     }
253     return RemoveDimRangeWithStatus(d, d + 1);
254   }
255 
256   /// \brief Removes last `n` dimensions from the `TensorShape`.
257   /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)258   void RemoveLastDims(int n) {
259     CHECK_LE(n, dims());
260     RemoveDimRange(dims() - n, dims());
261   }
262 
263   /// Same as `RemoveLastDims` but returns a `Status`.
264   /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
RemoveLastDimsWithStatus(int64 n)265   Status RemoveLastDimsWithStatus(int64 n) {
266     if (TF_PREDICT_FALSE(n < dims())) {
267       return errors::Internal("Expected dimension index to be at most ", dims(),
268                               " got ", n);
269     }
270     return RemoveDimRangeWithStatus(dims() - n, dims());
271   }
272 
273   /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
274   /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
275   /// Python). The same is true for negative values of `begin`.
276   /// REQUIRES: `-(dims()+1) <= begin <= dims()`
277   /// REQUIRES: `-(dims()+1) <= end <= dims()`
278   void RemoveDimRange(int begin, int end);
279 
280   /// Same as `RemoveDimRange` but returns a `Status`.
281   /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
282   /// prevent `CHECK`-fail crashes.
283   Status RemoveDimRangeWithStatus(int begin, int end);
284 
285   /// Return whether the rank is unknown
unknown_rank()286   bool unknown_rank() const {
287     return kIsPartial && ndims_byte() == kUnknownRank;
288   }
289 
290   /// Return the number of dimensions in the tensor.
291   /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()292   int dims() const {
293     uint8 dims = ndims_byte();
294     return kIsPartial && dims == kUnknownRank ? -1 : dims;
295   }
296 
297   /// \brief Returns the number of elements in dimension `d`.
298   /// REQUIRES: `0 <= d < dims()`
299   // TODO(touts): Rename to `dimension()` to match
300   // `Eigen::Tensor::dimension()`?
301   int64 dim_size(int d) const;
302 
303   /// Returns sizes of all dimensions.
304   // Returns an empty list for unknown rank PartialTensorShape.
305   gtl::InlinedVector<int64, 4> dim_sizes() const;
306 
307   /// Return true iff the rank and all of the dimensions are well defined
308   // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()309   bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
310 
311   /// Fill `*proto` from `*this`.
312   void AsProto(TensorShapeProto* proto) const;
313 
314   /// For iterating through the dimensions.
315   TensorShapeIter<Shape> begin() const;
316   TensorShapeIter<Shape> end() const;
317 
318  protected:
319   // Optimized constructor for a shape representing an empty vector.
320   //
321   // This constructor is provided to optimize the default constructor for
322   // `Tensor`.
323   explicit TensorShapeBase(DataType dt);
324 
325  private:
326   Status RecomputeNumElements();
327   Status InitDims(gtl::ArraySlice<int64> dim_sizes);
328 
329   // True for PartialTensorShape, false for TensorShape
330   static constexpr bool kIsPartial =
331       std::is_same<Shape, PartialTensorShape>::value;
332   static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333                 "Shape is neither TensorShape nor PartialTensorShape");
334 
335   // Used by AddDim and MakeShapeHelper.  Does no error checking.
336   void UnsafeAddDim(int64 size, int64 new_num_elements);
337 
338   // For use by TensorShapeUtils::MakeShape
339   template <class T, class S>
340   friend Status MakeShapeHelper(const T*, int64, S*);
341 };
342 
343 /// Outputs `TensorShapeBase` to `std::ostream`.
344 template <typename Shape>
345 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346   return os << tsb.DebugString();
347 }
348 
349 /// Represents the shape of a Tensor.
350 ///
351 /// A tensor's shape is denoted by its number of dimensions and a size for each
352 /// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
353 /// a shape of 2-D, [3,4].
354 ///
355 /// If you know the exact shape of your Tensor when you create the TensorShape
356 /// object, you can specify it then, or you can create a TensorShape with
357 /// zero dimensions and one element, and call AddDim() to add dimensions later.
358 class TensorShape : public TensorShapeBase<TensorShape> {
359  public:
360   using TensorShapeBase<TensorShape>::TensorShapeBase;
361 
362   /// Allow a TensorShape to be used as a PartialTensorShape without copying
363   operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
364 
365   /// Returns true if `*this` and `b` have the same sizes. Ignores
366   /// dimension names.
367   bool IsSameSize(const TensorShape& b) const;
368   bool operator==(const TensorShape& b) const { return IsSameSize(b); }
369   bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
370 
371   /// Fill `*dsizes` from `*this`.
372   /// Notice: Using IndexType=int32 in combination with To32Bit() can
373   /// significantly improve performance on GPU.
374   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
375   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
376 
377   // Same as `AsEigenDSizes()` but returns a `Status` instead.
378   // Use this method to surface error to user instead of crashing if `NDMIS` is
379   // not equal to `dims()`.
380   // Caller must take ownership of `out`.
381   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
382   Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
383 
384   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
385   /// which case we pad the rest of the sizes with 1.
386   /// Notice: Using IndexType=int32 in combination with To32Bit() can
387   /// significantly improve performance on GPU.
388   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
389   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
390 
391   // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
392   // Use this method to surface error to user instead of crashing if `NDMIS` is
393   // not equal to `dims()`.
394   // Caller must take ownership of `out`.
395   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
396   Status AsEigenDSizesWithPaddingWithStatus(
397       Eigen::DSizes<IndexType, NDIMS>* out) const;
398 
399  private:
400   // These CHECK fail to ease debugging.
401   // REQUIRES: dims() == NDIMS
402   void CheckDimsEqual(int NDIMS) const;
403   // REQUIRES: dims() >= NDIMS
404   void CheckDimsAtLeast(int NDIMS) const;
405 
406   // Fill output from `*this`.
407   // Helper method for common code between `AsEigenDSize()` and
408   // `AsEigenDSizeWithStatus()`.
409   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
410   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
411 
412   // Fill output from `*this`.
413   // Helper method for common code between `AsEigenDSizesWithPadding()` and
414   // `AsEigenDSizeWithPaddingWithStatus()`.
415   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
416   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
417 
418   // For access to TensorShapeBase(DataType).
419   friend class Tensor;
420 };
421 
422 /// Outputs `TensorShapeBase` to `std::ostream`.
423 inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
424   return os << ts.DebugString();
425 }
426 
427 /// Represents the value of one dimension in a TensorShape.
428 struct TensorShapeDim {
TensorShapeDimTensorShapeDim429   explicit TensorShapeDim(int64 s) : size(s) {}
430   int64 size;
431 };
432 
433 // START_SKIP_DOXYGEN
434 template <class Shape>
435 class TensorShapeIter {
436  public:
TensorShapeIter(const Shape * shape,int d)437   TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
438   bool operator==(const TensorShapeIter& rhs) {
439     DCHECK(shape_ == rhs.shape_);
440     return d_ == rhs.d_;
441   }
442   bool operator!=(const TensorShapeIter& rhs) {
443     DCHECK(shape_ == rhs.shape_);
444     return d_ != rhs.d_;
445   }
446   void operator++() { ++d_; }
447   TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
448 
449  private:
450   const Shape* shape_;
451   int d_;
452 };
453 // END_SKIP_DOXYGEN
454 
455 /// \brief Static helper routines for `TensorShape`. Includes a few common
456 /// predicates on a tensor shape.
457 class TensorShapeUtils {
458  public:
IsScalar(const TensorShape & shape)459   static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
460 
IsVector(const TensorShape & shape)461   static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
462 
IsVectorOrHigher(const TensorShape & shape)463   static bool IsVectorOrHigher(const TensorShape& shape) {
464     return shape.dims() >= 1;
465   }
466 
IsMatrix(const TensorShape & shape)467   static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
468 
IsSquareMatrix(const TensorShape & shape)469   static bool IsSquareMatrix(const TensorShape& shape) {
470     return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
471   }
472 
IsMatrixOrHigher(const TensorShape & shape)473   static bool IsMatrixOrHigher(const TensorShape& shape) {
474     return shape.dims() >= 2;
475   }
476 
477   /// \brief Returns a `TensorShape` whose dimensions are
478   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
479   static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
480   static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
481   static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
482   static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
483   static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
484   static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
485   static Status MakeShape(gtl::ArraySlice<int32> shape,
486                           PartialTensorShape* out);
487   static Status MakeShape(gtl::ArraySlice<int64> shape,
488                           PartialTensorShape* out);
489 
490   static std::string ShapeListString(
491       const gtl::ArraySlice<TensorShape>& shapes);
492 
493   /// \brief Returns true iff `shape` starts with `prefix`.
494   static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
495 
496   /// \brief Returns true iff `shape` ends with `suffix`.
497   static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
498 
499   /// \brief Returns the product of values in an int64 array,
500   /// or a failing Status if the array represents a value larger than
501   /// a `TensorShape` can hold.
502   static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
503 };
504 
505 /// Manages the partially known dimensions of a Tensor and their sizes.
506 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
507  public:
PartialTensorShape()508   PartialTensorShape() {}
509   using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
510 
511   /// Add a dimension to the end ("inner-most"), returns a new
512   /// PartialTensorShape.
513   /// REQUIRES: `size >= -1`, where -1 means unknown
514   PartialTensorShape Concatenate(int64 size) const;
515 
516   /// Similar to `Concatenate` but returning `Status`.
517   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
518   /// are to be avoided.
519   Status ConcatenateWithStatus(int64 size, PartialTensorShape* out) const;
520 
521   /// Appends all the dimensions from `shape`.  Returns a new
522   /// PartialTensorShape.
523   PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
524 
525   /// Similar to `Concatenate` but returning `Status`.
526   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
527   /// are to be avoided.
528   Status ConcatenateWithStatus(const PartialTensorShape& shape,
529                                PartialTensorShape* out) const;
530 
531   /// Merges all the dimensions from `shape`.  Returns
532   /// `InvalidArgument` error if either `shape` has a different rank
533   /// or if any of the dimensions are incompatible.
534   Status MergeWith(const PartialTensorShape& shape,
535                    PartialTensorShape* result) const;
536 
537   /// Exact equality test. Returns true iff the ranks match (i.e., both are
538   /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
539   /// both dimensions are known, or both are known and equal). This is a
540   /// stronger condition that IsCompatibleWith.
541   bool IsIdenticalTo(const PartialTensorShape& shape) const;
542 
543   /// Return true iff the ranks match, and if the
544   /// dimensions all either match or one is unknown.
545   bool IsCompatibleWith(const PartialTensorShape& shape) const;
546 
547   // Fill `*shape` from `*this`.
548   // If `*this` is not fully defined, returns false and
549   // `*shape` is left in an intermediate state.  Otherwise
550   // returns true.
551   bool AsTensorShape(TensorShape* shape) const;
552 
553   /// \brief Returns a `PartialTensorShape` whose dimensions are
554   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
555   /// considered "unknown".
556   template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)557   static Status MakePartialShape(const T* dims, int n,
558                                  PartialTensorShape* out) {
559     return TensorShapeUtils::MakeShape(dims, n, out);
560   }
561 };
562 
563 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
564 /// common predicates on a partially known tensor shape.
565 class PartialTensorShapeUtils {
566  public:
567   static std::string PartialShapeListString(
568       const gtl::ArraySlice<PartialTensorShape>& shapes);
569 
570   static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
571                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
572 
573   static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
574                             const gtl::ArraySlice<PartialTensorShape>& shapes1);
575 };
576 
577 // ----------------------------------------------------------------------------
578 // Template method implementation details below
579 // ----------------------------------------------------------------------------
580 
581 template <int NDIMS, typename IndexType>
AsEigenDSizesCopy()582 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
583   Eigen::DSizes<IndexType, NDIMS> dsizes;
584   for (int d = 0; d < NDIMS; d++) {
585     dsizes[d] = static_cast<IndexType>(dim_size(d));
586   }
587   return dsizes;
588 }
589 
590 template <int NDIMS, typename IndexType>
AsEigenDSizesCopyAndPad()591 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
592   static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
593   Eigen::DSizes<IndexType, NDIMS> dsizes;
594   for (int d = 0; d < dims(); d++) {
595     dsizes[d] = static_cast<IndexType>(dim_size(d));
596   }
597   for (int d = dims(); d < NDIMS; d++) {
598     dsizes[d] = 1;
599   }
600   return dsizes;
601 }
602 
603 template <int NDIMS, typename IndexType>
AsEigenDSizes()604 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
605   CheckDimsEqual(NDIMS);
606   return AsEigenDSizesCopy<NDIMS, IndexType>();
607 }
608 
609 template <int NDIMS, typename IndexType>
AsEigenDSizesWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)610 Status TensorShape::AsEigenDSizesWithStatus(
611     Eigen::DSizes<IndexType, NDIMS>* out) const {
612   if (TF_PREDICT_FALSE(NDIMS != dims())) {
613     return errors::Internal("Asking for tensor of ", NDIMS,
614                             " dimensions from a tensor of ", dims(),
615                             " dimensions");
616   }
617   *out = AsEigenDSizesCopy<NDIMS, IndexType>();
618 }
619 
620 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()621 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
622   CheckDimsAtLeast(NDIMS);
623   return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
624 }
625 
626 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPaddingWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)627 Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
628     Eigen::DSizes<IndexType, NDIMS>* out) const {
629   if (TF_PREDICT_FALSE(NDIMS < dims())) {
630     return errors::Internal("Asking for tensor of at least ", NDIMS,
631                             " dimensions from a tensor of ", dims(),
632                             " dimensions");
633   }
634   *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
635 }
636 
637 // ----------------------------------------------------------------------------
638 // Inlining of some performance critical routines
639 // ----------------------------------------------------------------------------
640 
TensorShapeRep(const TensorShapeRep & b)641 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
642   num_elements_ = b.num_elements_;
643   if (b.tag() != REP_OUT_OF_LINE) {
644     memcpy(buf(), b.buf(), sizeof(u_.buf));
645     // memcpy above Implicitly does:
646     //   set_ndims_byte(b.ndims_byte());
647     //   set_tag(b.tag());
648   } else {
649     set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
650     SlowCopyFrom(b);
651   }
652 }
653 
TensorShapeRep(TensorShapeRep && b)654 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
655   num_elements_ = b.num_elements_;
656   memcpy(buf(), b.buf(), sizeof(u_.buf));
657   // memcpy above Implicitly does:
658   //   set_ndims_byte(b.ndims_byte());
659   //   set_tag(b.tag());
660   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
661 }
662 
~TensorShapeRep()663 inline TensorShapeRep::~TensorShapeRep() {
664   if (tag() == REP_OUT_OF_LINE) {
665     DestructorOutOfLine();
666   }
667 }
668 
669 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
670   num_elements_ = b.num_elements_;
671   if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
672     memcpy(buf(), b.buf(), sizeof(u_.buf));
673     // memcpy above implicitly also does:
674     //   set_tag(b.tag());
675     //   set_ndims_byte(b.ndims_byte());
676   } else {
677     SlowCopyFrom(b);
678   }
679 }
680 
681 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
682   if (tag() == REP_OUT_OF_LINE) {
683     DestructorOutOfLine();
684   }
685   num_elements_ = b.num_elements_;
686   memcpy(buf(), b.buf(), sizeof(u_.buf));
687   // memcpy above Implicitly does:
688   //   set_ndims_byte(b.ndims_byte());
689   //   set_tag(b.tag());
690   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
691 }
692 
693 inline TensorShape::operator const PartialTensorShape&() const {
694   // Downcast to the shared representation and upcast to PartialTensorShape
695   const TensorShapeRep* rep = this;
696   return *static_cast<const PartialTensorShape*>(rep);
697 }
698 
699 template <class Shape>
TensorShapeBase(DataType dt)700 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
701   set_tag(REP16);
702   set_data_type(dt);
703 
704   // Optimized implementation of InitDims() where the shape is statically known
705   // to be {0}.
706   set_ndims_byte(1);
707   uint16* dst = as16()->dims_;
708   *dst = 0;
709   set_num_elements(0);
710 }
711 
712 // Declare explicit instantiations in .cc file
713 extern template class TensorShapeBase<TensorShape>;
714 extern template class TensorShapeBase<PartialTensorShape>;
715 
716 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
717 // often used in shape inference.
718 struct DtypeAndPartialTensorShape {
719   DataType dtype;
720   PartialTensorShape shape;
721 };
722 
723 }  // namespace tensorflow
724 
725 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
726