1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_shape.h"
17
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/util/overflow.h"
25
26 namespace tensorflow {
27
28 // TensorShape and PartialTensorShape should have no fields beyond
29 // TensorShapeRep. In particular, their sizes should be the same.
30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
31 "TensorShape must have no fields beyond TensorShapeRep");
32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
33 "PartialTensorShape must have no fields beyond TensorShapeRep");
34
35 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)36 static void AppendTo(const TensorShapeBase<Shape>& s,
37 gtl::InlinedVector<int64, 8>* vals) {
38 for (auto dim : s) {
39 vals->push_back(dim.size);
40 }
41 }
42
CheckDimsEqual(int NDIMS) const43 void TensorShape::CheckDimsEqual(int NDIMS) const {
44 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
45 << " from a tensor of " << dims() << " dimensions";
46 }
47
CheckDimsAtLeast(int NDIMS) const48 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
49 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
50 << " dimensions from a tensor of " << dims()
51 << " dimensions";
52 }
53
54 template <class Shape>
IsValid(const TensorShapeProto & proto)55 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
56 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
57 // unknown_shape() set, and it seems hard to remove this without backwards
58 // compatibility issues.
59 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
60 int64 num_elements = 1;
61 if (proto.dim().size() > MaxDimensions()) return false;
62 for (const auto& d : proto.dim()) {
63 if (d.size() < (kIsPartial ? -1 : 0)) return false;
64 if (d.size() == -1) {
65 num_elements = -1;
66 } else if (!kIsPartial || num_elements >= 0) {
67 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
68 if (num_elements < 0) return false;
69 }
70 }
71 return true;
72 }
73
74 template <class Shape>
IsValidShape(const TensorShapeProto & proto)75 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
76 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
77 // unknown_shape() set, and it seems hard to remove this without backwards
78 // compatibility issues.
79 if (kIsPartial && proto.unknown_rank()) {
80 if (proto.dim_size() > 0) {
81 return errors::InvalidArgument(
82 "An unknown shape must not have any dimensions set.");
83 }
84 return Status::OK();
85 }
86 int64 num_elements = 1;
87 if (proto.dim().size() > MaxDimensions()) {
88 return errors::InvalidArgument("Shape ", DebugString(proto),
89 " has too many dimensions");
90 }
91 for (const auto& d : proto.dim()) {
92 if (d.size() < (kIsPartial ? -1 : 0)) {
93 if (kIsPartial) {
94 return errors::InvalidArgument(
95 "Shape ", DebugString(proto),
96 " has dimensions with values below -1 (where -1 means unknown)");
97 } else {
98 return errors::InvalidArgument("Shape ", DebugString(proto),
99 " is not fully defined");
100 }
101 }
102 if (d.size() == -1) {
103 num_elements = -1;
104 } else if (!kIsPartial || num_elements >= 0) {
105 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
106 if (num_elements < 0) {
107 return errors::InvalidArgument(
108 "Shape ", DebugString(proto),
109 " is too large (more than 2**63 - 1 entries)");
110 }
111 }
112 }
113 return Status::OK();
114 }
115
116 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)117 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
118 set_tag(REP16);
119 set_data_type(DT_INVALID);
120 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
121 // unknown_shape() set, and it seems hard to remove this without backwards
122 // compatibility issues.
123 if (kIsPartial && proto.unknown_rank()) {
124 set_ndims_byte(kUnknownRank);
125 set_num_elements(-1);
126 } else {
127 set_ndims_byte(0);
128 set_num_elements(1);
129 for (const auto& d : proto.dim()) {
130 AddDim(d.size());
131 }
132 }
133 }
134
135 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)136 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
137 set_tag(REP16);
138 set_data_type(DT_INVALID);
139 InitDims(dim_sizes);
140 }
141
142 // Returns true iff partial is true and val is < 0.
143 // REQUIRES: val < kMaxRep16
144 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)145 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
146 if (partial) {
147 if (val < 0) {
148 dst[dim] = std::numeric_limits<uint16>::max();
149 return true;
150 }
151 } else {
152 CHECK_GE(val, 0);
153 }
154 dst[dim] = val;
155 return false;
156 }
157
158 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)159 void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
160 DCHECK_EQ(tag(), REP16);
161
162 // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
163 // below cannot overflow.
164 static const uint64 kMaxSmall = 0xd744;
165 static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
166 "bad overflow check");
167 bool large_size = false;
168 for (auto s : dim_sizes) {
169 if (s > kMaxSmall) {
170 large_size = true;
171 break;
172 }
173 }
174
175 if (!large_size) {
176 // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
177 uint16* dst = as16()->dims_;
178 switch (dim_sizes.size()) {
179 case 1: {
180 set_ndims_byte(1);
181 const int64 size = dim_sizes[0];
182 const bool neg = Set16(kIsPartial, dst, 0, size);
183 set_num_elements(neg ? -1 : size);
184 return;
185 }
186 case 2: {
187 set_ndims_byte(2);
188 const int64 size0 = dim_sizes[0];
189 const int64 size1 = dim_sizes[1];
190 bool neg = Set16(kIsPartial, dst, 0, size0);
191 neg |= Set16(kIsPartial, dst, 1, size1);
192 set_num_elements(neg ? -1 : (size0 * size1));
193 return;
194 }
195 case 3: {
196 set_ndims_byte(3);
197 const int64 size0 = dim_sizes[0];
198 const int64 size1 = dim_sizes[1];
199 const int64 size2 = dim_sizes[2];
200 bool neg = Set16(kIsPartial, dst, 0, size0);
201 neg |= Set16(kIsPartial, dst, 1, size1);
202 neg |= Set16(kIsPartial, dst, 2, size2);
203 set_num_elements(neg ? -1 : (size0 * size1 * size2));
204 return;
205 }
206 case 4: {
207 set_ndims_byte(4);
208 const int64 size0 = dim_sizes[0];
209 const int64 size1 = dim_sizes[1];
210 const int64 size2 = dim_sizes[2];
211 const int64 size3 = dim_sizes[3];
212 bool neg = Set16(kIsPartial, dst, 0, size0);
213 neg |= Set16(kIsPartial, dst, 1, size1);
214 neg |= Set16(kIsPartial, dst, 2, size2);
215 neg |= Set16(kIsPartial, dst, 3, size3);
216 set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
217 return;
218 }
219 }
220 }
221
222 set_ndims_byte(0);
223 set_num_elements(1);
224 for (int64 s : dim_sizes) {
225 AddDim(internal::SubtleMustCopy(s));
226 }
227 }
228
229 template <class Shape>
TensorShapeBase()230 TensorShapeBase<Shape>::TensorShapeBase() {
231 set_tag(REP16);
232 set_data_type(DT_INVALID);
233 if (kIsPartial) {
234 set_ndims_byte(kUnknownRank);
235 set_num_elements(-1);
236 } else {
237 set_ndims_byte(0);
238 set_num_elements(1);
239 }
240 }
241
DestructorOutOfLine()242 void TensorShapeRep::DestructorOutOfLine() {
243 DCHECK(tag() == REP_OUT_OF_LINE);
244 delete as64()->dims_;
245 }
246
SlowCopyFrom(const TensorShapeRep & b)247 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
248 if (b.tag() != REP_OUT_OF_LINE) {
249 if (tag() == REP_OUT_OF_LINE) {
250 delete as64()->dims_;
251 }
252 memcpy(buf(), b.buf(), sizeof(u_.buf));
253 // memcpy above implicitly also does:
254 // set_tag(b.tag());
255 // set_ndims_byte(b.ndims_byte());
256 // set_data_type(b.data_type());
257 } else {
258 DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
259 set_ndims_byte(b.ndims_byte());
260 set_data_type(b.data_type());
261 if (tag() == REP_OUT_OF_LINE) {
262 // vector already allocated
263 *(as64()->dims_) = *(b.as64()->dims_);
264 } else {
265 set_tag(REP_OUT_OF_LINE);
266 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
267 }
268 }
269 }
270
271 template <class Shape>
dim_size(int d) const272 int64 TensorShapeBase<Shape>::dim_size(int d) const {
273 if (unknown_rank()) return -1;
274 DCHECK_GE(d, 0);
275 DCHECK_LT(d, dims());
276 if (tag() == REP16) {
277 uint16 dim = as16()->dims_[d];
278 if (kIsPartial && dim == kUnknownRep16) return -1;
279 return dim;
280 } else if (tag() == REP32) {
281 uint32 dim = as32()->dims_[d];
282 if (kIsPartial && dim == kUnknownRep32) return -1;
283 return dim;
284 } else {
285 return (*as64()->dims_)[d];
286 }
287 }
288
Clear()289 void TensorShapeRep::Clear() {
290 ClearAllButDataType();
291 set_data_type(DT_INVALID);
292 }
293
ClearAllButDataType()294 void TensorShapeRep::ClearAllButDataType() {
295 if (tag() == REP_OUT_OF_LINE) {
296 delete as64()->dims_;
297 }
298 set_tag(REP16);
299 set_ndims_byte(0);
300 // Leaves data_type alone
301 set_num_elements(1);
302 }
303
304 template <class Shape>
RecomputeNumElements()305 void TensorShapeBase<Shape>::RecomputeNumElements() {
306 if (unknown_rank()) {
307 set_num_elements(-1);
308 return;
309 }
310 int64 n = 1;
311 for (auto dim : *this) {
312 if (kIsPartial && dim.size < 0) {
313 n = -1;
314 break;
315 }
316 n = MultiplyWithoutOverflow(n, dim.size);
317 CHECK_LE(0, n);
318 }
319 set_num_elements(n);
320 }
321
322 template <class Shape>
AddDim(int64 size)323 void TensorShapeBase<Shape>::AddDim(int64 size) {
324 if (!kIsPartial) CHECK_GE(size, 0);
325 if (unknown_rank()) return;
326 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
327 int64 new_num_elements;
328 if (kIsPartial && (num_elements() < 0 || size < 0)) {
329 new_num_elements = -1;
330 } else {
331 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
332 CHECK_LE(0, new_num_elements);
333 }
334 UnsafeAddDim(size, new_num_elements);
335 }
336
337 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)338 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
339 const int nd = ndims_byte();
340 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
341 as16()->dims_[nd] =
342 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
343 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
344 as32()->dims_[nd] =
345 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
346 } else if (tag() == REP_OUT_OF_LINE) {
347 as64()->dims_->push_back(size);
348 } else {
349 // Need to change representation
350 gtl::InlinedVector<int64, 8> vals;
351 AppendTo(*this, &vals);
352 vals.push_back(size);
353 // We know we can't be REP16. See if we have a small enough
354 // number of dimensions and each dimension's size is small enough
355 // to allow REP32.
356 bool can_be_rep32 = (vals.size() <= 3);
357 if (can_be_rep32) {
358 for (size_t i = 0; i < vals.size(); i++) {
359 if (vals[i] >= kMaxRep32) {
360 can_be_rep32 = false;
361 break;
362 }
363 }
364 }
365 if (can_be_rep32) {
366 set_tag(REP32);
367 for (size_t d = 0; d < vals.size(); d++) {
368 as32()->dims_[d] = kIsPartial && vals[d] < 0
369 ? kUnknownRep32
370 : static_cast<uint32>(vals[d]);
371 }
372 } else {
373 set_tag(REP_OUT_OF_LINE);
374 as64()->dims_ =
375 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
376 }
377 }
378 set_ndims_byte(nd + 1);
379 set_num_elements(new_num_elements);
380 }
381
382 template <class Shape>
AppendShape(const TensorShapeBase & shape)383 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
384 for (auto d : shape) AddDim(d.size);
385 }
386
387 template <class Shape>
InsertDim(int d,int64 size)388 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
389 CHECK_GE(d, 0);
390 CHECK_LE(d, dims());
391 if (!kIsPartial) CHECK_GE(size, 0);
392 CHECK_LT(dims(), MaxDimensions());
393 gtl::InlinedVector<int64, 8> vals;
394 AppendTo(*this, &vals);
395 vals.insert(vals.begin() + d, size);
396 ClearAllButDataType();
397 for (auto dval : vals) {
398 AddDim(dval);
399 }
400 }
401
402 template <class Shape>
dim_sizes() const403 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
404 gtl::InlinedVector<int64, 4> result;
405 for (auto dim : *this) {
406 result.push_back(dim.size);
407 }
408 return result;
409 }
410
411 template <class Shape>
set_dim(int d,int64 size)412 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
413 CHECK_GE(d, 0);
414 CHECK_LT(d, dims());
415 CHECK_GE(size, 0);
416 if (tag() == REP16 && size < kMaxRep16) {
417 as16()->dims_[d] =
418 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
419 } else if (tag() == REP32 && size < kMaxRep32) {
420 as32()->dims_[d] =
421 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
422 } else if (tag() == REP_OUT_OF_LINE) {
423 (*as64()->dims_)[d] = size;
424 } else {
425 // Must upgrade
426 gtl::InlinedVector<int64, 8> vals;
427 AppendTo(*this, &vals);
428 vals[d] = size;
429 ClearAllButDataType();
430 for (auto dval : vals) {
431 AddDim(dval);
432 }
433 }
434 RecomputeNumElements();
435 }
436
437 template <class Shape>
RemoveDimRange(int begin,int end)438 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
439 if (unknown_rank()) return;
440 begin = begin < 0 ? dims() + begin + 1 : begin;
441 end = end < 0 ? dims() + end + 1 : end;
442 CHECK_GE(begin, 0);
443 CHECK_LE(begin, dims());
444 CHECK_GE(end, 0);
445 CHECK_LE(end, dims());
446 if (begin >= end) return;
447 gtl::InlinedVector<int64, 8> vals;
448 AppendTo(*this, &vals);
449 vals.erase(vals.begin() + begin, vals.begin() + end);
450 ClearAllButDataType();
451 for (auto dval : vals) {
452 AddDim(dval);
453 }
454 RecomputeNumElements();
455 }
456
IsSameSize(const TensorShape & b) const457 bool TensorShape::IsSameSize(const TensorShape& b) const {
458 if (b.dims() != dims()) return false;
459 for (int d = 0; d < dims(); d++) {
460 if (dim_size(d) != b.dim_size(d)) return false;
461 }
462 return true;
463 }
464
465 template <class Shape>
AsProto(TensorShapeProto * proto) const466 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
467 proto->Clear();
468 if (unknown_rank()) {
469 proto->set_unknown_rank(true);
470 } else {
471 for (int i = 0; i < dims(); i++) {
472 proto->add_dim()->set_size(dim_size(i));
473 }
474 }
475 }
476
DumpRep() const477 void TensorShapeRep::DumpRep() const {
478 #if 0
479 fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
480 if (tag() == REP16) {
481 fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
482 for (int i = 0; i < ndims_byte(); i++) {
483 fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
484 }
485 } else if (tag_ == REP32) {
486 fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
487 for (int i = 0; i < ndims_byte(); i++) {
488 fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
489 }
490 } else if (tag_ == REP_OUT_OF_LINE) {
491 fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
492 for (int i = 0; i < ndims_byte(); i++) {
493 fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
494 }
495 }
496 #endif
497 }
498
499 template <class Shape>
begin() const500 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
501 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
502 }
503
504 template <class Shape>
end() const505 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
506 CHECK(!unknown_rank());
507 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
508 }
509
DebugString() const510 string TensorShapeRep::DebugString() const {
511 const auto& shape = *static_cast<const PartialTensorShape*>(this);
512 if (shape.unknown_rank()) return "<unknown>";
513 string s = "[";
514 for (int i = 0; i < shape.dims(); i++) {
515 if (i > 0) strings::StrAppend(&s, ",");
516 int64 dim = shape.dim_size(i);
517 if (dim < 0) {
518 strings::StrAppend(&s, "?");
519 } else {
520 strings::StrAppend(&s, dim);
521 }
522 }
523 strings::StrAppend(&s, "]");
524 return s;
525 }
526
DebugString(const TensorShapeProto & proto)527 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
528 string s;
529 if (proto.unknown_rank()) {
530 strings::StrAppend(&s, "<unknown>");
531 if (proto.dim_size() == 0) return s;
532 }
533 strings::StrAppend(&s, "[");
534 bool first = true;
535 for (const auto& d : proto.dim()) {
536 if (!first) strings::StrAppend(&s, ",");
537 if (d.size() == -1) {
538 strings::StrAppend(&s, "?");
539 } else {
540 strings::StrAppend(&s, d.size());
541 }
542 first = false;
543 }
544 strings::StrAppend(&s, "]");
545 return s;
546 }
547
StartsWith(const TensorShape & shape,const TensorShape & prefix)548 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
549 const TensorShape& prefix) {
550 if (shape.dims() < prefix.dims()) return false;
551 for (int i = 0; i < prefix.dims(); ++i) {
552 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
553 }
554 return true;
555 }
556
EndsWith(const TensorShape & shape,const TensorShape & suffix)557 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
558 const TensorShape& suffix) {
559 const int suffix_size = suffix.dims();
560 if (shape.dims() < suffix_size) return false;
561 for (int i = 0; i < suffix_size; ++i) {
562 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
563 return false;
564 }
565 }
566 return true;
567 }
568
569 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)570 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
571 out->Clear();
572 if (n > TensorShape::MaxDimensions()) {
573 return errors::InvalidArgument("Too many dimensions");
574 }
575 if (n < 0) {
576 return errors::InvalidArgument("Negative number of dimensions ", n);
577 }
578 for (int64 i = 0; i < n; ++i) {
579 T dim = internal::SubtleMustCopy(dims[i]);
580 int64 new_num_elements;
581 if (dim < 0) {
582 if (!out->kIsPartial) {
583 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
584 }
585 if (dim < -1) {
586 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
587 }
588 dim = -1;
589 new_num_elements = -1;
590 } else if (out->num_elements() < 0) {
591 new_num_elements = -1;
592 } else {
593 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
594 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
595 TensorShapeProto proto;
596 for (int64 j = 0; j < n; ++j) {
597 proto.add_dim()->set_size(dim);
598 }
599 return errors::InvalidArgument(
600 "Shape ", TensorShape::DebugString(proto),
601 " would have more than 2**63 - 1 elements");
602 }
603 }
604 out->UnsafeAddDim(dim, new_num_elements);
605 }
606 return Status::OK();
607 }
608
609 #define MAKE_SHAPE(T, Shape) \
610 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \
611 return MakeShapeHelper(dims, n, out); \
612 } \
613 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
614 return MakeShapeHelper(shape.data(), shape.size(), out); \
615 }
MAKE_SHAPE(int32,TensorShape)616 MAKE_SHAPE(int32, TensorShape)
617 MAKE_SHAPE(int64, TensorShape)
618 MAKE_SHAPE(int32, PartialTensorShape)
619 MAKE_SHAPE(int64, PartialTensorShape)
620 #undef MAKE_SHAPE
621
622 string TensorShapeUtils::ShapeListString(
623 const gtl::ArraySlice<TensorShape>& shapes) {
624 string result = "[";
625 bool first = true;
626 for (const TensorShape& shape : shapes) {
627 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
628 first = false;
629 }
630 strings::StrAppend(&result, "]");
631 return result;
632 }
633
Concatenate(int64 size) const634 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
635 PartialTensorShape out = *this;
636 out.AddDim(size);
637 return out;
638 }
639
Concatenate(const PartialTensorShape & shape) const640 PartialTensorShape PartialTensorShape::Concatenate(
641 const PartialTensorShape& shape) const {
642 if (unknown_rank() || shape.unknown_rank()) {
643 return PartialTensorShape();
644 }
645 PartialTensorShape out = *this;
646 for (auto dim : shape) out.AddDim(dim.size);
647 return out;
648 }
649
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const650 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
651 PartialTensorShape* result) const {
652 if (unknown_rank()) {
653 *result = shape;
654 return Status::OK();
655 }
656 if (shape.unknown_rank()) {
657 *result = *this;
658 return Status::OK();
659 }
660 const int dims_ = dims();
661 if (dims_ != shape.dims()) {
662 return errors::InvalidArgument(
663 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
664 shape.dims());
665 }
666 CHECK(result != this);
667 result->Clear();
668 for (int i = 0; i < dims_; ++i) {
669 const int64 dim0 = dim_size(i);
670 const int64 dim1 = shape.dim_size(i);
671 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
672 return errors::InvalidArgument(
673 "PartialTensorShape: Incompatible shapes during merge: ",
674 DebugString(), " vs. ", shape.DebugString());
675 }
676 result->AddDim(dim0 >= 0 ? dim0 : dim1);
677 }
678 return Status::OK();
679 }
680
AsTensorShape(TensorShape * shape) const681 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
682 if (IsFullyDefined()) {
683 const TensorShapeRep* rep = this;
684 *shape = *static_cast<const TensorShape*>(rep);
685 return true;
686 }
687 return false;
688 }
689
IsIdenticalTo(const PartialTensorShape & shape) const690 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
691 if (unknown_rank() || shape.unknown_rank()) {
692 return unknown_rank() == shape.unknown_rank();
693 }
694 if (dims() != shape.dims()) return false;
695 for (int i = 0; i < dims(); i++) {
696 if (dim_size(i) != shape.dim_size(i)) return false;
697 }
698 return true;
699 }
700
IsCompatibleWith(const PartialTensorShape & shape) const701 bool PartialTensorShape::IsCompatibleWith(
702 const PartialTensorShape& shape) const {
703 if (unknown_rank() || shape.unknown_rank()) return true;
704 if (dims() != shape.dims()) return false;
705 for (int i = 0; i < dims(); i++) {
706 const int64 dim0 = dim_size(i);
707 const int64 dim1 = shape.dim_size(i);
708 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
709 }
710 return true;
711 }
712
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)713 string PartialTensorShapeUtils::PartialShapeListString(
714 const gtl::ArraySlice<PartialTensorShape>& shapes) {
715 string result = "[";
716 bool first = true;
717 for (const PartialTensorShape& shape : shapes) {
718 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
719 first = false;
720 }
721 strings::StrAppend(&result, "]");
722 return result;
723 }
724
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)725 bool PartialTensorShapeUtils::AreCompatible(
726 const gtl::ArraySlice<PartialTensorShape>& shapes0,
727 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
728 if (shapes0.size() == shapes1.size()) {
729 for (size_t i = 0; i < shapes0.size(); ++i) {
730 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
731 return false;
732 }
733 }
734 return true;
735 } else {
736 return false;
737 }
738 }
739
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)740 bool PartialTensorShapeUtils::AreIdentical(
741 const gtl::ArraySlice<PartialTensorShape>& shapes0,
742 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
743 if (shapes0.size() == shapes1.size()) {
744 for (size_t i = 0; i < shapes0.size(); ++i) {
745 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
746 return false;
747 }
748 }
749 return true;
750 } else {
751 return false;
752 }
753 }
754
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)755 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
756 int64* num_elements) {
757 int64 n = 1;
758 for (auto dim : shape) {
759 n = MultiplyWithoutOverflow(n, dim);
760 if (n < 0) {
761 return errors::InvalidArgument("Can't compute total size of shape [",
762 str_util::Join(shape, ","),
763 "]; product would overflow int64");
764 }
765 }
766 *num_elements = n;
767 return Status::OK();
768 }
769
770 template class TensorShapeBase<TensorShape>;
771 template class TensorShapeBase<PartialTensorShape>;
772
773 } // namespace tensorflow
774