1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19 #include <string>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace tensorflow {
32
33 // START_SKIP_DOXYGEN
34 template <class Shape>
35 class TensorShapeIter;
36 class TensorShape;
37 class TensorShapeProto;
38 class PartialTensorShape;
39 // END_SKIP_DOXYGEN
40
41 /// Internal representation for both TensorShape and PartialTensorShape.
42 class TensorShapeRep {
43 public:
44 ~TensorShapeRep();
45
46 /// Copy the specified shape
47 TensorShapeRep(const TensorShapeRep& b);
48 void operator=(const TensorShapeRep& b);
49
50 /// Move the specified shape. After moving, <b> is safe for destruction and
51 // can be reassigned into, but its dimensions and number of elements can be
52 // nonsensical (e.g., negative dimension sizes, or number of elements not
53 // properly recomputed).
54 TensorShapeRep(TensorShapeRep&& b);
55 void operator=(TensorShapeRep&& b);
56
57 /// Clear a tensor shape, producing the scalar shape.
58 void Clear();
59
60 // Maximum number of dimensions in a tensor.
61 // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()62 static constexpr int MaxDimensions() { return 254; }
63
64 /// \brief Returns the number of elements in the tensor.
65 ///
66 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully
68 /// defined.
num_elements()69 int64 num_elements() const { return num_elements_; }
70
71 /// For error messages.
72 string DebugString() const;
73 static string DebugString(const TensorShapeProto& proto);
74
75 void DumpRep() const; // XXX
76
77 protected:
78 // Constructable only via TensorShapeBase
79 TensorShapeRep() = default;
80
81 void ClearAllButDataType();
82
83 // We use 16 bytes to represent a TensorShape. Because we need to
84 // be able to support full 64-bit dimension sizes and an arbitrary
85 // number of dimensions for a Tensor, but most tensor dimensions are
86 // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87 // dimensions, we have several representations.
88 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91 // an out of line vector.
92 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93 // This value is not allowed in TensorShape either for format compatibility.
94 struct Rep16 {
95 uint16 dims_[6];
96 };
97 struct Rep32 {
98 uint32 dims_[3];
99 };
100 struct Rep64 {
101 gtl::InlinedVector<int64, 4>* dims_;
102 };
103
104 // We use the max value of uint16 or uint32 to represent unknown shapes, so
105 // the maximum representable valid shape in these representations is one less.
106 static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107 static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108 static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109 static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110
as16()111 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114
as16()115 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118
119 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120
121 // Since we have a convenient extra byte available, we allow the
122 // Tensor class to store an 8-bit value in this extra storage. This
123 // allows it to store the Tensor's datatype enum value here and avoid
124 // an extra word of storage.
125 friend class Tensor;
126 friend class TensorShapeTestHelper;
data_type()127 DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128 void set_data_type(DataType dt) {
129 // We only have 8 bits available to store DataType, so make sure it fits
130 DCHECK_LT(static_cast<uint32>(dt), 256u);
131 buf()[13] = static_cast<uint8>(dt);
132 }
133
134 // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135 // Bytes [0..13] vary depending on the representation.
136 // A value of 255 indicates unknown rank in the PartialTensorShape case.
137 static const uint8 kUnknownRank = 255;
ndims_byte()138 uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139 void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140
tag()141 RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143
set_num_elements(int64 n)144 void set_num_elements(int64 n) { num_elements_ = n; }
145
146 private:
147 void DestructorOutOfLine();
148 void SlowCopyFrom(const TensorShapeRep& b);
149
buf()150 uint8* buf() { return &u_.buf[0]; }
buf()151 const uint8* buf() const { return &u_.buf[0]; }
152
153 union {
154 uint8 buf[16];
155 // Force data to be aligned enough for a pointer.
156 Rep64* unused_aligner;
157 } u_;
158 int64 num_elements_;
159 };
160
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167 public:
168 /// \brief Construct a `TensorShapeBase` from the provided sizes.
169 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170 explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171 TensorShapeBase(std::initializer_list<int64> dim_sizes)
172 : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173
174 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175 TensorShapeBase();
176
177 TensorShapeBase(const TensorShapeProto& proto);
178
179 /// Returns `true` iff `proto` is a valid tensor shape.
180 // For TensorShape, the proto shape must be fully defined.
181 static bool IsValid(const TensorShapeProto& proto);
182
183 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
184 /// status otherwise.
185 static Status IsValidShape(const TensorShapeProto& proto);
186
187 /// \brief Add a dimension to the end ("inner-most").
188 /// REQUIRES: `size >= 0`
189 void AddDim(int64 size);
190
191 /// Appends all the dimensions from `shape`.
192 void AppendShape(const TensorShapeBase& shape);
193
194 /// \brief Insert a dimension somewhere in the `TensorShape`.
195 /// REQUIRES: `0 <= d <= dims()`
196 /// REQUIRES: `size >= 0`
197 void InsertDim(int d, int64 size);
198
199 /// \brief Modifies the size of the dimension `d` to be `size`
200 /// REQUIRES: `0 <= d < dims()`
201 /// REQUIRES: `size >= 0`
202 void set_dim(int d, int64 size);
203
204 /// \brief Removes dimension `d` from the `TensorShape`.
205 /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)206 void RemoveDim(int d) {
207 CHECK_GE(d, 0);
208 RemoveDimRange(d, d + 1);
209 }
210
211 /// \brief Removes last `n` dimensions from the `TensorShape`.
212 /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)213 void RemoveLastDims(int n) {
214 CHECK_LE(n, dims());
215 RemoveDimRange(dims() - n, dims());
216 }
217
218 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
219 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
220 /// Python). The same is true for negative values of `begin`. REQUIRES:
221 /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
222 void RemoveDimRange(int begin, int end);
223
224 /// Return whether the rank is unknown
unknown_rank()225 bool unknown_rank() const {
226 return kIsPartial && ndims_byte() == kUnknownRank;
227 }
228
229 /// Return the number of dimensions in the tensor.
230 /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()231 int dims() const {
232 uint8 dims = ndims_byte();
233 return kIsPartial && dims == kUnknownRank ? -1 : dims;
234 }
235
236 /// \brief Returns the number of elements in dimension `d`.
237 /// REQUIRES: `0 <= d < dims()`
238 // TODO(touts): Rename to `dimension()` to match
239 // `Eigen::Tensor::dimension()`?
240 int64 dim_size(int d) const;
241
242 /// Returns sizes of all dimensions.
243 // Returns an empty list for unknown rank PartialTensorShape.
244 gtl::InlinedVector<int64, 4> dim_sizes() const;
245
246 /// Return true iff the rank and all of the dimensions are well defined
247 // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()248 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
249
250 /// Fill `*proto` from `*this`.
251 void AsProto(TensorShapeProto* proto) const;
252
253 /// For iterating through the dimensions.
254 TensorShapeIter<Shape> begin() const;
255 TensorShapeIter<Shape> end() const;
256
257 private:
258 void RecomputeNumElements();
259 void InitDims(gtl::ArraySlice<int64> dim_sizes);
260
261 // True for PartialTensorShape, false for TensorShape
262 static constexpr bool kIsPartial =
263 std::is_same<Shape, PartialTensorShape>::value;
264 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
265 "Shape is neither TensorShape nor PartialTensorShape");
266
267 // Used by AddDim and MakeShapeHelper. Does no error checking.
268 void UnsafeAddDim(int64 size, int64 new_num_elements);
269
270 // For use by TensorShapeUtils::MakeShape
271 template <class T, class S>
272 friend Status MakeShapeHelper(const T*, int64, S*);
273 };
274
275 /// Outputs `TensorShapeBase` to `std::ostream`.
276 template <typename Shape>
277 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
278 return os << tsb.DebugString();
279 }
280
281 /// Represents the shape of a Tensor.
282 ///
283 /// A tensor's shape is denoted by its number of dimensions and a size for each
284 /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have
285 /// a shape of 2-D, [3,4].
286 ///
287 /// If you know the exact shape of your Tensor when you create the TensorShape
288 /// object, you can specify it then, or you can create a TensorShape with
289 /// zero dimensions and one element, and call AddDim() to add dimensions later.
290 class TensorShape : public TensorShapeBase<TensorShape> {
291 public:
292 using TensorShapeBase<TensorShape>::TensorShapeBase;
293
294 /// Allow a TensorShape to be used as a PartialTensorShape without copying
295 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit)
296
297 /// Returns true if `*this` and `b` have the same sizes. Ignores
298 /// dimension names.
299 bool IsSameSize(const TensorShape& b) const;
300 bool operator==(const TensorShape& b) const { return IsSameSize(b); }
301 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
302
303 /// Fill `*dsizes` from `*this`.
304 /// Notice: Using IndexType=int32 in combination with To32Bit() can
305 /// significantly improve performance on GPU.
306 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
307 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
308
309 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
310 /// which case we pad the rest of the sizes with 1.
311 /// Notice: Using IndexType=int32 in combination with To32Bit() can
312 /// significantly improve performance on GPU.
313 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
314 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
315
316 private:
317 // These CHECK fail to ease debugging.
318 // REQUIRES: dims() == NDIMS
319 void CheckDimsEqual(int NDIMS) const;
320 // REQUIRES: dims() >= NDIMS
321 void CheckDimsAtLeast(int NDIMS) const;
322 };
323
324 /// Represents the value of one dimension in a TensorShape.
325 struct TensorShapeDim {
TensorShapeDimTensorShapeDim326 explicit TensorShapeDim(int64 s) : size(s) {}
327 int64 size;
328 };
329
330 // START_SKIP_DOXYGEN
331 template <class Shape>
332 class TensorShapeIter {
333 public:
TensorShapeIter(const Shape * shape,int d)334 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
335 bool operator==(const TensorShapeIter& rhs) {
336 DCHECK(shape_ == rhs.shape_);
337 return d_ == rhs.d_;
338 }
339 bool operator!=(const TensorShapeIter& rhs) {
340 DCHECK(shape_ == rhs.shape_);
341 return d_ != rhs.d_;
342 }
343 void operator++() { ++d_; }
344 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
345
346 private:
347 const Shape* shape_;
348 int d_;
349 };
350 // END_SKIP_DOXYGEN
351
352 /// \brief Static helper routines for `TensorShape`. Includes a few common
353 /// predicates on a tensor shape.
354 class TensorShapeUtils {
355 public:
IsScalar(const TensorShape & shape)356 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
357
IsVector(const TensorShape & shape)358 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
359
IsVectorOrHigher(const TensorShape & shape)360 static bool IsVectorOrHigher(const TensorShape& shape) {
361 return shape.dims() >= 1;
362 }
363
IsMatrix(const TensorShape & shape)364 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
365
IsSquareMatrix(const TensorShape & shape)366 static bool IsSquareMatrix(const TensorShape& shape) {
367 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
368 }
369
IsMatrixOrHigher(const TensorShape & shape)370 static bool IsMatrixOrHigher(const TensorShape& shape) {
371 return shape.dims() >= 2;
372 }
373
374 /// \brief Returns a `TensorShape` whose dimensions are
375 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
376 static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
377 static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
378 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
379 static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
380 static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
381 static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
382 static Status MakeShape(gtl::ArraySlice<int32> shape,
383 PartialTensorShape* out);
384 static Status MakeShape(gtl::ArraySlice<int64> shape,
385 PartialTensorShape* out);
386
387 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
388
389 /// \brief Returns true iff `shape` starts with `prefix`.
390 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
391
392 /// \brief Returns true iff `shape` ends with `suffix`.
393 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
394
395 /// \brief Returns the product of values in an int64 array,
396 /// or a failing Status if the array represents a value larger than
397 /// a `TensorShape` can hold.
398 static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
399 };
400
401 /// Manages the partially known dimensions of a Tensor and their sizes.
402 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
403 public:
PartialTensorShape()404 PartialTensorShape() {}
405 using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
406
407 /// Add a dimension to the end ("inner-most"), returns a new
408 /// PartialTensorShape.
409 /// REQUIRES: `size >= -1`, where -1 means unknown
410 PartialTensorShape Concatenate(int64 size) const;
411
412 /// Appends all the dimensions from `shape`. Returns a new
413 /// PartialTensorShape.
414 PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
415
416 /// Merges all the dimensions from `shape`. Returns
417 /// `InvalidArgument` error if either `shape` has a different rank
418 /// or if any of the dimensions are incompatible.
419 Status MergeWith(const PartialTensorShape& shape,
420 PartialTensorShape* result) const;
421
422 /// Exact equality test. Returns true iff the ranks match (i.e., both are
423 /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
424 /// both dimensions are known, or both are known and equal). This is a
425 /// stronger condition that IsCompatibleWith.
426 bool IsIdenticalTo(const PartialTensorShape& shape) const;
427
428 /// Return true iff the ranks match, and if the
429 /// dimensions all either match or one is unknown.
430 bool IsCompatibleWith(const PartialTensorShape& shape) const;
431
432 // Fill `*shape` from `*this`.
433 // If `*this` is not fully defined, returns false and
434 // `*shape` is left in an intermediate state. Otherwise
435 // returns true.
436 bool AsTensorShape(TensorShape* shape) const;
437
438 /// \brief Returns a `PartialTensorShape` whose dimensions are
439 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
440 /// considered "unknown".
441 template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)442 static Status MakePartialShape(const T* dims, int n,
443 PartialTensorShape* out) {
444 return TensorShapeUtils::MakeShape(dims, n, out);
445 }
446 };
447
448 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
449 /// common predicates on a partially known tensor shape.
450 class PartialTensorShapeUtils {
451 public:
452 static string PartialShapeListString(
453 const gtl::ArraySlice<PartialTensorShape>& shapes);
454
455 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
456 const gtl::ArraySlice<PartialTensorShape>& shapes1);
457
458 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
459 const gtl::ArraySlice<PartialTensorShape>& shapes1);
460 };
461
462 // ----------------------------------------------------------------------------
463 // Template method implementation details below
464 // ----------------------------------------------------------------------------
465
466 template <int NDIMS, typename IndexType>
AsEigenDSizes()467 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
468 CheckDimsEqual(NDIMS);
469 return AsEigenDSizesWithPadding<NDIMS, IndexType>();
470 }
471
472 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()473 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
474 CheckDimsAtLeast(NDIMS);
475 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
476 Eigen::DSizes<IndexType, NDIMS> dsizes;
477 for (int d = 0; d < dims(); d++) {
478 dsizes[d] = static_cast<IndexType>(dim_size(d));
479 }
480 for (int d = dims(); d < NDIMS; d++) {
481 dsizes[d] = 1;
482 }
483 return dsizes;
484 }
485
486 // ----------------------------------------------------------------------------
487 // Inlining of some performance critical routines
488 // ----------------------------------------------------------------------------
489
TensorShapeRep(const TensorShapeRep & b)490 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
491 num_elements_ = b.num_elements_;
492 if (b.tag() != REP_OUT_OF_LINE) {
493 memcpy(buf(), b.buf(), sizeof(u_.buf));
494 // memcpy above Implicitly does:
495 // set_ndims_byte(b.ndims_byte());
496 // set_tag(b.tag());
497 } else {
498 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate
499 SlowCopyFrom(b);
500 }
501 }
502
TensorShapeRep(TensorShapeRep && b)503 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
504 num_elements_ = b.num_elements_;
505 memcpy(buf(), b.buf(), sizeof(u_.buf));
506 // memcpy above Implicitly does:
507 // set_ndims_byte(b.ndims_byte());
508 // set_tag(b.tag());
509 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
510 }
511
~TensorShapeRep()512 inline TensorShapeRep::~TensorShapeRep() {
513 if (tag() == REP_OUT_OF_LINE) {
514 DestructorOutOfLine();
515 }
516 }
517
518 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
519 num_elements_ = b.num_elements_;
520 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
521 memcpy(buf(), b.buf(), sizeof(u_.buf));
522 // memcpy above implicitly also does:
523 // set_tag(b.tag());
524 // set_ndims_byte(b.ndims_byte());
525 } else {
526 SlowCopyFrom(b);
527 }
528 }
529
530 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
531 if (tag() == REP_OUT_OF_LINE) {
532 DestructorOutOfLine();
533 }
534 num_elements_ = b.num_elements_;
535 memcpy(buf(), b.buf(), sizeof(u_.buf));
536 // memcpy above Implicitly does:
537 // set_ndims_byte(b.ndims_byte());
538 // set_tag(b.tag());
539 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
540 }
541
542 inline TensorShape::operator const PartialTensorShape&() const {
543 // Downcast to the shared representation and upcast to PartialTensorShape
544 const TensorShapeRep* rep = this;
545 return *static_cast<const PartialTensorShape*>(rep);
546 }
547
548 // Declare explicit instantiations in .cc file
549 extern template class TensorShapeBase<TensorShape>;
550 extern template class TensorShapeBase<PartialTensorShape>;
551
552 } // namespace tensorflow
553
554 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
555