1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19 #include <string>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace tensorflow {
32
33 // START_SKIP_DOXYGEN
34 template <class Shape>
35 class TensorShapeIter;
36 class TensorShape;
37 class TensorShapeProto;
38 class PartialTensorShape;
39 // END_SKIP_DOXYGEN
40
41 /// Internal representation for both TensorShape and PartialTensorShape.
42 class TensorShapeRep {
43 public:
44 ~TensorShapeRep();
45
46 /// Copy the specified shape
47 TensorShapeRep(const TensorShapeRep& b);
48 void operator=(const TensorShapeRep& b);
49
50 /// Move the specified shape. After moving, `b` is safe for destruction and
51 // can be reassigned into, but its dimensions and number of elements can be
52 // nonsensical (e.g., negative dimension sizes, or number of elements not
53 // properly recomputed).
54 TensorShapeRep(TensorShapeRep&& b);
55 void operator=(TensorShapeRep&& b);
56
57 /// Clear a tensor shape, producing the scalar shape.
58 void Clear();
59
60 // Maximum number of dimensions in a tensor.
61 // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()62 static constexpr int MaxDimensions() { return 254; }
63
64 /// \brief Returns the number of elements in the tensor.
65 ///
66 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully
68 /// defined.
num_elements()69 int64 num_elements() const { return num_elements_; }
70
71 /// For error messages.
72 string DebugString() const;
73 static string DebugString(const TensorShapeProto& proto);
74
75 void DumpRep() const; // XXX
76
77 protected:
78 // Constructable only via TensorShapeBase
79 TensorShapeRep() = default;
80
81 void ClearAllButDataType();
82
83 // We use 16 bytes to represent a TensorShape. Because we need to
84 // be able to support full 64-bit dimension sizes and an arbitrary
85 // number of dimensions for a Tensor, but most tensor dimensions are
86 // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87 // dimensions, we have several representations.
88 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91 // an out of line vector.
92 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93 // This value is not allowed in TensorShape either for format compatibility.
94 struct Rep16 {
95 uint16 dims_[6];
96 };
97 struct Rep32 {
98 uint32 dims_[3];
99 };
100 struct Rep64 {
101 gtl::InlinedVector<int64, 4>* dims_;
102 };
103
104 // We use the max value of uint16 or uint32 to represent unknown shapes, so
105 // the maximum representable valid shape in these representations is one less.
106 static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107 static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108 static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109 static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110
as16()111 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()112 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()113 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114
as16()115 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()116 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()117 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118
119 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120
121 // Since we have a convenient extra byte available, we allow the
122 // Tensor class to store an 8-bit value in this extra storage. This
123 // allows it to store the Tensor's datatype enum value here and avoid
124 // an extra word of storage.
125 friend class Tensor;
126 friend class TensorShapeTestHelper;
data_type()127 DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)128 void set_data_type(DataType dt) {
129 // We only have 8 bits available to store DataType, so make sure it fits
130 DCHECK_LT(static_cast<uint32>(dt), 256u);
131 buf()[13] = static_cast<uint8>(dt);
132 }
133
134 // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135 // Bytes [0..13] vary depending on the representation.
136 // A value of 255 indicates unknown rank in the PartialTensorShape case.
137 static const uint8 kUnknownRank = 255;
ndims_byte()138 uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)139 void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140
tag()141 RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)142 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143
set_num_elements(int64 n)144 void set_num_elements(int64 n) { num_elements_ = n; }
145
146 private:
147 void DestructorOutOfLine();
148 void SlowCopyFrom(const TensorShapeRep& b);
149
buf()150 uint8* buf() { return &u_.buf[0]; }
buf()151 const uint8* buf() const { return &u_.buf[0]; }
152
153 union {
154 uint8 buf[16];
155 // Force data to be aligned enough for a pointer.
156 Rep64* unused_aligner;
157 } u_;
158 int64 num_elements_;
159 };
160
161 /// Base class for TensorShape and PartialTensorShape.
162 /// The class is templatized by either TensorShape or PartialTensorShape to
163 /// allow skipping known/unknown checks in the TensorShape case, but the
164 /// representation is shared exactly for fast conversion.
165 template <class Shape>
166 class TensorShapeBase : public TensorShapeRep {
167 public:
168 /// \brief Construct a `TensorShapeBase` from the provided sizes.
169 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170 explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)171 TensorShapeBase(std::initializer_list<int64> dim_sizes)
172 : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173
174 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175 TensorShapeBase();
176
177 TensorShapeBase(const TensorShapeProto& proto);
178
179 /// Returns `true` iff `proto` is a valid tensor shape.
180 // For TensorShape, the proto shape must be fully defined.
181 static bool IsValid(const TensorShapeProto& proto);
182
183 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
184 /// status otherwise.
185 static Status IsValidShape(const TensorShapeProto& proto);
186
187 /// Returns `true` iff this is a valid tensor shape.
188 bool IsValid();
189
190 /// \brief Add a dimension to the end ("inner-most").
191 /// REQUIRES: `size >= 0`
192 void AddDim(int64 size);
193
194 /// Appends all the dimensions from `shape`.
195 void AppendShape(const TensorShapeBase& shape);
196
197 /// \brief Insert a dimension somewhere in the `TensorShape`.
198 /// REQUIRES: `0 <= d <= dims()`
199 /// REQUIRES: `size >= 0`
200 void InsertDim(int d, int64 size);
201
202 /// \brief Modifies the size of the dimension `d` to be `size`
203 /// REQUIRES: `0 <= d < dims()`
204 /// REQUIRES: `size >= 0`
205 void set_dim(int d, int64 size);
206
207 /// \brief Removes dimension `d` from the `TensorShape`.
208 /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)209 void RemoveDim(int d) {
210 CHECK_GE(d, 0);
211 RemoveDimRange(d, d + 1);
212 }
213
214 /// \brief Removes last `n` dimensions from the `TensorShape`.
215 /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)216 void RemoveLastDims(int n) {
217 CHECK_LE(n, dims());
218 RemoveDimRange(dims() - n, dims());
219 }
220
221 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
222 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
223 /// Python). The same is true for negative values of `begin`. REQUIRES:
224 /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
225 void RemoveDimRange(int begin, int end);
226
227 /// Return whether the rank is unknown
unknown_rank()228 bool unknown_rank() const {
229 return kIsPartial && ndims_byte() == kUnknownRank;
230 }
231
232 /// Return the number of dimensions in the tensor.
233 /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()234 int dims() const {
235 uint8 dims = ndims_byte();
236 return kIsPartial && dims == kUnknownRank ? -1 : dims;
237 }
238
239 /// \brief Returns the number of elements in dimension `d`.
240 /// REQUIRES: `0 <= d < dims()`
241 // TODO(touts): Rename to `dimension()` to match
242 // `Eigen::Tensor::dimension()`?
243 int64 dim_size(int d) const;
244
245 /// Returns sizes of all dimensions.
246 // Returns an empty list for unknown rank PartialTensorShape.
247 gtl::InlinedVector<int64, 4> dim_sizes() const;
248
249 /// Return true iff the rank and all of the dimensions are well defined
250 // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()251 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
252
253 /// Fill `*proto` from `*this`.
254 void AsProto(TensorShapeProto* proto) const;
255
256 /// For iterating through the dimensions.
257 TensorShapeIter<Shape> begin() const;
258 TensorShapeIter<Shape> end() const;
259
260 protected:
261 // Optimized constructor for a shape representing an empty vector.
262 //
263 // This constructor is provided to optimize the default constructor for
264 // `Tensor`.
265 explicit TensorShapeBase(DataType dt);
266
267 private:
268 void RecomputeNumElements();
269 void InitDims(gtl::ArraySlice<int64> dim_sizes);
270
271 // True for PartialTensorShape, false for TensorShape
272 static constexpr bool kIsPartial =
273 std::is_same<Shape, PartialTensorShape>::value;
274 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
275 "Shape is neither TensorShape nor PartialTensorShape");
276
277 // Used by AddDim and MakeShapeHelper. Does no error checking.
278 void UnsafeAddDim(int64 size, int64 new_num_elements);
279
280 // For use by TensorShapeUtils::MakeShape
281 template <class T, class S>
282 friend Status MakeShapeHelper(const T*, int64, S*);
283 };
284
285 /// Outputs `TensorShapeBase` to `std::ostream`.
286 template <typename Shape>
287 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
288 return os << tsb.DebugString();
289 }
290
291 /// Represents the shape of a Tensor.
292 ///
293 /// A tensor's shape is denoted by its number of dimensions and a size for each
294 /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have
295 /// a shape of 2-D, [3,4].
296 ///
297 /// If you know the exact shape of your Tensor when you create the TensorShape
298 /// object, you can specify it then, or you can create a TensorShape with
299 /// zero dimensions and one element, and call AddDim() to add dimensions later.
300 class TensorShape : public TensorShapeBase<TensorShape> {
301 public:
302 using TensorShapeBase<TensorShape>::TensorShapeBase;
303
304 /// Allow a TensorShape to be used as a PartialTensorShape without copying
305 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit)
306
307 /// Returns true if `*this` and `b` have the same sizes. Ignores
308 /// dimension names.
309 bool IsSameSize(const TensorShape& b) const;
310 bool operator==(const TensorShape& b) const { return IsSameSize(b); }
311 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
312
313 /// Fill `*dsizes` from `*this`.
314 /// Notice: Using IndexType=int32 in combination with To32Bit() can
315 /// significantly improve performance on GPU.
316 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
317 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
318
319 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
320 /// which case we pad the rest of the sizes with 1.
321 /// Notice: Using IndexType=int32 in combination with To32Bit() can
322 /// significantly improve performance on GPU.
323 template <int NDIMS, typename IndexType = Eigen::DenseIndex>
324 Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
325
326 private:
327 // These CHECK fail to ease debugging.
328 // REQUIRES: dims() == NDIMS
329 void CheckDimsEqual(int NDIMS) const;
330 // REQUIRES: dims() >= NDIMS
331 void CheckDimsAtLeast(int NDIMS) const;
332
333 // For access to TensorShapeBase(DataType).
334 friend class Tensor;
335 };
336
337 /// Represents the value of one dimension in a TensorShape.
338 struct TensorShapeDim {
TensorShapeDimTensorShapeDim339 explicit TensorShapeDim(int64 s) : size(s) {}
340 int64 size;
341 };
342
343 // START_SKIP_DOXYGEN
344 template <class Shape>
345 class TensorShapeIter {
346 public:
TensorShapeIter(const Shape * shape,int d)347 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
348 bool operator==(const TensorShapeIter& rhs) {
349 DCHECK(shape_ == rhs.shape_);
350 return d_ == rhs.d_;
351 }
352 bool operator!=(const TensorShapeIter& rhs) {
353 DCHECK(shape_ == rhs.shape_);
354 return d_ != rhs.d_;
355 }
356 void operator++() { ++d_; }
357 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
358
359 private:
360 const Shape* shape_;
361 int d_;
362 };
363 // END_SKIP_DOXYGEN
364
365 /// \brief Static helper routines for `TensorShape`. Includes a few common
366 /// predicates on a tensor shape.
367 class TensorShapeUtils {
368 public:
IsScalar(const TensorShape & shape)369 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
370
IsVector(const TensorShape & shape)371 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
372
IsVectorOrHigher(const TensorShape & shape)373 static bool IsVectorOrHigher(const TensorShape& shape) {
374 return shape.dims() >= 1;
375 }
376
IsMatrix(const TensorShape & shape)377 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
378
IsSquareMatrix(const TensorShape & shape)379 static bool IsSquareMatrix(const TensorShape& shape) {
380 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
381 }
382
IsMatrixOrHigher(const TensorShape & shape)383 static bool IsMatrixOrHigher(const TensorShape& shape) {
384 return shape.dims() >= 2;
385 }
386
387 /// \brief Returns a `TensorShape` whose dimensions are
388 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
389 static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
390 static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
391 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
392 static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
393 static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
394 static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
395 static Status MakeShape(gtl::ArraySlice<int32> shape,
396 PartialTensorShape* out);
397 static Status MakeShape(gtl::ArraySlice<int64> shape,
398 PartialTensorShape* out);
399
400 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
401
402 /// \brief Returns true iff `shape` starts with `prefix`.
403 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
404
405 /// \brief Returns true iff `shape` ends with `suffix`.
406 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
407
408 /// \brief Returns the product of values in an int64 array,
409 /// or a failing Status if the array represents a value larger than
410 /// a `TensorShape` can hold.
411 static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
412 };
413
414 /// Manages the partially known dimensions of a Tensor and their sizes.
415 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
416 public:
PartialTensorShape()417 PartialTensorShape() {}
418 using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
419
420 /// Add a dimension to the end ("inner-most"), returns a new
421 /// PartialTensorShape.
422 /// REQUIRES: `size >= -1`, where -1 means unknown
423 PartialTensorShape Concatenate(int64 size) const;
424
425 /// Appends all the dimensions from `shape`. Returns a new
426 /// PartialTensorShape.
427 PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
428
429 /// Merges all the dimensions from `shape`. Returns
430 /// `InvalidArgument` error if either `shape` has a different rank
431 /// or if any of the dimensions are incompatible.
432 Status MergeWith(const PartialTensorShape& shape,
433 PartialTensorShape* result) const;
434
435 /// Exact equality test. Returns true iff the ranks match (i.e., both are
436 /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
437 /// both dimensions are known, or both are known and equal). This is a
438 /// stronger condition that IsCompatibleWith.
439 bool IsIdenticalTo(const PartialTensorShape& shape) const;
440
441 /// Return true iff the ranks match, and if the
442 /// dimensions all either match or one is unknown.
443 bool IsCompatibleWith(const PartialTensorShape& shape) const;
444
445 // Fill `*shape` from `*this`.
446 // If `*this` is not fully defined, returns false and
447 // `*shape` is left in an intermediate state. Otherwise
448 // returns true.
449 bool AsTensorShape(TensorShape* shape) const;
450
451 /// \brief Returns a `PartialTensorShape` whose dimensions are
452 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
453 /// considered "unknown".
454 template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)455 static Status MakePartialShape(const T* dims, int n,
456 PartialTensorShape* out) {
457 return TensorShapeUtils::MakeShape(dims, n, out);
458 }
459 };
460
461 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
462 /// common predicates on a partially known tensor shape.
463 class PartialTensorShapeUtils {
464 public:
465 static string PartialShapeListString(
466 const gtl::ArraySlice<PartialTensorShape>& shapes);
467
468 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
469 const gtl::ArraySlice<PartialTensorShape>& shapes1);
470
471 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
472 const gtl::ArraySlice<PartialTensorShape>& shapes1);
473 };
474
475 // ----------------------------------------------------------------------------
476 // Template method implementation details below
477 // ----------------------------------------------------------------------------
478
479 template <int NDIMS, typename IndexType>
AsEigenDSizes()480 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
481 CheckDimsEqual(NDIMS);
482 return AsEigenDSizesWithPadding<NDIMS, IndexType>();
483 }
484
485 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()486 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
487 CheckDimsAtLeast(NDIMS);
488 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
489 Eigen::DSizes<IndexType, NDIMS> dsizes;
490 for (int d = 0; d < dims(); d++) {
491 dsizes[d] = static_cast<IndexType>(dim_size(d));
492 }
493 for (int d = dims(); d < NDIMS; d++) {
494 dsizes[d] = 1;
495 }
496 return dsizes;
497 }
498
499 // ----------------------------------------------------------------------------
500 // Inlining of some performance critical routines
501 // ----------------------------------------------------------------------------
502
TensorShapeRep(const TensorShapeRep & b)503 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
504 num_elements_ = b.num_elements_;
505 if (b.tag() != REP_OUT_OF_LINE) {
506 memcpy(buf(), b.buf(), sizeof(u_.buf));
507 // memcpy above Implicitly does:
508 // set_ndims_byte(b.ndims_byte());
509 // set_tag(b.tag());
510 } else {
511 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate
512 SlowCopyFrom(b);
513 }
514 }
515
TensorShapeRep(TensorShapeRep && b)516 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
517 num_elements_ = b.num_elements_;
518 memcpy(buf(), b.buf(), sizeof(u_.buf));
519 // memcpy above Implicitly does:
520 // set_ndims_byte(b.ndims_byte());
521 // set_tag(b.tag());
522 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
523 }
524
~TensorShapeRep()525 inline TensorShapeRep::~TensorShapeRep() {
526 if (tag() == REP_OUT_OF_LINE) {
527 DestructorOutOfLine();
528 }
529 }
530
531 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
532 num_elements_ = b.num_elements_;
533 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
534 memcpy(buf(), b.buf(), sizeof(u_.buf));
535 // memcpy above implicitly also does:
536 // set_tag(b.tag());
537 // set_ndims_byte(b.ndims_byte());
538 } else {
539 SlowCopyFrom(b);
540 }
541 }
542
543 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
544 if (tag() == REP_OUT_OF_LINE) {
545 DestructorOutOfLine();
546 }
547 num_elements_ = b.num_elements_;
548 memcpy(buf(), b.buf(), sizeof(u_.buf));
549 // memcpy above Implicitly does:
550 // set_ndims_byte(b.ndims_byte());
551 // set_tag(b.tag());
552 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any.
553 }
554
555 inline TensorShape::operator const PartialTensorShape&() const {
556 // Downcast to the shared representation and upcast to PartialTensorShape
557 const TensorShapeRep* rep = this;
558 return *static_cast<const PartialTensorShape*>(rep);
559 }
560
561 template <class Shape>
TensorShapeBase(DataType dt)562 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
563 set_tag(REP16);
564 set_data_type(dt);
565
566 // Optimized implementation of InitDims() where the shape is statically known
567 // to be {0}.
568 set_ndims_byte(1);
569 uint16* dst = as16()->dims_;
570 *dst = 0;
571 set_num_elements(0);
572 }
573
574 // Declare explicit instantiations in .cc file
575 extern template class TensorShapeBase<TensorShape>;
576 extern template class TensorShapeBase<PartialTensorShape>;
577
578 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
579 // often used in shape inference.
580 struct DtypeAndPartialTensorShape {
581 DataType dtype;
582 PartialTensorShape shape;
583 };
584
585 } // namespace tensorflow
586
587 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
588