1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_shape.h"
17
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/overflow.h"
27
28 namespace tensorflow {
29
30 // TensorShape and PartialTensorShape should have no fields beyond
31 // TensorShapeRep. In particular, their sizes should be the same.
32 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
33 "TensorShape must have no fields beyond TensorShapeRep");
34 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
35 "PartialTensorShape must have no fields beyond TensorShapeRep");
36
37 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)38 static void AppendTo(const TensorShapeBase<Shape>& s,
39 gtl::InlinedVector<int64, 8>* vals) {
40 for (auto dim : s) {
41 vals->push_back(dim.size);
42 }
43 }
44
CheckDimsEqual(int NDIMS) const45 void TensorShape::CheckDimsEqual(int NDIMS) const {
46 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
47 << " from a tensor of " << dims() << " dimensions";
48 }
49
CheckDimsAtLeast(int NDIMS) const50 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
51 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
52 << " dimensions from a tensor of " << dims()
53 << " dimensions";
54 }
55
56 // TODO(slebedev): Consider merging IsValid implementations.
57 template <class Shape>
IsValid()58 bool TensorShapeBase<Shape>::IsValid() {
59 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
60 // unknown_shape() set, and it seems hard to remove this without backwards
61 // compatibility issues.
62 if (kIsPartial && unknown_rank()) return dims() == 0;
63 int64 num_elements = 1;
64 if (dims() > MaxDimensions()) return false;
65 for (auto d : dim_sizes()) {
66 if (d < (kIsPartial ? -1 : 0)) return false;
67 if (d == -1) {
68 num_elements = -1;
69 } else if (!kIsPartial || num_elements >= 0) {
70 num_elements = MultiplyWithoutOverflow(num_elements, d);
71 if (num_elements < 0) return false;
72 }
73 }
74 return true;
75 }
76
77 template <class Shape>
IsValid(const TensorShapeProto & proto)78 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
79 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
80 // unknown_shape() set, and it seems hard to remove this without backwards
81 // compatibility issues.
82 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
83 int64 num_elements = 1;
84 if (proto.dim().size() > MaxDimensions()) return false;
85 for (const auto& d : proto.dim()) {
86 if (d.size() < (kIsPartial ? -1 : 0)) return false;
87 if (d.size() == -1) {
88 num_elements = -1;
89 } else if (!kIsPartial || num_elements >= 0) {
90 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
91 if (num_elements < 0) return false;
92 }
93 }
94 return true;
95 }
96
97 template <class Shape>
IsValidShape(const TensorShapeProto & proto)98 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
99 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
100 // unknown_shape() set, and it seems hard to remove this without backwards
101 // compatibility issues.
102 if (kIsPartial && proto.unknown_rank()) {
103 if (proto.dim_size() > 0) {
104 return errors::InvalidArgument(
105 "An unknown shape must not have any dimensions set.");
106 }
107 return Status::OK();
108 }
109 int64 num_elements = 1;
110 if (proto.dim().size() > MaxDimensions()) {
111 return errors::InvalidArgument("Shape ", DebugString(proto),
112 " has too many dimensions");
113 }
114 for (const auto& d : proto.dim()) {
115 if (d.size() < (kIsPartial ? -1 : 0)) {
116 if (kIsPartial) {
117 return errors::InvalidArgument(
118 "Shape ", DebugString(proto),
119 " has dimensions with values below -1 (where -1 means unknown)");
120 } else {
121 return errors::InvalidArgument("Shape ", DebugString(proto),
122 " is not fully defined");
123 }
124 }
125 if (d.size() == -1) {
126 num_elements = -1;
127 } else if (!kIsPartial || num_elements >= 0) {
128 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
129 if (num_elements < 0) {
130 return errors::InvalidArgument(
131 "Shape ", DebugString(proto),
132 " is too large (more than 2**63 - 1 entries)");
133 }
134 }
135 }
136 return Status::OK();
137 }
138
139 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)140 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
141 set_tag(REP16);
142 set_data_type(DT_INVALID);
143 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
144 // unknown_shape() set, and it seems hard to remove this without backwards
145 // compatibility issues.
146 if (kIsPartial && proto.unknown_rank()) {
147 set_ndims_byte(kUnknownRank);
148 set_num_elements(-1);
149 } else {
150 set_ndims_byte(0);
151 set_num_elements(1);
152 for (const auto& d : proto.dim()) {
153 AddDim(d.size());
154 }
155 }
156 }
157
158 template <class Shape>
BuildTensorShapeBase(const TensorShapeProto & proto,TensorShapeBase * out)159 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
160 const TensorShapeProto& proto, TensorShapeBase* out) {
161 out->set_tag(REP16);
162 out->set_data_type(DT_INVALID);
163 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
164 // unknown_shape() set, and it seems hard to remove this without backwards
165 // compatibility issues.
166 if (kIsPartial && proto.unknown_rank()) {
167 out->set_ndims_byte(kUnknownRank);
168 out->set_num_elements(-1);
169 } else {
170 out->set_ndims_byte(0);
171 out->set_num_elements(1);
172 Status s = Status::OK();
173 for (const auto& d : proto.dim()) {
174 s = out->AddDimWithStatus(d.size());
175 if (!s.ok()) {
176 return s;
177 }
178 }
179 }
180 return Status::OK();
181 }
182
183 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)184 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
185 set_tag(REP16);
186 set_data_type(DT_INVALID);
187 TF_CHECK_OK(InitDims(dim_sizes));
188 }
189
190 template <class Shape>
BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,TensorShapeBase * out)191 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
192 gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
193 out->set_tag(REP16);
194 out->set_data_type(DT_INVALID);
195 return out->InitDims(dim_sizes);
196 }
197
198 // Returns true iff partial is true and val is < 0.
199 // REQUIRES: val < kMaxRep16
200 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)201 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
202 if (partial) {
203 if (val < 0) {
204 dst[dim] = std::numeric_limits<uint16>::max();
205 return true;
206 }
207 }
208 dst[dim] = val;
209 return false;
210 }
211
212 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)213 Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
214 DCHECK_EQ(tag(), REP16);
215
216 // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
217 // below cannot overflow.
218 static const int64 kMaxSmall = 0xd744;
219 static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
220 "bad overflow check");
221 bool large_size = false;
222 for (auto s : dim_sizes) {
223 if (s > kMaxSmall) {
224 large_size = true;
225 break;
226 }
227 }
228
229 if (!kIsPartial && !large_size) {
230 for (auto s : dim_sizes) {
231 if (TF_PREDICT_FALSE(s < 0)) {
232 return errors::Internal(
233 "Expected shape dimensions to be non-negative, got ", s);
234 }
235 }
236 }
237
238 if (!large_size) {
239 // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
240 uint16* dst = as16()->dims_;
241 switch (dim_sizes.size()) {
242 case 1: {
243 set_ndims_byte(1);
244 const int64 size = dim_sizes[0];
245 const bool neg = Set16(kIsPartial, dst, 0, size);
246 set_num_elements(neg ? -1 : size);
247 return Status::OK();
248 }
249 case 2: {
250 set_ndims_byte(2);
251 const int64 size0 = dim_sizes[0];
252 const int64 size1 = dim_sizes[1];
253 bool neg = Set16(kIsPartial, dst, 0, size0);
254 neg |= Set16(kIsPartial, dst, 1, size1);
255 set_num_elements(neg ? -1 : (size0 * size1));
256 return Status::OK();
257 }
258 case 3: {
259 set_ndims_byte(3);
260 const int64 size0 = dim_sizes[0];
261 const int64 size1 = dim_sizes[1];
262 const int64 size2 = dim_sizes[2];
263 bool neg = Set16(kIsPartial, dst, 0, size0);
264 neg |= Set16(kIsPartial, dst, 1, size1);
265 neg |= Set16(kIsPartial, dst, 2, size2);
266 set_num_elements(neg ? -1 : (size0 * size1 * size2));
267 return Status::OK();
268 }
269 case 4: {
270 set_ndims_byte(4);
271 const int64 size0 = dim_sizes[0];
272 const int64 size1 = dim_sizes[1];
273 const int64 size2 = dim_sizes[2];
274 const int64 size3 = dim_sizes[3];
275 bool neg = Set16(kIsPartial, dst, 0, size0);
276 neg |= Set16(kIsPartial, dst, 1, size1);
277 neg |= Set16(kIsPartial, dst, 2, size2);
278 neg |= Set16(kIsPartial, dst, 3, size3);
279 set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
280 return Status::OK();
281 }
282 }
283 }
284
285 set_ndims_byte(0);
286 set_num_elements(1);
287 Status status = Status::OK();
288 for (int64 s : dim_sizes) {
289 status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
290 if (!status.ok()) {
291 return status;
292 }
293 }
294
295 return status;
296 }
297
298 template <class Shape>
TensorShapeBase()299 TensorShapeBase<Shape>::TensorShapeBase() {
300 set_tag(REP16);
301 set_data_type(DT_INVALID);
302 if (kIsPartial) {
303 set_ndims_byte(kUnknownRank);
304 set_num_elements(-1);
305 } else {
306 set_ndims_byte(0);
307 set_num_elements(1);
308 }
309 }
310
DestructorOutOfLine()311 void TensorShapeRep::DestructorOutOfLine() {
312 DCHECK(tag() == REP_OUT_OF_LINE);
313 delete as64()->dims_;
314 }
315
SlowCopyFrom(const TensorShapeRep & b)316 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
317 if (b.tag() != REP_OUT_OF_LINE) {
318 if (tag() == REP_OUT_OF_LINE) {
319 delete as64()->dims_;
320 }
321 memcpy(buf(), b.buf(), sizeof(u_.buf));
322 // memcpy above implicitly also does:
323 // set_tag(b.tag());
324 // set_ndims_byte(b.ndims_byte());
325 // set_data_type(b.data_type());
326 } else {
327 set_ndims_byte(b.ndims_byte());
328 set_data_type(b.data_type());
329 if (tag() == REP_OUT_OF_LINE) {
330 // vector already allocated
331 *(as64()->dims_) = *(b.as64()->dims_);
332 } else {
333 set_tag(REP_OUT_OF_LINE);
334 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
335 }
336 }
337 }
338
339 template <class Shape>
dim_size(int d) const340 int64 TensorShapeBase<Shape>::dim_size(int d) const {
341 if (unknown_rank()) return -1;
342 DCHECK_GE(d, 0);
343 DCHECK_LT(d, dims());
344 if (tag() == REP16) {
345 uint16 dim = as16()->dims_[d];
346 if (kIsPartial && dim == kUnknownRep16) return -1;
347 return dim;
348 } else if (tag() == REP32) {
349 uint32 dim = as32()->dims_[d];
350 if (kIsPartial && dim == kUnknownRep32) return -1;
351 return dim;
352 } else {
353 return (*as64()->dims_)[d];
354 }
355 }
356
Clear()357 void TensorShapeRep::Clear() {
358 ClearAllButDataType();
359 set_data_type(DT_INVALID);
360 }
361
ClearAllButDataType()362 void TensorShapeRep::ClearAllButDataType() {
363 if (tag() == REP_OUT_OF_LINE) {
364 delete as64()->dims_;
365 }
366 set_tag(REP16);
367 set_ndims_byte(0);
368 // Leaves data_type alone
369 set_num_elements(1);
370 }
371
372 template <class Shape>
RecomputeNumElements()373 Status TensorShapeBase<Shape>::RecomputeNumElements() {
374 if (unknown_rank()) {
375 set_num_elements(-1);
376 return Status::OK();
377 }
378 int64 n = 1;
379 for (auto dim : *this) {
380 if (kIsPartial && dim.size < 0) {
381 n = -1;
382 break;
383 }
384 n = MultiplyWithoutOverflow(n, dim.size);
385 if (TF_PREDICT_FALSE(n < 0)) {
386 return errors::InvalidArgument(
387 "Shape ", this->DebugString(),
388 " results in overflow when computing number of elements");
389 }
390 }
391 set_num_elements(n);
392 return Status::OK();
393 }
394
395 template <class Shape>
AddDim(int64 size)396 void TensorShapeBase<Shape>::AddDim(int64 size) {
397 if (!kIsPartial) CHECK_GE(size, 0);
398 if (unknown_rank()) return;
399 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
400 int64 new_num_elements;
401 if (kIsPartial && (num_elements() < 0 || size < 0)) {
402 new_num_elements = -1;
403 } else {
404 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
405 CHECK_LE(0, new_num_elements);
406 }
407 UnsafeAddDim(size, new_num_elements);
408 }
409
410 template <class Shape>
AddDimWithStatus(int64 size)411 Status TensorShapeBase<Shape>::AddDimWithStatus(int64 size) {
412 if (!kIsPartial) {
413 if (TF_PREDICT_FALSE(size < 0)) {
414 return errors::Internal("Expected a non-negative size, got ", size);
415 }
416 }
417
418 if (unknown_rank()) {
419 return Status::OK();
420 }
421
422 if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423 return errors::Internal("Too many dimensions in tensor");
424 }
425
426 int64 new_num_elements;
427 if (kIsPartial && (num_elements() < 0 || size < 0)) {
428 new_num_elements = -1;
429 } else {
430 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432 return errors::Internal("Encountered overflow when multiplying ",
433 num_elements(), " with ", size,
434 ", result: ", new_num_elements);
435 }
436 }
437
438 UnsafeAddDim(size, new_num_elements);
439 return Status::OK();
440 }
441
442 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)443 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
444 const int nd = ndims_byte();
445 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
446 as16()->dims_[nd] =
447 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
448 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
449 as32()->dims_[nd] =
450 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
451 } else if (tag() == REP_OUT_OF_LINE) {
452 as64()->dims_->push_back(size);
453 } else {
454 // Need to change representation
455 gtl::InlinedVector<int64, 8> vals;
456 AppendTo(*this, &vals);
457 vals.push_back(size);
458 // We know we can't be REP16. See if we have a small enough
459 // number of dimensions and each dimension's size is small enough
460 // to allow REP32.
461 bool can_be_rep32 = (vals.size() <= 3);
462 if (can_be_rep32) {
463 for (size_t i = 0; i < vals.size(); i++) {
464 if (vals[i] >= kMaxRep32) {
465 can_be_rep32 = false;
466 break;
467 }
468 }
469 }
470 if (can_be_rep32) {
471 set_tag(REP32);
472 for (size_t d = 0; d < vals.size(); d++) {
473 as32()->dims_[d] = kIsPartial && vals[d] < 0
474 ? kUnknownRep32
475 : static_cast<uint32>(vals[d]);
476 }
477 } else {
478 set_tag(REP_OUT_OF_LINE);
479 as64()->dims_ =
480 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
481 }
482 }
483 set_ndims_byte(nd + 1);
484 set_num_elements(new_num_elements);
485 }
486
487 template <class Shape>
AppendShape(const TensorShapeBase & shape)488 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
489 for (auto d : shape) AddDim(d.size);
490 }
491
492 template <class Shape>
AppendShapeWithStatus(const TensorShapeBase & shape)493 Status TensorShapeBase<Shape>::AppendShapeWithStatus(
494 const TensorShapeBase& shape) {
495 Status s = Status::OK();
496 for (auto d : shape) {
497 s.Update(AddDimWithStatus(d.size));
498 if (!s.ok()) {
499 return s;
500 }
501 }
502 return s;
503 }
504
505 template <class Shape>
InsertDim(int d,int64 size)506 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
507 CHECK_GE(d, 0);
508 CHECK_LE(d, dims());
509 if (!kIsPartial) CHECK_GE(size, 0);
510 CHECK_LT(dims(), MaxDimensions());
511 gtl::InlinedVector<int64, 8> vals;
512 AppendTo(*this, &vals);
513 vals.insert(vals.begin() + d, size);
514 ClearAllButDataType();
515 for (auto dval : vals) {
516 AddDim(dval);
517 }
518 }
519
520 template <class Shape>
InsertDimWithStatus(int d,int64 size)521 Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64 size) {
522 if (!kIsPartial) {
523 if (TF_PREDICT_FALSE(size < 0)) {
524 return errors::Internal("Expected a non-negative size, got ", size);
525 }
526 }
527
528 if (TF_PREDICT_FALSE(d < 0)) {
529 return errors::Internal("The insertion index must be non-negative, got ",
530 d);
531 }
532 if (TF_PREDICT_FALSE(d > dims())) {
533 return errors::Internal("The insertion index must be at most ", dims(),
534 " got ", d);
535 }
536 if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
537 return errors::Internal("Shape has ", dims(),
538 " dimensions which is the maximum allowed");
539 }
540
541 gtl::InlinedVector<int64, 8> vals;
542 AppendTo(*this, &vals);
543 vals.insert(vals.begin() + d, size);
544 ClearAllButDataType();
545
546 Status s = Status::OK();
547 for (auto dval : vals) {
548 s.Update(AddDimWithStatus(dval));
549 if (!s.ok()) {
550 return s;
551 }
552 }
553 return s;
554 }
555
556 template <class Shape>
dim_sizes() const557 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
558 gtl::InlinedVector<int64, 4> result;
559 for (auto dim : *this) {
560 result.push_back(dim.size);
561 }
562 return result;
563 }
564
565 template <class Shape>
set_dim(int d,int64 size)566 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
567 CHECK_GE(d, 0);
568 CHECK_LT(d, dims());
569 CHECK_GE(size, 0);
570 if (tag() == REP16 && size < kMaxRep16) {
571 as16()->dims_[d] =
572 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
573 } else if (tag() == REP32 && size < kMaxRep32) {
574 as32()->dims_[d] =
575 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
576 } else if (tag() == REP_OUT_OF_LINE) {
577 (*as64()->dims_)[d] = size;
578 } else {
579 // Must upgrade
580 gtl::InlinedVector<int64, 8> vals;
581 AppendTo(*this, &vals);
582 vals[d] = size;
583 ClearAllButDataType();
584 for (auto dval : vals) {
585 AddDim(dval);
586 }
587 }
588 TF_CHECK_OK(RecomputeNumElements());
589 }
590
591 template <class Shape>
SetDimWithStatus(int d,int64 size)592 Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64 size) {
593 if (TF_PREDICT_FALSE(d < 0)) {
594 return errors::Internal("Index must be non-negative, got ", d);
595 }
596 if (TF_PREDICT_FALSE(d >= dims())) {
597 return errors::Internal("Index must be less than ", dims(), ", got ", d);
598 }
599 if (TF_PREDICT_FALSE(size < 0)) {
600 return errors::Internal("Expected a non-negative size, got ", size);
601 }
602
603 if (tag() == REP16 && size < kMaxRep16) {
604 as16()->dims_[d] =
605 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
606 } else if (tag() == REP32 && size < kMaxRep32) {
607 as32()->dims_[d] =
608 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
609 } else if (tag() == REP_OUT_OF_LINE) {
610 (*as64()->dims_)[d] = size;
611 } else {
612 // Must upgrade
613 gtl::InlinedVector<int64, 8> vals;
614 AppendTo(*this, &vals);
615 vals[d] = size;
616 ClearAllButDataType();
617
618 Status s = Status::OK();
619 for (auto dval : vals) {
620 s.Update(AddDimWithStatus(dval));
621 if (!s.ok()) {
622 return s;
623 }
624 }
625 }
626
627 return RecomputeNumElements();
628 }
629
630 template <class Shape>
RemoveDimRange(int begin,int end)631 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
632 if (unknown_rank()) return;
633 begin = begin < 0 ? dims() + begin + 1 : begin;
634 end = end < 0 ? dims() + end + 1 : end;
635 CHECK_GE(begin, 0);
636 CHECK_LE(begin, dims());
637 CHECK_GE(end, 0);
638 CHECK_LE(end, dims());
639 if (begin >= end) return;
640 gtl::InlinedVector<int64, 8> vals;
641 AppendTo(*this, &vals);
642 vals.erase(vals.begin() + begin, vals.begin() + end);
643 ClearAllButDataType();
644 for (auto dval : vals) {
645 AddDim(dval);
646 }
647 TF_CHECK_OK(RecomputeNumElements());
648 }
649
650 template <class Shape>
RemoveDimRangeWithStatus(int begin,int end)651 Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
652 if (unknown_rank()) {
653 return Status::OK();
654 }
655
656 begin = begin < 0 ? dims() + begin + 1 : begin;
657 end = end < 0 ? dims() + end + 1 : end;
658
659 if (TF_PREDICT_FALSE(begin < 0)) {
660 return errors::Internal("Start index must be non-negative, got ", begin);
661 }
662 if (TF_PREDICT_FALSE(begin > dims())) {
663 return errors::Internal("Start index must be less than ", dims(), ", got ",
664 begin);
665 }
666 if (TF_PREDICT_FALSE(end < 0)) {
667 return errors::Internal("End index must be non-negative, got ", end);
668 }
669 if (TF_PREDICT_FALSE(end > dims())) {
670 return errors::Internal("End index must be less than ", dims(), ", got ",
671 end);
672 }
673
674 if (begin >= end) {
675 return Status::OK();
676 }
677
678 gtl::InlinedVector<int64, 8> vals;
679 AppendTo(*this, &vals);
680 vals.erase(vals.begin() + begin, vals.begin() + end);
681 ClearAllButDataType();
682
683 Status s = Status::OK();
684 for (auto dval : vals) {
685 s.Update(AddDimWithStatus(dval));
686 if (!s.ok()) {
687 return s;
688 }
689 }
690
691 return RecomputeNumElements();
692 }
693
IsSameSize(const TensorShape & b) const694 bool TensorShape::IsSameSize(const TensorShape& b) const {
695 if (b.dims() != dims()) return false;
696 for (int d = 0; d < dims(); d++) {
697 if (dim_size(d) != b.dim_size(d)) return false;
698 }
699 return true;
700 }
701
702 template <class Shape>
AsProto(TensorShapeProto * proto) const703 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
704 proto->Clear();
705 if (unknown_rank()) {
706 proto->set_unknown_rank(true);
707 } else {
708 for (int i = 0; i < dims(); i++) {
709 proto->add_dim()->set_size(dim_size(i));
710 }
711 }
712 }
713
714 template <class Shape>
begin() const715 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
716 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
717 }
718
719 template <class Shape>
end() const720 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
721 const int max_dim = unknown_rank() ? -1 : dims();
722 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
723 }
724
DebugString() const725 string TensorShapeRep::DebugString() const {
726 const auto& shape = *static_cast<const PartialTensorShape*>(this);
727 if (shape.unknown_rank()) return "<unknown>";
728 string s = "[";
729 for (int i = 0; i < shape.dims(); i++) {
730 if (i > 0) strings::StrAppend(&s, ",");
731 int64 dim = shape.dim_size(i);
732 if (dim < 0) {
733 strings::StrAppend(&s, "?");
734 } else {
735 strings::StrAppend(&s, dim);
736 }
737 }
738 strings::StrAppend(&s, "]");
739 return s;
740 }
741
DebugString(const TensorShapeProto & proto)742 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
743 string s;
744 if (proto.unknown_rank()) {
745 strings::StrAppend(&s, "<unknown>");
746 if (proto.dim_size() == 0) return s;
747 }
748 strings::StrAppend(&s, "[");
749 bool first = true;
750 for (const auto& d : proto.dim()) {
751 if (!first) strings::StrAppend(&s, ",");
752 if (d.size() == -1) {
753 strings::StrAppend(&s, "?");
754 } else {
755 strings::StrAppend(&s, d.size());
756 }
757 first = false;
758 }
759 strings::StrAppend(&s, "]");
760 return s;
761 }
762
StartsWith(const TensorShape & shape,const TensorShape & prefix)763 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
764 const TensorShape& prefix) {
765 if (shape.dims() < prefix.dims()) return false;
766 for (int i = 0; i < prefix.dims(); ++i) {
767 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
768 }
769 return true;
770 }
771
EndsWith(const TensorShape & shape,const TensorShape & suffix)772 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
773 const TensorShape& suffix) {
774 const int suffix_size = suffix.dims();
775 if (shape.dims() < suffix_size) return false;
776 for (int i = 0; i < suffix_size; ++i) {
777 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
778 return false;
779 }
780 }
781 return true;
782 }
783
784 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)785 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
786 out->Clear();
787 if (n > TensorShape::MaxDimensions()) {
788 return errors::InvalidArgument("Too many dimensions");
789 }
790 if (n < 0) {
791 return errors::InvalidArgument("Negative number of dimensions ", n);
792 }
793 for (int64 i = 0; i < n; ++i) {
794 T dim = internal::SubtleMustCopy(dims[i]);
795 int64 new_num_elements;
796 if (dim < 0) {
797 if (!out->kIsPartial) {
798 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
799 }
800 if (dim < -1) {
801 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
802 }
803 dim = -1;
804 new_num_elements = -1;
805 } else if (out->num_elements() < 0) {
806 new_num_elements = -1;
807 } else {
808 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
809 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
810 TensorShapeProto proto;
811 for (int64 j = 0; j < n; ++j) {
812 proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
813 }
814 return errors::InvalidArgument(
815 "Shape ", TensorShape::DebugString(proto),
816 " would have more than 2**63 - 1 elements");
817 }
818 }
819 out->UnsafeAddDim(dim, new_num_elements);
820 }
821 return Status::OK();
822 }
823
824 #define MAKE_SHAPE(T, Shape) \
825 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \
826 return MakeShapeHelper(dims, n, out); \
827 } \
828 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
829 return MakeShapeHelper(shape.data(), shape.size(), out); \
830 }
MAKE_SHAPE(int32,TensorShape)831 MAKE_SHAPE(int32, TensorShape)
832 MAKE_SHAPE(int64, TensorShape)
833 MAKE_SHAPE(int32, PartialTensorShape)
834 MAKE_SHAPE(int64, PartialTensorShape)
835 #undef MAKE_SHAPE
836
837 string TensorShapeUtils::ShapeListString(
838 const gtl::ArraySlice<TensorShape>& shapes) {
839 string result = "[";
840 bool first = true;
841 for (const TensorShape& shape : shapes) {
842 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
843 first = false;
844 }
845 strings::StrAppend(&result, "]");
846 return result;
847 }
848
Concatenate(int64 size) const849 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
850 PartialTensorShape out = *this;
851 out.AddDim(size);
852 return out;
853 }
854
ConcatenateWithStatus(int64 size,PartialTensorShape * out) const855 Status PartialTensorShape::ConcatenateWithStatus(
856 int64 size, PartialTensorShape* out) const {
857 out = const_cast<PartialTensorShape*>(this);
858 return out->AddDimWithStatus(size);
859 }
860
Concatenate(const PartialTensorShape & shape) const861 PartialTensorShape PartialTensorShape::Concatenate(
862 const PartialTensorShape& shape) const {
863 if (unknown_rank() || shape.unknown_rank()) {
864 return PartialTensorShape();
865 }
866 PartialTensorShape out = *this;
867 for (auto dim : shape) out.AddDim(dim.size);
868 return out;
869 }
870
ConcatenateWithStatus(const PartialTensorShape & shape,PartialTensorShape * out) const871 Status PartialTensorShape::ConcatenateWithStatus(
872 const PartialTensorShape& shape, PartialTensorShape* out) const {
873 if (unknown_rank() || shape.unknown_rank()) {
874 *out = PartialTensorShape();
875 return Status::OK();
876 }
877 out = const_cast<PartialTensorShape*>(this);
878 for (auto dim : shape) {
879 Status s = out->AddDimWithStatus(dim.size);
880 if (!s.ok()) return s;
881 }
882
883 return Status::OK();
884 }
885
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const886 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
887 PartialTensorShape* result) const {
888 if (unknown_rank()) {
889 *result = shape;
890 return Status::OK();
891 }
892 if (shape.unknown_rank()) {
893 *result = *this;
894 return Status::OK();
895 }
896 const int dims_ = dims();
897 if (dims_ != shape.dims()) {
898 return errors::InvalidArgument(
899 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
900 shape.dims());
901 }
902
903 if (result == this) {
904 return errors::Internal(
905 "PartialTensorShape::MergeWith: cannot merge shape with itself");
906 }
907
908 result->Clear();
909 Status s = Status::OK();
910 for (int i = 0; i < dims_; ++i) {
911 const int64 dim0 = dim_size(i);
912 const int64 dim1 = shape.dim_size(i);
913 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
914 return errors::InvalidArgument(
915 "PartialTensorShape: Incompatible shapes during merge: ",
916 DebugString(), " vs. ", shape.DebugString());
917 }
918 s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
919 if (!s.ok()) {
920 return s;
921 }
922 }
923 return Status::OK();
924 }
925
AsTensorShape(TensorShape * shape) const926 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
927 if (IsFullyDefined()) {
928 const TensorShapeRep* rep = this;
929 *shape = *static_cast<const TensorShape*>(rep);
930 return true;
931 }
932 return false;
933 }
934
IsIdenticalTo(const PartialTensorShape & shape) const935 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
936 if (unknown_rank() || shape.unknown_rank()) {
937 return unknown_rank() == shape.unknown_rank();
938 }
939 if (dims() != shape.dims()) return false;
940 for (int i = 0; i < dims(); i++) {
941 if (dim_size(i) != shape.dim_size(i)) return false;
942 }
943 return true;
944 }
945
IsCompatibleWith(const PartialTensorShape & shape) const946 bool PartialTensorShape::IsCompatibleWith(
947 const PartialTensorShape& shape) const {
948 if (unknown_rank() || shape.unknown_rank()) return true;
949 if (dims() != shape.dims()) return false;
950 for (int i = 0; i < dims(); i++) {
951 const int64 dim0 = dim_size(i);
952 const int64 dim1 = shape.dim_size(i);
953 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
954 }
955 return true;
956 }
957
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)958 string PartialTensorShapeUtils::PartialShapeListString(
959 const gtl::ArraySlice<PartialTensorShape>& shapes) {
960 string result = "[";
961 bool first = true;
962 for (const PartialTensorShape& shape : shapes) {
963 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
964 first = false;
965 }
966 strings::StrAppend(&result, "]");
967 return result;
968 }
969
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)970 bool PartialTensorShapeUtils::AreCompatible(
971 const gtl::ArraySlice<PartialTensorShape>& shapes0,
972 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
973 if (shapes0.size() == shapes1.size()) {
974 for (size_t i = 0; i < shapes0.size(); ++i) {
975 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
976 return false;
977 }
978 }
979 return true;
980 } else {
981 return false;
982 }
983 }
984
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)985 bool PartialTensorShapeUtils::AreIdentical(
986 const gtl::ArraySlice<PartialTensorShape>& shapes0,
987 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
988 if (shapes0.size() == shapes1.size()) {
989 for (size_t i = 0; i < shapes0.size(); ++i) {
990 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
991 return false;
992 }
993 }
994 return true;
995 } else {
996 return false;
997 }
998 }
999
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)1000 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
1001 int64* num_elements) {
1002 int64 n = 1;
1003 for (auto dim : shape) {
1004 n = MultiplyWithoutOverflow(n, dim);
1005 if (n < 0) {
1006 return errors::InvalidArgument("Can't compute total size of shape [",
1007 absl::StrJoin(shape, ","),
1008 "]; product would overflow int64");
1009 }
1010 }
1011 *num_elements = n;
1012 return Status::OK();
1013 }
1014
1015 template class TensorShapeBase<TensorShape>;
1016 template class TensorShapeBase<PartialTensorShape>;
1017
1018 } // namespace tensorflow
1019