• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18 
19 #include <string>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace tensorflow {
32 
33 // START_SKIP_DOXYGEN
34 template <class Shape>
35 class TensorShapeIter;
36 class TensorShape;
37 class TensorShapeProto;
38 class PartialTensorShape;
39 // END_SKIP_DOXYGEN
40 
41 /// Internal representation for both TensorShape and PartialTensorShape.
42 class TensorShapeRep {
43  public:
44   ~TensorShapeRep();
45 
46   /// Copy the specified shape
47   TensorShapeRep(const TensorShapeRep& b);
48   void operator=(const TensorShapeRep& b);
49 
50   /// Move the specified shape.  After moving, `b` is safe for destruction and
51   // can be reassigned into, but its dimensions and number of elements can be
52   // nonsensical (e.g., negative dimension sizes, or number of elements not
53   // properly recomputed).
54   TensorShapeRep(TensorShapeRep&& b);
55   void operator=(TensorShapeRep&& b);
56 
57   /// Clear a tensor shape, producing the scalar shape.
58   void Clear();
59 
60   // Maximum number of dimensions in a tensor.
61   // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()62   static constexpr int MaxDimensions() { return 254; }
63 
64   /// \brief Returns the number of elements in the tensor.
65   ///
66   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67   /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
68   /// defined.
num_elements()69   int64 num_elements() const { return num_elements_; }
70 
71   /// For error messages.
72   string DebugString() const;
73   static string DebugString(const TensorShapeProto& proto);
74 
75   void DumpRep() const;  // XXX
76 
77  protected:
78   // Constructable only via TensorShapeBase
79   TensorShapeRep() = default;
80 
81   void ClearAllButDataType();
82 
83   // We use 16 bytes to represent a TensorShape.  Because we need to
84   // be able to support full 64-bit dimension sizes and an arbitrary
85   // number of dimensions for a Tensor, but most tensor dimensions are
86   // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87   // dimensions, we have several representations.
88   // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89   // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90   // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91   //        an out of line vector.
92   // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93   // This value is not allowed in TensorShape either for format compatibility.
94   struct Rep16 {
95     uint16 dims_[6];
96   };
97   struct Rep32 {
98     uint32 dims_[3];
99   };
100   struct Rep64 {
101     gtl::InlinedVector<int64, 4>* dims_;
102   };
103 
104   // We use the max value of uint16 or uint32 to represent unknown shapes, so
105   // the maximum representable valid shape in these representations is one less.
106   static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107   static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108   static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109   static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110 
as16()111   Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112   Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113   Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114 
as16()115   const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116   const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117   const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118 
119   enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120 
121   // Since we have a convenient extra byte available, we allow the
122   // Tensor class to store an 8-bit value in this extra storage.  This
123   // allows it to store the Tensor's datatype enum value here and avoid
124   // an extra word of storage.
125   friend class Tensor;
126   friend class TensorShapeTestHelper;
data_type()127   DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128   void set_data_type(DataType dt) {
129     // We only have 8 bits available to store DataType, so make sure it fits
130     DCHECK_LT(static_cast<uint32>(dt), 256u);
131     buf()[13] = static_cast<uint8>(dt);
132   }
133 
134   // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135   // Bytes [0..13] vary depending on the representation.
136   // A value of 255 indicates unknown rank in the PartialTensorShape case.
137   static const uint8 kUnknownRank = 255;
ndims_byte()138   uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139   void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140 
tag()141   RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142   void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143 
set_num_elements(int64 n)144   void set_num_elements(int64 n) { num_elements_ = n; }
145 
146  private:
147   void DestructorOutOfLine();
148   void SlowCopyFrom(const TensorShapeRep& b);
149 
buf()150   uint8* buf() { return &u_.buf[0]; }
buf()151   const uint8* buf() const { return &u_.buf[0]; }
152 
153   union {
154     uint8 buf[16];
155     // Force data to be aligned enough for a pointer.
156     Rep64* unused_aligner;
157   } u_;
158   int64 num_elements_;
159 };
160 
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167  public:
168   /// \brief Construct a `TensorShapeBase` from the provided sizes.
169   /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170   explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171   TensorShapeBase(std::initializer_list<int64> dim_sizes)
172       : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173 
174   /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175   TensorShapeBase();
176 
177   TensorShapeBase(const TensorShapeProto& proto);
178 
179   /// Returns `true` iff `proto` is a valid tensor shape.
180   // For TensorShape, the proto shape must be fully defined.
181   static bool IsValid(const TensorShapeProto& proto);
182 
183   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
184   /// status otherwise.
185   static Status IsValidShape(const TensorShapeProto& proto);
186 
187   /// Returns `true` iff this is a valid tensor shape.
188   bool IsValid();
189 
190   /// \brief Add a dimension to the end ("inner-most").
191   /// REQUIRES: `size >= 0`
192   void AddDim(int64 size);
193 
194   /// Appends all the dimensions from `shape`.
195   void AppendShape(const TensorShapeBase& shape);
196 
197   /// \brief Insert a dimension somewhere in the `TensorShape`.
198   /// REQUIRES: `0 <= d <= dims()`
199   /// REQUIRES: `size >= 0`
200   void InsertDim(int d, int64 size);
201 
202   /// \brief Modifies the size of the dimension `d` to be `size`
203   /// REQUIRES: `0 <= d < dims()`
204   /// REQUIRES: `size >= 0`
205   void set_dim(int d, int64 size);
206 
207   /// \brief Removes dimension `d` from the `TensorShape`.
208   /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)209   void RemoveDim(int d) {
210     CHECK_GE(d, 0);
211     RemoveDimRange(d, d + 1);
212   }
213 
214   /// \brief Removes last `n` dimensions from the `TensorShape`.
215   /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)216   void RemoveLastDims(int n) {
217     CHECK_LE(n, dims());
218     RemoveDimRange(dims() - n, dims());
219   }
220 
221   /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
222   /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
223   /// Python). The same is true for negative values of `begin`. REQUIRES:
224   /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
225   void RemoveDimRange(int begin, int end);
226 
227   /// Return whether the rank is unknown
unknown_rank()228   bool unknown_rank() const {
229     return kIsPartial && ndims_byte() == kUnknownRank;
230   }
231 
232   /// Return the number of dimensions in the tensor.
233   /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()234   int dims() const {
235     uint8 dims = ndims_byte();
236     return kIsPartial && dims == kUnknownRank ? -1 : dims;
237   }
238 
239   /// \brief Returns the number of elements in dimension `d`.
240   /// REQUIRES: `0 <= d < dims()`
241   // TODO(touts): Rename to `dimension()` to match
242   // `Eigen::Tensor::dimension()`?
243   int64 dim_size(int d) const;
244 
245   /// Returns sizes of all dimensions.
246   // Returns an empty list for unknown rank PartialTensorShape.
247   gtl::InlinedVector<int64, 4> dim_sizes() const;
248 
249   /// Return true iff the rank and all of the dimensions are well defined
250   // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()251   bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
252 
253   /// Fill `*proto` from `*this`.
254   void AsProto(TensorShapeProto* proto) const;
255 
256   /// For iterating through the dimensions.
257   TensorShapeIter<Shape> begin() const;
258   TensorShapeIter<Shape> end() const;
259 
260  protected:
261   // Optimized constructor for a shape representing an empty vector.
262   //
263   // This constructor is provided to optimize the default constructor for
264   // `Tensor`.
265   explicit TensorShapeBase(DataType dt);
266 
267  private:
268   void RecomputeNumElements();
269   void InitDims(gtl::ArraySlice<int64> dim_sizes);
270 
271   // True for PartialTensorShape, false for TensorShape
272   static constexpr bool kIsPartial =
273       std::is_same<Shape, PartialTensorShape>::value;
274   static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
275                 "Shape is neither TensorShape nor PartialTensorShape");
276 
277   // Used by AddDim and MakeShapeHelper.  Does no error checking.
278   void UnsafeAddDim(int64 size, int64 new_num_elements);
279 
280   // For use by TensorShapeUtils::MakeShape
281   template <class T, class S>
282   friend Status MakeShapeHelper(const T*, int64, S*);
283 };
284 
285 /// Outputs `TensorShapeBase` to `std::ostream`.
286 template <typename Shape>
287 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
288   return os << tsb.DebugString();
289 }
290 
291 /// Represents the shape of a Tensor.
292 ///
293 /// A tensor's shape is denoted by its number of dimensions and a size for each
294 /// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
295 /// a shape of 2-D, [3,4].
296 ///
297 /// If you know the exact shape of your Tensor when you create the TensorShape
298 /// object, you can specify it then, or you can create a TensorShape with
299 /// zero dimensions and one element, and call AddDim() to add dimensions later.
300 class TensorShape : public TensorShapeBase<TensorShape> {
301  public:
302   using TensorShapeBase<TensorShape>::TensorShapeBase;
303 
304   /// Allow a TensorShape to be used as a PartialTensorShape without copying
305   operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
306 
307   /// Returns true if `*this` and `b` have the same sizes. Ignores
308   /// dimension names.
309   bool IsSameSize(const TensorShape& b) const;
310   bool operator==(const TensorShape& b) const { return IsSameSize(b); }
311   bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
312 
313   /// Fill `*dsizes` from `*this`.
314   /// Notice: Using IndexType=int32 in combination with To32Bit() can
315   /// significantly improve performance on GPU.
316   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
317   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
318 
319   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
320   /// which case we pad the rest of the sizes with 1.
321   /// Notice: Using IndexType=int32 in combination with To32Bit() can
322   /// significantly improve performance on GPU.
323   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
324   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
325 
326  private:
327   // These CHECK fail to ease debugging.
328   // REQUIRES: dims() == NDIMS
329   void CheckDimsEqual(int NDIMS) const;
330   // REQUIRES: dims() >= NDIMS
331   void CheckDimsAtLeast(int NDIMS) const;
332 
333   // For access to TensorShapeBase(DataType).
334   friend class Tensor;
335 };
336 
337 /// Represents the value of one dimension in a TensorShape.
338 struct TensorShapeDim {
TensorShapeDimTensorShapeDim339   explicit TensorShapeDim(int64 s) : size(s) {}
340   int64 size;
341 };
342 
343 // START_SKIP_DOXYGEN
344 template <class Shape>
345 class TensorShapeIter {
346  public:
TensorShapeIter(const Shape * shape,int d)347   TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
348   bool operator==(const TensorShapeIter& rhs) {
349     DCHECK(shape_ == rhs.shape_);
350     return d_ == rhs.d_;
351   }
352   bool operator!=(const TensorShapeIter& rhs) {
353     DCHECK(shape_ == rhs.shape_);
354     return d_ != rhs.d_;
355   }
356   void operator++() { ++d_; }
357   TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
358 
359  private:
360   const Shape* shape_;
361   int d_;
362 };
363 // END_SKIP_DOXYGEN
364 
365 /// \brief Static helper routines for `TensorShape`. Includes a few common
366 /// predicates on a tensor shape.
367 class TensorShapeUtils {
368  public:
IsScalar(const TensorShape & shape)369   static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
370 
IsVector(const TensorShape & shape)371   static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
372 
IsVectorOrHigher(const TensorShape & shape)373   static bool IsVectorOrHigher(const TensorShape& shape) {
374     return shape.dims() >= 1;
375   }
376 
IsMatrix(const TensorShape & shape)377   static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
378 
IsSquareMatrix(const TensorShape & shape)379   static bool IsSquareMatrix(const TensorShape& shape) {
380     return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
381   }
382 
IsMatrixOrHigher(const TensorShape & shape)383   static bool IsMatrixOrHigher(const TensorShape& shape) {
384     return shape.dims() >= 2;
385   }
386 
387   /// \brief Returns a `TensorShape` whose dimensions are
388   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
389   static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
390   static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
391   static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
392   static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
393   static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
394   static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
395   static Status MakeShape(gtl::ArraySlice<int32> shape,
396                           PartialTensorShape* out);
397   static Status MakeShape(gtl::ArraySlice<int64> shape,
398                           PartialTensorShape* out);
399 
400   static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
401 
402   /// \brief Returns true iff `shape` starts with `prefix`.
403   static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
404 
405   /// \brief Returns true iff `shape` ends with `suffix`.
406   static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
407 
408   /// \brief Returns the product of values in an int64 array,
409   /// or a failing Status if the array represents a value larger than
410   /// a `TensorShape` can hold.
411   static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
412 };
413 
414 /// Manages the partially known dimensions of a Tensor and their sizes.
415 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
416  public:
PartialTensorShape()417   PartialTensorShape() {}
418   using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
419 
420   /// Add a dimension to the end ("inner-most"), returns a new
421   /// PartialTensorShape.
422   /// REQUIRES: `size >= -1`, where -1 means unknown
423   PartialTensorShape Concatenate(int64 size) const;
424 
425   /// Appends all the dimensions from `shape`.  Returns a new
426   /// PartialTensorShape.
427   PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
428 
429   /// Merges all the dimensions from `shape`.  Returns
430   /// `InvalidArgument` error if either `shape` has a different rank
431   /// or if any of the dimensions are incompatible.
432   Status MergeWith(const PartialTensorShape& shape,
433                    PartialTensorShape* result) const;
434 
435   /// Exact equality test. Returns true iff the ranks match (i.e., both are
436   /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
437   /// both dimensions are known, or both are known and equal). This is a
438   /// stronger condition that IsCompatibleWith.
439   bool IsIdenticalTo(const PartialTensorShape& shape) const;
440 
441   /// Return true iff the ranks match, and if the
442   /// dimensions all either match or one is unknown.
443   bool IsCompatibleWith(const PartialTensorShape& shape) const;
444 
445   // Fill `*shape` from `*this`.
446   // If `*this` is not fully defined, returns false and
447   // `*shape` is left in an intermediate state.  Otherwise
448   // returns true.
449   bool AsTensorShape(TensorShape* shape) const;
450 
451   /// \brief Returns a `PartialTensorShape` whose dimensions are
452   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
453   /// considered "unknown".
454   template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)455   static Status MakePartialShape(const T* dims, int n,
456                                  PartialTensorShape* out) {
457     return TensorShapeUtils::MakeShape(dims, n, out);
458   }
459 };
460 
461 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
462 /// common predicates on a partially known tensor shape.
463 class PartialTensorShapeUtils {
464  public:
465   static string PartialShapeListString(
466       const gtl::ArraySlice<PartialTensorShape>& shapes);
467 
468   static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
469                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
470 
471   static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
472                             const gtl::ArraySlice<PartialTensorShape>& shapes1);
473 };
474 
475 // ----------------------------------------------------------------------------
476 // Template method implementation details below
477 // ----------------------------------------------------------------------------
478 
479 template <int NDIMS, typename IndexType>
AsEigenDSizes()480 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
481   CheckDimsEqual(NDIMS);
482   return AsEigenDSizesWithPadding<NDIMS, IndexType>();
483 }
484 
485 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()486 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
487   CheckDimsAtLeast(NDIMS);
488   static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
489   Eigen::DSizes<IndexType, NDIMS> dsizes;
490   for (int d = 0; d < dims(); d++) {
491     dsizes[d] = static_cast<IndexType>(dim_size(d));
492   }
493   for (int d = dims(); d < NDIMS; d++) {
494     dsizes[d] = 1;
495   }
496   return dsizes;
497 }
498 
499 // ----------------------------------------------------------------------------
500 // Inlining of some performance critical routines
501 // ----------------------------------------------------------------------------
502 
TensorShapeRep(const TensorShapeRep & b)503 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
504   num_elements_ = b.num_elements_;
505   if (b.tag() != REP_OUT_OF_LINE) {
506     memcpy(buf(), b.buf(), sizeof(u_.buf));
507     // memcpy above Implicitly does:
508     //   set_ndims_byte(b.ndims_byte());
509     //   set_tag(b.tag());
510   } else {
511     set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
512     SlowCopyFrom(b);
513   }
514 }
515 
TensorShapeRep(TensorShapeRep && b)516 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
517   num_elements_ = b.num_elements_;
518   memcpy(buf(), b.buf(), sizeof(u_.buf));
519   // memcpy above Implicitly does:
520   //   set_ndims_byte(b.ndims_byte());
521   //   set_tag(b.tag());
522   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
523 }
524 
~TensorShapeRep()525 inline TensorShapeRep::~TensorShapeRep() {
526   if (tag() == REP_OUT_OF_LINE) {
527     DestructorOutOfLine();
528   }
529 }
530 
531 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
532   num_elements_ = b.num_elements_;
533   if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
534     memcpy(buf(), b.buf(), sizeof(u_.buf));
535     // memcpy above implicitly also does:
536     //   set_tag(b.tag());
537     //   set_ndims_byte(b.ndims_byte());
538   } else {
539     SlowCopyFrom(b);
540   }
541 }
542 
543 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
544   if (tag() == REP_OUT_OF_LINE) {
545     DestructorOutOfLine();
546   }
547   num_elements_ = b.num_elements_;
548   memcpy(buf(), b.buf(), sizeof(u_.buf));
549   // memcpy above Implicitly does:
550   //   set_ndims_byte(b.ndims_byte());
551   //   set_tag(b.tag());
552   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
553 }
554 
555 inline TensorShape::operator const PartialTensorShape&() const {
556   // Downcast to the shared representation and upcast to PartialTensorShape
557   const TensorShapeRep* rep = this;
558   return *static_cast<const PartialTensorShape*>(rep);
559 }
560 
561 template <class Shape>
TensorShapeBase(DataType dt)562 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
563   set_tag(REP16);
564   set_data_type(dt);
565 
566   // Optimized implementation of InitDims() where the shape is statically known
567   // to be {0}.
568   set_ndims_byte(1);
569   uint16* dst = as16()->dims_;
570   *dst = 0;
571   set_num_elements(0);
572 }
573 
574 // Declare explicit instantiations in .cc file
575 extern template class TensorShapeBase<TensorShape>;
576 extern template class TensorShapeBase<PartialTensorShape>;
577 
578 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
579 // often used in shape inference.
580 struct DtypeAndPartialTensorShape {
581   DataType dtype;
582   PartialTensorShape shape;
583 };
584 
585 }  // namespace tensorflow
586 
587 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
588