1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_shape.h"
17
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/util/overflow.h"
25
26 namespace tensorflow {
27
28 // TensorShape and PartialTensorShape should have no fields beyond
29 // TensorShapeRep. In particular, their sizes should be the same.
30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
31 "TensorShape must have no fields beyond TensorShapeRep");
32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
33 "PartialTensorShape must have no fields beyond TensorShapeRep");
34
35 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)36 static void AppendTo(const TensorShapeBase<Shape>& s,
37 gtl::InlinedVector<int64, 8>* vals) {
38 for (auto dim : s) {
39 vals->push_back(dim.size);
40 }
41 }
42
CheckDimsEqual(int NDIMS) const43 void TensorShape::CheckDimsEqual(int NDIMS) const {
44 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
45 << " from a tensor of " << dims() << " dimensions";
46 }
47
CheckDimsAtLeast(int NDIMS) const48 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
49 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
50 << " dimensions from a tensor of " << dims()
51 << " dimensions";
52 }
53
54 // TODO(slebedev): Consider merging IsValid implementations.
55 template <class Shape>
IsValid()56 bool TensorShapeBase<Shape>::IsValid() {
57 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
58 // unknown_shape() set, and it seems hard to remove this without backwards
59 // compatibility issues.
60 if (kIsPartial && unknown_rank()) return dims() == 0;
61 int64 num_elements = 1;
62 if (dims() > MaxDimensions()) return false;
63 for (auto d : dim_sizes()) {
64 if (d < (kIsPartial ? -1 : 0)) return false;
65 if (d == -1) {
66 num_elements = -1;
67 } else if (!kIsPartial || num_elements >= 0) {
68 num_elements = MultiplyWithoutOverflow(num_elements, d);
69 if (num_elements < 0) return false;
70 }
71 }
72 return true;
73 }
74
75 template <class Shape>
IsValid(const TensorShapeProto & proto)76 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
77 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
78 // unknown_shape() set, and it seems hard to remove this without backwards
79 // compatibility issues.
80 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
81 int64 num_elements = 1;
82 if (proto.dim().size() > MaxDimensions()) return false;
83 for (const auto& d : proto.dim()) {
84 if (d.size() < (kIsPartial ? -1 : 0)) return false;
85 if (d.size() == -1) {
86 num_elements = -1;
87 } else if (!kIsPartial || num_elements >= 0) {
88 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
89 if (num_elements < 0) return false;
90 }
91 }
92 return true;
93 }
94
95 template <class Shape>
IsValidShape(const TensorShapeProto & proto)96 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
97 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
98 // unknown_shape() set, and it seems hard to remove this without backwards
99 // compatibility issues.
100 if (kIsPartial && proto.unknown_rank()) {
101 if (proto.dim_size() > 0) {
102 return errors::InvalidArgument(
103 "An unknown shape must not have any dimensions set.");
104 }
105 return Status::OK();
106 }
107 int64 num_elements = 1;
108 if (proto.dim().size() > MaxDimensions()) {
109 return errors::InvalidArgument("Shape ", DebugString(proto),
110 " has too many dimensions");
111 }
112 for (const auto& d : proto.dim()) {
113 if (d.size() < (kIsPartial ? -1 : 0)) {
114 if (kIsPartial) {
115 return errors::InvalidArgument(
116 "Shape ", DebugString(proto),
117 " has dimensions with values below -1 (where -1 means unknown)");
118 } else {
119 return errors::InvalidArgument("Shape ", DebugString(proto),
120 " is not fully defined");
121 }
122 }
123 if (d.size() == -1) {
124 num_elements = -1;
125 } else if (!kIsPartial || num_elements >= 0) {
126 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
127 if (num_elements < 0) {
128 return errors::InvalidArgument(
129 "Shape ", DebugString(proto),
130 " is too large (more than 2**63 - 1 entries)");
131 }
132 }
133 }
134 return Status::OK();
135 }
136
137 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)138 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
139 set_tag(REP16);
140 set_data_type(DT_INVALID);
141 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
142 // unknown_shape() set, and it seems hard to remove this without backwards
143 // compatibility issues.
144 if (kIsPartial && proto.unknown_rank()) {
145 set_ndims_byte(kUnknownRank);
146 set_num_elements(-1);
147 } else {
148 set_ndims_byte(0);
149 set_num_elements(1);
150 for (const auto& d : proto.dim()) {
151 AddDim(d.size());
152 }
153 }
154 }
155
156 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)157 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
158 set_tag(REP16);
159 set_data_type(DT_INVALID);
160 InitDims(dim_sizes);
161 }
162
163 // Returns true iff partial is true and val is < 0.
164 // REQUIRES: val < kMaxRep16
165 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)166 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
167 if (partial) {
168 if (val < 0) {
169 dst[dim] = std::numeric_limits<uint16>::max();
170 return true;
171 }
172 } else {
173 CHECK_GE(val, 0);
174 }
175 dst[dim] = val;
176 return false;
177 }
178
179 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)180 void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
181 DCHECK_EQ(tag(), REP16);
182
183 // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
184 // below cannot overflow.
185 static const uint64 kMaxSmall = 0xd744;
186 static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
187 "bad overflow check");
188 bool large_size = false;
189 for (auto s : dim_sizes) {
190 if (s > kMaxSmall) {
191 large_size = true;
192 break;
193 }
194 }
195
196 if (!large_size) {
197 // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
198 uint16* dst = as16()->dims_;
199 switch (dim_sizes.size()) {
200 case 1: {
201 set_ndims_byte(1);
202 const int64 size = dim_sizes[0];
203 const bool neg = Set16(kIsPartial, dst, 0, size);
204 set_num_elements(neg ? -1 : size);
205 return;
206 }
207 case 2: {
208 set_ndims_byte(2);
209 const int64 size0 = dim_sizes[0];
210 const int64 size1 = dim_sizes[1];
211 bool neg = Set16(kIsPartial, dst, 0, size0);
212 neg |= Set16(kIsPartial, dst, 1, size1);
213 set_num_elements(neg ? -1 : (size0 * size1));
214 return;
215 }
216 case 3: {
217 set_ndims_byte(3);
218 const int64 size0 = dim_sizes[0];
219 const int64 size1 = dim_sizes[1];
220 const int64 size2 = dim_sizes[2];
221 bool neg = Set16(kIsPartial, dst, 0, size0);
222 neg |= Set16(kIsPartial, dst, 1, size1);
223 neg |= Set16(kIsPartial, dst, 2, size2);
224 set_num_elements(neg ? -1 : (size0 * size1 * size2));
225 return;
226 }
227 case 4: {
228 set_ndims_byte(4);
229 const int64 size0 = dim_sizes[0];
230 const int64 size1 = dim_sizes[1];
231 const int64 size2 = dim_sizes[2];
232 const int64 size3 = dim_sizes[3];
233 bool neg = Set16(kIsPartial, dst, 0, size0);
234 neg |= Set16(kIsPartial, dst, 1, size1);
235 neg |= Set16(kIsPartial, dst, 2, size2);
236 neg |= Set16(kIsPartial, dst, 3, size3);
237 set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
238 return;
239 }
240 }
241 }
242
243 set_ndims_byte(0);
244 set_num_elements(1);
245 for (int64 s : dim_sizes) {
246 AddDim(internal::SubtleMustCopy(s));
247 }
248 }
249
250 template <class Shape>
TensorShapeBase()251 TensorShapeBase<Shape>::TensorShapeBase() {
252 set_tag(REP16);
253 set_data_type(DT_INVALID);
254 if (kIsPartial) {
255 set_ndims_byte(kUnknownRank);
256 set_num_elements(-1);
257 } else {
258 set_ndims_byte(0);
259 set_num_elements(1);
260 }
261 }
262
DestructorOutOfLine()263 void TensorShapeRep::DestructorOutOfLine() {
264 DCHECK(tag() == REP_OUT_OF_LINE);
265 delete as64()->dims_;
266 }
267
SlowCopyFrom(const TensorShapeRep & b)268 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
269 if (b.tag() != REP_OUT_OF_LINE) {
270 if (tag() == REP_OUT_OF_LINE) {
271 delete as64()->dims_;
272 }
273 memcpy(buf(), b.buf(), sizeof(u_.buf));
274 // memcpy above implicitly also does:
275 // set_tag(b.tag());
276 // set_ndims_byte(b.ndims_byte());
277 // set_data_type(b.data_type());
278 } else {
279 DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
280 set_ndims_byte(b.ndims_byte());
281 set_data_type(b.data_type());
282 if (tag() == REP_OUT_OF_LINE) {
283 // vector already allocated
284 *(as64()->dims_) = *(b.as64()->dims_);
285 } else {
286 set_tag(REP_OUT_OF_LINE);
287 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
288 }
289 }
290 }
291
292 template <class Shape>
dim_size(int d) const293 int64 TensorShapeBase<Shape>::dim_size(int d) const {
294 if (unknown_rank()) return -1;
295 DCHECK_GE(d, 0);
296 DCHECK_LT(d, dims());
297 if (tag() == REP16) {
298 uint16 dim = as16()->dims_[d];
299 if (kIsPartial && dim == kUnknownRep16) return -1;
300 return dim;
301 } else if (tag() == REP32) {
302 uint32 dim = as32()->dims_[d];
303 if (kIsPartial && dim == kUnknownRep32) return -1;
304 return dim;
305 } else {
306 return (*as64()->dims_)[d];
307 }
308 }
309
Clear()310 void TensorShapeRep::Clear() {
311 ClearAllButDataType();
312 set_data_type(DT_INVALID);
313 }
314
ClearAllButDataType()315 void TensorShapeRep::ClearAllButDataType() {
316 if (tag() == REP_OUT_OF_LINE) {
317 delete as64()->dims_;
318 }
319 set_tag(REP16);
320 set_ndims_byte(0);
321 // Leaves data_type alone
322 set_num_elements(1);
323 }
324
325 template <class Shape>
RecomputeNumElements()326 void TensorShapeBase<Shape>::RecomputeNumElements() {
327 if (unknown_rank()) {
328 set_num_elements(-1);
329 return;
330 }
331 int64 n = 1;
332 for (auto dim : *this) {
333 if (kIsPartial && dim.size < 0) {
334 n = -1;
335 break;
336 }
337 n = MultiplyWithoutOverflow(n, dim.size);
338 CHECK_LE(0, n);
339 }
340 set_num_elements(n);
341 }
342
343 template <class Shape>
AddDim(int64 size)344 void TensorShapeBase<Shape>::AddDim(int64 size) {
345 if (!kIsPartial) CHECK_GE(size, 0);
346 if (unknown_rank()) return;
347 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
348 int64 new_num_elements;
349 if (kIsPartial && (num_elements() < 0 || size < 0)) {
350 new_num_elements = -1;
351 } else {
352 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
353 CHECK_LE(0, new_num_elements);
354 }
355 UnsafeAddDim(size, new_num_elements);
356 }
357
358 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)359 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
360 const int nd = ndims_byte();
361 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
362 as16()->dims_[nd] =
363 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
364 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
365 as32()->dims_[nd] =
366 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
367 } else if (tag() == REP_OUT_OF_LINE) {
368 as64()->dims_->push_back(size);
369 } else {
370 // Need to change representation
371 gtl::InlinedVector<int64, 8> vals;
372 AppendTo(*this, &vals);
373 vals.push_back(size);
374 // We know we can't be REP16. See if we have a small enough
375 // number of dimensions and each dimension's size is small enough
376 // to allow REP32.
377 bool can_be_rep32 = (vals.size() <= 3);
378 if (can_be_rep32) {
379 for (size_t i = 0; i < vals.size(); i++) {
380 if (vals[i] >= kMaxRep32) {
381 can_be_rep32 = false;
382 break;
383 }
384 }
385 }
386 if (can_be_rep32) {
387 set_tag(REP32);
388 for (size_t d = 0; d < vals.size(); d++) {
389 as32()->dims_[d] = kIsPartial && vals[d] < 0
390 ? kUnknownRep32
391 : static_cast<uint32>(vals[d]);
392 }
393 } else {
394 set_tag(REP_OUT_OF_LINE);
395 as64()->dims_ =
396 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
397 }
398 }
399 set_ndims_byte(nd + 1);
400 set_num_elements(new_num_elements);
401 }
402
403 template <class Shape>
AppendShape(const TensorShapeBase & shape)404 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
405 for (auto d : shape) AddDim(d.size);
406 }
407
408 template <class Shape>
InsertDim(int d,int64 size)409 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
410 CHECK_GE(d, 0);
411 CHECK_LE(d, dims());
412 if (!kIsPartial) CHECK_GE(size, 0);
413 CHECK_LT(dims(), MaxDimensions());
414 gtl::InlinedVector<int64, 8> vals;
415 AppendTo(*this, &vals);
416 vals.insert(vals.begin() + d, size);
417 ClearAllButDataType();
418 for (auto dval : vals) {
419 AddDim(dval);
420 }
421 }
422
423 template <class Shape>
dim_sizes() const424 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
425 gtl::InlinedVector<int64, 4> result;
426 for (auto dim : *this) {
427 result.push_back(dim.size);
428 }
429 return result;
430 }
431
432 template <class Shape>
set_dim(int d,int64 size)433 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
434 CHECK_GE(d, 0);
435 CHECK_LT(d, dims());
436 CHECK_GE(size, 0);
437 if (tag() == REP16 && size < kMaxRep16) {
438 as16()->dims_[d] =
439 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
440 } else if (tag() == REP32 && size < kMaxRep32) {
441 as32()->dims_[d] =
442 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
443 } else if (tag() == REP_OUT_OF_LINE) {
444 (*as64()->dims_)[d] = size;
445 } else {
446 // Must upgrade
447 gtl::InlinedVector<int64, 8> vals;
448 AppendTo(*this, &vals);
449 vals[d] = size;
450 ClearAllButDataType();
451 for (auto dval : vals) {
452 AddDim(dval);
453 }
454 }
455 RecomputeNumElements();
456 }
457
458 template <class Shape>
RemoveDimRange(int begin,int end)459 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
460 if (unknown_rank()) return;
461 begin = begin < 0 ? dims() + begin + 1 : begin;
462 end = end < 0 ? dims() + end + 1 : end;
463 CHECK_GE(begin, 0);
464 CHECK_LE(begin, dims());
465 CHECK_GE(end, 0);
466 CHECK_LE(end, dims());
467 if (begin >= end) return;
468 gtl::InlinedVector<int64, 8> vals;
469 AppendTo(*this, &vals);
470 vals.erase(vals.begin() + begin, vals.begin() + end);
471 ClearAllButDataType();
472 for (auto dval : vals) {
473 AddDim(dval);
474 }
475 RecomputeNumElements();
476 }
477
IsSameSize(const TensorShape & b) const478 bool TensorShape::IsSameSize(const TensorShape& b) const {
479 if (b.dims() != dims()) return false;
480 for (int d = 0; d < dims(); d++) {
481 if (dim_size(d) != b.dim_size(d)) return false;
482 }
483 return true;
484 }
485
486 template <class Shape>
AsProto(TensorShapeProto * proto) const487 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
488 proto->Clear();
489 if (unknown_rank()) {
490 proto->set_unknown_rank(true);
491 } else {
492 for (int i = 0; i < dims(); i++) {
493 proto->add_dim()->set_size(dim_size(i));
494 }
495 }
496 }
497
DumpRep() const498 void TensorShapeRep::DumpRep() const {
499 #if 0
500 fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
501 if (tag() == REP16) {
502 fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
503 for (int i = 0; i < ndims_byte(); i++) {
504 fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
505 }
506 } else if (tag_ == REP32) {
507 fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
508 for (int i = 0; i < ndims_byte(); i++) {
509 fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
510 }
511 } else if (tag_ == REP_OUT_OF_LINE) {
512 fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
513 for (int i = 0; i < ndims_byte(); i++) {
514 fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
515 }
516 }
517 #endif
518 }
519
520 template <class Shape>
begin() const521 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
522 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
523 }
524
525 template <class Shape>
end() const526 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
527 CHECK(!unknown_rank());
528 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
529 }
530
DebugString() const531 string TensorShapeRep::DebugString() const {
532 const auto& shape = *static_cast<const PartialTensorShape*>(this);
533 if (shape.unknown_rank()) return "<unknown>";
534 string s = "[";
535 for (int i = 0; i < shape.dims(); i++) {
536 if (i > 0) strings::StrAppend(&s, ",");
537 int64 dim = shape.dim_size(i);
538 if (dim < 0) {
539 strings::StrAppend(&s, "?");
540 } else {
541 strings::StrAppend(&s, dim);
542 }
543 }
544 strings::StrAppend(&s, "]");
545 return s;
546 }
547
DebugString(const TensorShapeProto & proto)548 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
549 string s;
550 if (proto.unknown_rank()) {
551 strings::StrAppend(&s, "<unknown>");
552 if (proto.dim_size() == 0) return s;
553 }
554 strings::StrAppend(&s, "[");
555 bool first = true;
556 for (const auto& d : proto.dim()) {
557 if (!first) strings::StrAppend(&s, ",");
558 if (d.size() == -1) {
559 strings::StrAppend(&s, "?");
560 } else {
561 strings::StrAppend(&s, d.size());
562 }
563 first = false;
564 }
565 strings::StrAppend(&s, "]");
566 return s;
567 }
568
StartsWith(const TensorShape & shape,const TensorShape & prefix)569 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
570 const TensorShape& prefix) {
571 if (shape.dims() < prefix.dims()) return false;
572 for (int i = 0; i < prefix.dims(); ++i) {
573 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
574 }
575 return true;
576 }
577
EndsWith(const TensorShape & shape,const TensorShape & suffix)578 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
579 const TensorShape& suffix) {
580 const int suffix_size = suffix.dims();
581 if (shape.dims() < suffix_size) return false;
582 for (int i = 0; i < suffix_size; ++i) {
583 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
584 return false;
585 }
586 }
587 return true;
588 }
589
590 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)591 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
592 out->Clear();
593 if (n > TensorShape::MaxDimensions()) {
594 return errors::InvalidArgument("Too many dimensions");
595 }
596 if (n < 0) {
597 return errors::InvalidArgument("Negative number of dimensions ", n);
598 }
599 for (int64 i = 0; i < n; ++i) {
600 T dim = internal::SubtleMustCopy(dims[i]);
601 int64 new_num_elements;
602 if (dim < 0) {
603 if (!out->kIsPartial) {
604 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
605 }
606 if (dim < -1) {
607 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
608 }
609 dim = -1;
610 new_num_elements = -1;
611 } else if (out->num_elements() < 0) {
612 new_num_elements = -1;
613 } else {
614 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
615 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
616 TensorShapeProto proto;
617 for (int64 j = 0; j < n; ++j) {
618 proto.add_dim()->set_size(dim);
619 }
620 return errors::InvalidArgument(
621 "Shape ", TensorShape::DebugString(proto),
622 " would have more than 2**63 - 1 elements");
623 }
624 }
625 out->UnsafeAddDim(dim, new_num_elements);
626 }
627 return Status::OK();
628 }
629
630 #define MAKE_SHAPE(T, Shape) \
631 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \
632 return MakeShapeHelper(dims, n, out); \
633 } \
634 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
635 return MakeShapeHelper(shape.data(), shape.size(), out); \
636 }
MAKE_SHAPE(int32,TensorShape)637 MAKE_SHAPE(int32, TensorShape)
638 MAKE_SHAPE(int64, TensorShape)
639 MAKE_SHAPE(int32, PartialTensorShape)
640 MAKE_SHAPE(int64, PartialTensorShape)
641 #undef MAKE_SHAPE
642
643 string TensorShapeUtils::ShapeListString(
644 const gtl::ArraySlice<TensorShape>& shapes) {
645 string result = "[";
646 bool first = true;
647 for (const TensorShape& shape : shapes) {
648 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
649 first = false;
650 }
651 strings::StrAppend(&result, "]");
652 return result;
653 }
654
Concatenate(int64 size) const655 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
656 PartialTensorShape out = *this;
657 out.AddDim(size);
658 return out;
659 }
660
Concatenate(const PartialTensorShape & shape) const661 PartialTensorShape PartialTensorShape::Concatenate(
662 const PartialTensorShape& shape) const {
663 if (unknown_rank() || shape.unknown_rank()) {
664 return PartialTensorShape();
665 }
666 PartialTensorShape out = *this;
667 for (auto dim : shape) out.AddDim(dim.size);
668 return out;
669 }
670
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const671 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
672 PartialTensorShape* result) const {
673 if (unknown_rank()) {
674 *result = shape;
675 return Status::OK();
676 }
677 if (shape.unknown_rank()) {
678 *result = *this;
679 return Status::OK();
680 }
681 const int dims_ = dims();
682 if (dims_ != shape.dims()) {
683 return errors::InvalidArgument(
684 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
685 shape.dims());
686 }
687 CHECK(result != this);
688 result->Clear();
689 for (int i = 0; i < dims_; ++i) {
690 const int64 dim0 = dim_size(i);
691 const int64 dim1 = shape.dim_size(i);
692 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
693 return errors::InvalidArgument(
694 "PartialTensorShape: Incompatible shapes during merge: ",
695 DebugString(), " vs. ", shape.DebugString());
696 }
697 result->AddDim(dim0 >= 0 ? dim0 : dim1);
698 }
699 return Status::OK();
700 }
701
AsTensorShape(TensorShape * shape) const702 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
703 if (IsFullyDefined()) {
704 const TensorShapeRep* rep = this;
705 *shape = *static_cast<const TensorShape*>(rep);
706 return true;
707 }
708 return false;
709 }
710
IsIdenticalTo(const PartialTensorShape & shape) const711 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
712 if (unknown_rank() || shape.unknown_rank()) {
713 return unknown_rank() == shape.unknown_rank();
714 }
715 if (dims() != shape.dims()) return false;
716 for (int i = 0; i < dims(); i++) {
717 if (dim_size(i) != shape.dim_size(i)) return false;
718 }
719 return true;
720 }
721
IsCompatibleWith(const PartialTensorShape & shape) const722 bool PartialTensorShape::IsCompatibleWith(
723 const PartialTensorShape& shape) const {
724 if (unknown_rank() || shape.unknown_rank()) return true;
725 if (dims() != shape.dims()) return false;
726 for (int i = 0; i < dims(); i++) {
727 const int64 dim0 = dim_size(i);
728 const int64 dim1 = shape.dim_size(i);
729 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
730 }
731 return true;
732 }
733
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)734 string PartialTensorShapeUtils::PartialShapeListString(
735 const gtl::ArraySlice<PartialTensorShape>& shapes) {
736 string result = "[";
737 bool first = true;
738 for (const PartialTensorShape& shape : shapes) {
739 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
740 first = false;
741 }
742 strings::StrAppend(&result, "]");
743 return result;
744 }
745
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)746 bool PartialTensorShapeUtils::AreCompatible(
747 const gtl::ArraySlice<PartialTensorShape>& shapes0,
748 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
749 if (shapes0.size() == shapes1.size()) {
750 for (size_t i = 0; i < shapes0.size(); ++i) {
751 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
752 return false;
753 }
754 }
755 return true;
756 } else {
757 return false;
758 }
759 }
760
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)761 bool PartialTensorShapeUtils::AreIdentical(
762 const gtl::ArraySlice<PartialTensorShape>& shapes0,
763 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
764 if (shapes0.size() == shapes1.size()) {
765 for (size_t i = 0; i < shapes0.size(); ++i) {
766 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
767 return false;
768 }
769 }
770 return true;
771 } else {
772 return false;
773 }
774 }
775
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)776 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
777 int64* num_elements) {
778 int64 n = 1;
779 for (auto dim : shape) {
780 n = MultiplyWithoutOverflow(n, dim);
781 if (n < 0) {
782 return errors::InvalidArgument("Can't compute total size of shape [",
783 absl::StrJoin(shape, ","),
784 "]; product would overflow int64");
785 }
786 }
787 *num_elements = n;
788 return Status::OK();
789 }
790
791 template class TensorShapeBase<TensorShape>;
792 template class TensorShapeBase<PartialTensorShape>;
793
794 } // namespace tensorflow
795