• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18 
19 #include <string>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 
33 namespace tensorflow {
34 
35 // START_SKIP_DOXYGEN
36 template <class Shape>
37 class TensorShapeIter;
38 class TensorShape;
39 class TensorShapeProto;
40 class PartialTensorShape;
41 // END_SKIP_DOXYGEN
42 
43 /// Internal representation for both TensorShape and PartialTensorShape.
44 class TensorShapeRep {
45  public:
46   ~TensorShapeRep();
47 
48   /// Copy the specified shape
49   TensorShapeRep(const TensorShapeRep& b);
50   void operator=(const TensorShapeRep& b);
51 
52   /// Move the specified shape.  After moving, `b` is safe for destruction and
53   // can be reassigned into, but its dimensions and number of elements can be
54   // nonsensical (e.g., negative dimension sizes, or number of elements not
55   // properly recomputed).
56   TensorShapeRep(TensorShapeRep&& b);
57   void operator=(TensorShapeRep&& b);
58 
59   /// Clear a tensor shape, producing the scalar shape.
60   void Clear();
61 
62   // Maximum number of dimensions in a tensor.
63   // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()64   static constexpr int MaxDimensions() { return 254; }
65 
66   /// \brief Returns the number of elements in the tensor.
67   ///
68   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
69   /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
70   /// defined.
num_elements()71   int64 num_elements() const { return num_elements_; }
72 
73   /// For error messages.
74   std::string DebugString() const;
75   static std::string DebugString(const TensorShapeProto& proto);
76 
77  protected:
78   // Constructable only via TensorShapeBase
79   TensorShapeRep() = default;
80 
81   void ClearAllButDataType();
82 
83   // We use 16 bytes to represent a TensorShape.  Because we need to
84   // be able to support full 64-bit dimension sizes and an arbitrary
85   // number of dimensions for a Tensor, but most tensor dimensions are
86   // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87   // dimensions, we have several representations.
88   // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89   // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90   // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91   //        an out of line vector.
92   // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93   // This value is not allowed in TensorShape either for format compatibility.
94   struct Rep16 {
95     uint16 dims_[6];
96   };
97   struct Rep32 {
98     uint32 dims_[3];
99   };
100   struct Rep64 {
101     gtl::InlinedVector<int64, 4>* dims_;
102   };
103 
104   // We use the max value of uint16 or uint32 to represent unknown shapes, so
105   // the maximum representable valid shape in these representations is one less.
106   static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107   static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108   static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109   static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110 
as16()111   Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112   Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113   Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114 
as16()115   const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116   const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117   const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118 
119   enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120 
121   // Since we have a convenient extra byte available, we allow the
122   // Tensor class to store an 8-bit value in this extra storage.  This
123   // allows it to store the Tensor's datatype enum value here and avoid
124   // an extra word of storage.
125   friend class Tensor;
126   friend class TensorShapeTestHelper;
data_type()127   DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128   void set_data_type(DataType dt) {
129     // We only have 8 bits available to store DataType, so make sure it fits
130     DCHECK_LT(static_cast<uint32>(dt), 256u);
131     buf()[13] = static_cast<uint8>(dt);
132   }
133 
134   // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135   // Bytes [0..13] vary depending on the representation.
136   // A value of 255 indicates unknown rank in the PartialTensorShape case.
137   static constexpr uint8 kUnknownRank = 255;
ndims_byte()138   uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139   void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140 
tag()141   RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142   void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143 
set_num_elements(int64_t n)144   void set_num_elements(int64_t n) { num_elements_ = n; }
145 
146  private:
147   void DestructorOutOfLine();
148   void SlowCopyFrom(const TensorShapeRep& b);
149 
buf()150   uint8* buf() { return &u_.buf[0]; }
buf()151   const uint8* buf() const { return &u_.buf[0]; }
152 
153   union {
154     uint8 buf[16];
155     // Force data to be aligned enough for a pointer.
156     Rep64* unused_aligner;
157   } u_;
158   int64 num_elements_;
159 };
160 
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167  public:
168   /// \brief Construct a `TensorShapeBase` from the provided sizes.
169   /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170   explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171   TensorShapeBase(std::initializer_list<int64> dim_sizes)
172       : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173 
174   /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175   TensorShapeBase();
176 
177   // TODO(mihaimaruseac): Mark this explicit in a subsequent change
178   TensorShapeBase(const TensorShapeProto& proto);
179 
180   // These factory methods should be used instead of the constructors that take
181   // an array of sizes if calling code cannot validate that the sizes specify a
182   // valid `TensorShape`.
183   // The value in `*out` is valid iff the returned value is `Status::OK`.
184   static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
185                                      TensorShapeBase* out);
BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,TensorShapeBase * out)186   static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
187                                      TensorShapeBase* out) {
188     return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
189   }
190   static Status BuildTensorShapeBase(const TensorShapeProto& proto,
191                                      TensorShapeBase* out);
192 
193   /// Returns `true` iff `proto` is a valid tensor shape.
194   // For TensorShape, the proto shape must be fully defined.
195   static bool IsValid(const TensorShapeProto& proto);
196 
197   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
198   /// status otherwise.
199   static Status IsValidShape(const TensorShapeProto& proto);
200 
201   /// Returns `true` iff this is a valid tensor shape.
202   bool IsValid();
203 
204   /// \brief Add a dimension to the end ("inner-most").
205   /// REQUIRES: `size >= 0`
206   void AddDim(int64_t size);
207 
208   /// Same as `AddDim` but returns a `Status`.
209   /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
210   Status AddDimWithStatus(int64_t size);
211 
212   /// Appends all the dimensions from `shape`.
213   void AppendShape(const TensorShapeBase& shape);
214 
215   /// Same as `RemoveDim` but returns a `Status`.
216   /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
217   Status AppendShapeWithStatus(const TensorShapeBase& shape);
218 
219   /// \brief Insert a dimension somewhere in the `TensorShape`.
220   /// REQUIRES: `0 <= d <= dims()`
221   /// REQUIRES: `size >= 0`
222   void InsertDim(int d, int64_t size);
223 
224   /// Same as `InsertDim` but returns a `Status`.
225   /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
226   /// `CHECK`-fail crashes.
227   Status InsertDimWithStatus(int d, int64_t size);
228 
229   /// \brief Modifies the size of the dimension `d` to be `size`
230   /// REQUIRES: `0 <= d < dims()`
231   /// REQUIRES: `size >= 0`
232   void set_dim(int d, int64_t size);
233 
234   /// Same as `set_dim` but returns a `Status`.
235   /// Use if unsure if requirements in `set_dim` are satistified, to prevent
236   /// `CHECK`-fail crashes.
237   Status SetDimWithStatus(int d, int64_t size);
238 
239   /// \brief Removes dimension `d` from the `TensorShape`.
240   /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)241   void RemoveDim(int d) {
242     CHECK_GE(d, 0);
243     RemoveDimRange(d, d + 1);
244   }
245 
246   /// Same as `RemoveDim` but returns a `Status`.
247   /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
RemoveDimWithStatus(int64_t d)248   Status RemoveDimWithStatus(int64_t d) {
249     if (TF_PREDICT_FALSE(d < 0)) {
250       return errors::Internal(
251           "Expected dimension index to be non-negative, got ", d);
252     }
253     return RemoveDimRangeWithStatus(d, d + 1);
254   }
255 
256   /// \brief Removes last `n` dimensions from the `TensorShape`.
257   /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)258   void RemoveLastDims(int n) {
259     CHECK_LE(n, dims());
260     RemoveDimRange(dims() - n, dims());
261   }
262 
263   /// Same as `RemoveLastDims` but returns a `Status`.
264   /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
RemoveLastDimsWithStatus(int64_t n)265   Status RemoveLastDimsWithStatus(int64_t n) {
266     if (TF_PREDICT_FALSE(n < dims())) {
267       return errors::Internal("Expected dimension index to be at most ", dims(),
268                               " got ", n);
269     }
270     return RemoveDimRangeWithStatus(dims() - n, dims());
271   }
272 
273   /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
274   /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
275   /// Python). The same is true for negative values of `begin`.
276   /// REQUIRES: `-(dims()+1) <= begin <= dims()`
277   /// REQUIRES: `-(dims()+1) <= end <= dims()`
278   void RemoveDimRange(int begin, int end);
279 
280   /// Same as `RemoveDimRange` but returns a `Status`.
281   /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
282   /// prevent `CHECK`-fail crashes.
283   Status RemoveDimRangeWithStatus(int begin, int end);
284 
285   /// Return whether the rank is unknown
unknown_rank()286   bool unknown_rank() const {
287     return kIsPartial && ndims_byte() == kUnknownRank;
288   }
289 
290   /// Return the number of dimensions in the tensor.
291   /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()292   int dims() const {
293     uint8 dims = ndims_byte();
294     return kIsPartial && dims == kUnknownRank ? -1 : dims;
295   }
296 
297   /// \brief Returns the number of elements in dimension `d`.
298   /// REQUIRES: `0 <= d < dims()`
299   // TODO(touts): Rename to `dimension()` to match
300   // `Eigen::Tensor::dimension()`?
301   int64 dim_size(int d) const;
302 
303   /// Returns sizes of all dimensions.
304   // Returns an empty list for unknown rank PartialTensorShape.
305   gtl::InlinedVector<int64, 4> dim_sizes() const;
306 
307   /// Return true iff the rank and all of the dimensions are well defined
308   // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()309   bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
310 
311   /// Fill `*proto` from `*this`.
312   void AsProto(TensorShapeProto* proto) const;
313 
314   /// For iterating through the dimensions.
315   TensorShapeIter<Shape> begin() const;
316   TensorShapeIter<Shape> end() const;
317 
318  protected:
319   // Optimized constructor for a shape representing an empty vector.
320   //
321   // This constructor is provided to optimize the default constructor for
322   // `Tensor`.
323   explicit TensorShapeBase(DataType dt);
324 
325  private:
326   Status RecomputeNumElements();
327   Status InitDims(gtl::ArraySlice<int64> dim_sizes);
328 
329   // True for PartialTensorShape, false for TensorShape
330   static constexpr bool kIsPartial =
331       std::is_same<Shape, PartialTensorShape>::value;
332   static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333                 "Shape is neither TensorShape nor PartialTensorShape");
334 
335   // Used by AddDim and MakeShapeHelper.  Does no error checking.
336   void UnsafeAddDim(int64_t size, int64_t new_num_elements);
337 
338   // For use by TensorShapeUtils::MakeShape
339   template <class T, class S>
340   friend Status MakeShapeHelper(const T*, int64_t, S*);
341 };
342 
343 /// Outputs `TensorShapeBase` to `std::ostream`.
344 template <typename Shape>
345 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346   return os << tsb.DebugString();
347 }
348 
349 /// Represents the shape of a Tensor.
350 ///
351 /// A tensor's shape is denoted by its number of dimensions and a size for each
352 /// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
353 /// a shape of 2-D, [3,4].
354 ///
355 /// If you know the exact shape of your Tensor when you create the TensorShape
356 /// object, you can specify it then, or you can create a TensorShape with
357 /// zero dimensions and one element, and call AddDim() to add dimensions later.
358 class TensorShape : public TensorShapeBase<TensorShape> {
359  public:
360   using TensorShapeBase<TensorShape>::TensorShapeBase;
361 
362   /// Allow a TensorShape to be used as a PartialTensorShape without copying
363   operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
364 
365   /// Returns true if `*this` and `b` have the same sizes. Ignores
366   /// dimension names.
367   bool IsSameSize(const TensorShape& b) const;
368   bool operator==(const TensorShape& b) const { return IsSameSize(b); }
369   bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
370 
371   /// Fill `*dsizes` from `*this`.
372   /// Notice: Using IndexType=int32 in combination with To32Bit() can
373   /// significantly improve performance on GPU.
374   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
375   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
376 
377   // Same as `AsEigenDSizes()` but returns a `Status` instead.
378   // Use this method to surface error to user instead of crashing if `NDMIS` is
379   // not equal to `dims()`.
380   // Caller must take ownership of `out`.
381   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
382   Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
383 
384   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
385   /// which case we pad the rest of the sizes with 1.
386   /// Notice: Using IndexType=int32 in combination with To32Bit() can
387   /// significantly improve performance on GPU.
388   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
389   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
390 
391   // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
392   // Use this method to surface error to user instead of crashing if `NDMIS` is
393   // not equal to `dims()`.
394   // Caller must take ownership of `out`.
395   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
396   Status AsEigenDSizesWithPaddingWithStatus(
397       Eigen::DSizes<IndexType, NDIMS>* out) const;
398 
399  private:
400   // These CHECK fail to ease debugging.
401   // REQUIRES: dims() == NDIMS
402   void CheckDimsEqual(int NDIMS) const;
403   // REQUIRES: dims() >= NDIMS
404   void CheckDimsAtLeast(int NDIMS) const;
405 
406   // Fill output from `*this`.
407   // Helper method for common code between `AsEigenDSize()` and
408   // `AsEigenDSizeWithStatus()`.
409   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
410   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
411 
412   // Fill output from `*this`.
413   // Helper method for common code between `AsEigenDSizesWithPadding()` and
414   // `AsEigenDSizeWithPaddingWithStatus()`.
415   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
416   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
417 
418   // For access to TensorShapeBase(DataType).
419   friend class Tensor;
420 };
421 
422 /// Outputs `TensorShapeBase` to `std::ostream`.
423 inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
424   return os << ts.DebugString();
425 }
426 
427 /// Represents the value of one dimension in a TensorShape.
428 struct TensorShapeDim {
TensorShapeDimTensorShapeDim429   explicit TensorShapeDim(int64_t s) : size(s) {}
430   int64 size;
431 };
432 
433 // START_SKIP_DOXYGEN
434 template <class Shape>
435 class TensorShapeIter {
436  public:
TensorShapeIter(const Shape * shape,int d)437   TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
438   bool operator==(const TensorShapeIter& rhs) {
439     DCHECK(shape_ == rhs.shape_);
440     return d_ == rhs.d_;
441   }
442   bool operator!=(const TensorShapeIter& rhs) {
443     DCHECK(shape_ == rhs.shape_);
444     return d_ != rhs.d_;
445   }
446   void operator++() { ++d_; }
447   TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
448 
449  private:
450   const Shape* shape_;
451   int d_;
452 };
453 // END_SKIP_DOXYGEN
454 
455 /// \brief Static helper routines for `TensorShape`. Includes a few common
456 /// predicates on a tensor shape.
457 class TensorShapeUtils {
458  public:
IsScalar(const TensorShape & shape)459   static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
460 
IsVector(const TensorShape & shape)461   static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
462 
IsVectorOrHigher(const TensorShape & shape)463   static bool IsVectorOrHigher(const TensorShape& shape) {
464     return shape.dims() >= 1;
465   }
466 
IsMatrix(const TensorShape & shape)467   static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
468 
IsSquareMatrix(const TensorShape & shape)469   static bool IsSquareMatrix(const TensorShape& shape) {
470     return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
471   }
472 
IsMatrixOrHigher(const TensorShape & shape)473   static bool IsMatrixOrHigher(const TensorShape& shape) {
474     return shape.dims() >= 2;
475   }
476 
477   /// \brief Returns a `TensorShape` whose dimensions are
478   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
479   static Status MakeShape(const int32* dims, int64_t n, TensorShape* out);
480   static Status MakeShape(const int64* dims, int64_t n, TensorShape* out);
481   static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
482   static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
483   static Status MakeShape(const int32* dims, int64_t n,
484                           PartialTensorShape* out);
485   static Status MakeShape(const int64* dims, int64_t n,
486                           PartialTensorShape* out);
487   static Status MakeShape(gtl::ArraySlice<int32> shape,
488                           PartialTensorShape* out);
489   static Status MakeShape(gtl::ArraySlice<int64> shape,
490                           PartialTensorShape* out);
491 
492   static std::string ShapeListString(
493       const gtl::ArraySlice<TensorShape>& shapes);
494 
495   /// \brief Returns true iff `shape` starts with `prefix`.
496   static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
497 
498   /// \brief Returns true iff `shape` ends with `suffix`.
499   static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
500 
501   /// \brief Returns the product of values in an int64 array,
502   /// or a failing Status if the array represents a value larger than
503   /// a `TensorShape` can hold.
504   static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
505 };
506 
507 /// Manages the partially known dimensions of a Tensor and their sizes.
508 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
509  public:
PartialTensorShape()510   PartialTensorShape() {}
511   using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
512 
513   /// Add a dimension to the end ("inner-most"), returns a new
514   /// PartialTensorShape.
515   /// REQUIRES: `size >= -1`, where -1 means unknown
516   PartialTensorShape Concatenate(int64_t size) const;
517 
518   /// Similar to `Concatenate` but returning `Status`.
519   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
520   /// are to be avoided.
521   Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const;
522 
523   /// Appends all the dimensions from `shape`.  Returns a new
524   /// PartialTensorShape.
525   PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
526 
527   /// Similar to `Concatenate` but returning `Status`.
528   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
529   /// are to be avoided.
530   Status ConcatenateWithStatus(const PartialTensorShape& shape,
531                                PartialTensorShape* out) const;
532 
533   /// Merges all the dimensions from `shape`.  Returns
534   /// `InvalidArgument` error if either `shape` has a different rank
535   /// or if any of the dimensions are incompatible.
536   Status MergeWith(const PartialTensorShape& shape,
537                    PartialTensorShape* result) const;
538 
539   /// Exact equality test. Returns true iff the ranks match (i.e., both are
540   /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
541   /// both dimensions are known, or both are known and equal). This is a
542   /// stronger condition that IsCompatibleWith.
543   bool IsIdenticalTo(const PartialTensorShape& shape) const;
544 
545   /// Return true iff the ranks match, and if the
546   /// dimensions all either match or one is unknown.
547   bool IsCompatibleWith(const PartialTensorShape& shape) const;
548 
549   // Fill `*shape` from `*this`.
550   // If `*this` is not fully defined, returns false and
551   // `*shape` is left in an intermediate state.  Otherwise
552   // returns true.
553   bool AsTensorShape(TensorShape* shape) const;
554 
555   /// \brief Returns a `PartialTensorShape` whose dimensions are
556   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
557   /// considered "unknown".
558   template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)559   static Status MakePartialShape(const T* dims, int n,
560                                  PartialTensorShape* out) {
561     return TensorShapeUtils::MakeShape(dims, n, out);
562   }
563 };
564 
565 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
566 /// common predicates on a partially known tensor shape.
567 class PartialTensorShapeUtils {
568  public:
569   static std::string PartialShapeListString(
570       const gtl::ArraySlice<PartialTensorShape>& shapes);
571 
572   static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
573                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
574 
575   static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
576                             const gtl::ArraySlice<PartialTensorShape>& shapes1);
577 };
578 
579 // ----------------------------------------------------------------------------
580 // Template method implementation details below
581 // ----------------------------------------------------------------------------
582 
583 template <int NDIMS, typename IndexType>
AsEigenDSizesCopy()584 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
585   Eigen::DSizes<IndexType, NDIMS> dsizes;
586   for (int d = 0; d < NDIMS; d++) {
587     dsizes[d] = static_cast<IndexType>(dim_size(d));
588   }
589   return dsizes;
590 }
591 
592 template <int NDIMS, typename IndexType>
AsEigenDSizesCopyAndPad()593 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
594   static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
595   Eigen::DSizes<IndexType, NDIMS> dsizes;
596   for (int d = 0; d < dims(); d++) {
597     dsizes[d] = static_cast<IndexType>(dim_size(d));
598   }
599   for (int d = dims(); d < NDIMS; d++) {
600     dsizes[d] = 1;
601   }
602   return dsizes;
603 }
604 
605 template <int NDIMS, typename IndexType>
AsEigenDSizes()606 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
607   CheckDimsEqual(NDIMS);
608   return AsEigenDSizesCopy<NDIMS, IndexType>();
609 }
610 
611 template <int NDIMS, typename IndexType>
AsEigenDSizesWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)612 Status TensorShape::AsEigenDSizesWithStatus(
613     Eigen::DSizes<IndexType, NDIMS>* out) const {
614   if (TF_PREDICT_FALSE(NDIMS != dims())) {
615     return errors::Internal("Asking for tensor of ", NDIMS,
616                             " dimensions from a tensor of ", dims(),
617                             " dimensions");
618   }
619   *out = AsEigenDSizesCopy<NDIMS, IndexType>();
620 }
621 
622 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()623 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
624   CheckDimsAtLeast(NDIMS);
625   return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
626 }
627 
628 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPaddingWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)629 Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
630     Eigen::DSizes<IndexType, NDIMS>* out) const {
631   if (TF_PREDICT_FALSE(NDIMS < dims())) {
632     return errors::Internal("Asking for tensor of at least ", NDIMS,
633                             " dimensions from a tensor of ", dims(),
634                             " dimensions");
635   }
636   *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
637 }
638 
639 // ----------------------------------------------------------------------------
640 // Inlining of some performance critical routines
641 // ----------------------------------------------------------------------------
642 
TensorShapeRep(const TensorShapeRep & b)643 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
644   num_elements_ = b.num_elements_;
645   if (b.tag() != REP_OUT_OF_LINE) {
646     memcpy(buf(), b.buf(), sizeof(u_.buf));
647     // memcpy above Implicitly does:
648     //   set_ndims_byte(b.ndims_byte());
649     //   set_tag(b.tag());
650   } else {
651     set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
652     SlowCopyFrom(b);
653   }
654 }
655 
TensorShapeRep(TensorShapeRep && b)656 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
657   num_elements_ = b.num_elements_;
658   memcpy(buf(), b.buf(), sizeof(u_.buf));
659   // memcpy above Implicitly does:
660   //   set_ndims_byte(b.ndims_byte());
661   //   set_tag(b.tag());
662   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
663 }
664 
~TensorShapeRep()665 inline TensorShapeRep::~TensorShapeRep() {
666   if (tag() == REP_OUT_OF_LINE) {
667     DestructorOutOfLine();
668   }
669 }
670 
671 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
672   num_elements_ = b.num_elements_;
673   if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
674     memcpy(buf(), b.buf(), sizeof(u_.buf));
675     // memcpy above implicitly also does:
676     //   set_tag(b.tag());
677     //   set_ndims_byte(b.ndims_byte());
678   } else {
679     SlowCopyFrom(b);
680   }
681 }
682 
683 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
684   if (tag() == REP_OUT_OF_LINE) {
685     DestructorOutOfLine();
686   }
687   num_elements_ = b.num_elements_;
688   memcpy(buf(), b.buf(), sizeof(u_.buf));
689   // memcpy above Implicitly does:
690   //   set_ndims_byte(b.ndims_byte());
691   //   set_tag(b.tag());
692   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
693 }
694 
695 inline TensorShape::operator const PartialTensorShape&() const {
696   // Downcast to the shared representation and upcast to PartialTensorShape
697   const TensorShapeRep* rep = this;
698   return *static_cast<const PartialTensorShape*>(rep);
699 }
700 
701 template <class Shape>
TensorShapeBase(DataType dt)702 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
703   set_tag(REP16);
704   set_data_type(dt);
705 
706   // Optimized implementation of InitDims() where the shape is statically known
707   // to be {0}.
708   set_ndims_byte(1);
709   uint16* dst = as16()->dims_;
710   *dst = 0;
711   set_num_elements(0);
712 }
713 
714 // Declare explicit instantiations in .cc file
715 extern template class TensorShapeBase<TensorShape>;
716 extern template class TensorShapeBase<PartialTensorShape>;
717 
718 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
719 // often used in shape inference.
720 struct DtypeAndPartialTensorShape {
721   DataType dtype;
722   PartialTensorShape shape;
723 };
724 
725 }  // namespace tensorflow
726 
727 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
728