1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_shape.h"
17
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/overflow.h"
27
28 namespace tensorflow {
29
30 // TensorShape and PartialTensorShape should have no fields beyond
31 // TensorShapeRep. In particular, their sizes should be the same.
32 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
33 "TensorShape must have no fields beyond TensorShapeRep");
34 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
35 "PartialTensorShape must have no fields beyond TensorShapeRep");
36
37 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)38 static void AppendTo(const TensorShapeBase<Shape>& s,
39 gtl::InlinedVector<int64, 8>* vals) {
40 for (auto dim : s) {
41 vals->push_back(dim.size);
42 }
43 }
44
CheckDimsEqual(int NDIMS) const45 void TensorShape::CheckDimsEqual(int NDIMS) const {
46 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
47 << " from a tensor of " << dims() << " dimensions";
48 }
49
CheckDimsAtLeast(int NDIMS) const50 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
51 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
52 << " dimensions from a tensor of " << dims()
53 << " dimensions";
54 }
55
56 // TODO(slebedev): Consider merging IsValid implementations.
57 template <class Shape>
IsValid()58 bool TensorShapeBase<Shape>::IsValid() {
59 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
60 // unknown_shape() set, and it seems hard to remove this without backwards
61 // compatibility issues.
62 if (kIsPartial && unknown_rank()) return dims() == 0;
63 int64_t num_elements = 1;
64 if (dims() > MaxDimensions()) return false;
65 for (auto d : dim_sizes()) {
66 if (d < (kIsPartial ? -1 : 0)) return false;
67 if (d == -1) {
68 num_elements = -1;
69 } else if (!kIsPartial || num_elements >= 0) {
70 num_elements = MultiplyWithoutOverflow(num_elements, d);
71 if (num_elements < 0) return false;
72 }
73 }
74 return true;
75 }
76
77 template <class Shape>
IsValid(const TensorShapeProto & proto)78 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
79 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
80 // unknown_shape() set, and it seems hard to remove this without backwards
81 // compatibility issues.
82 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
83 int64_t num_elements = 1;
84 if (proto.dim().size() > MaxDimensions()) return false;
85 for (const auto& d : proto.dim()) {
86 if (d.size() < (kIsPartial ? -1 : 0)) return false;
87 if (d.size() == -1) {
88 num_elements = -1;
89 } else if (!kIsPartial || num_elements >= 0) {
90 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
91 if (num_elements < 0) return false;
92 }
93 }
94 return true;
95 }
96
97 template <class Shape>
IsValidShape(const TensorShapeProto & proto)98 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
99 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
100 // unknown_shape() set, and it seems hard to remove this without backwards
101 // compatibility issues.
102 if (kIsPartial && proto.unknown_rank()) {
103 if (proto.dim_size() > 0) {
104 return errors::InvalidArgument(
105 "An unknown shape must not have any dimensions set.");
106 }
107 return Status::OK();
108 }
109 int64_t num_elements = 1;
110 if (proto.dim().size() > MaxDimensions()) {
111 return errors::InvalidArgument("Shape ", DebugString(proto),
112 " has too many dimensions");
113 }
114 for (const auto& d : proto.dim()) {
115 if (d.size() < (kIsPartial ? -1 : 0)) {
116 if (kIsPartial) {
117 return errors::InvalidArgument(
118 "Shape ", DebugString(proto),
119 " has dimensions with values below -1 (where -1 means unknown)");
120 } else {
121 return errors::InvalidArgument("Shape ", DebugString(proto),
122 " is not fully defined");
123 }
124 }
125 if (d.size() == -1) {
126 num_elements = -1;
127 } else if (!kIsPartial || num_elements >= 0) {
128 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
129 if (num_elements < 0) {
130 return errors::InvalidArgument(
131 "Shape ", DebugString(proto),
132 " is too large (more than 2**63 - 1 entries)");
133 }
134 }
135 }
136 return Status::OK();
137 }
138
139 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)140 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
141 set_tag(REP16);
142 set_data_type(DT_INVALID);
143 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
144 // unknown_shape() set, and it seems hard to remove this without backwards
145 // compatibility issues.
146 if (kIsPartial && proto.unknown_rank()) {
147 set_ndims_byte(kUnknownRank);
148 set_num_elements(-1);
149 } else {
150 set_ndims_byte(0);
151 set_num_elements(1);
152 for (const auto& d : proto.dim()) {
153 AddDim(d.size());
154 }
155 }
156 }
157
158 template <class Shape>
BuildTensorShapeBase(const TensorShapeProto & proto,TensorShapeBase * out)159 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
160 const TensorShapeProto& proto, TensorShapeBase* out) {
161 out->set_tag(REP16);
162 out->set_data_type(DT_INVALID);
163 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
164 // unknown_shape() set, and it seems hard to remove this without backwards
165 // compatibility issues.
166 if (kIsPartial && proto.unknown_rank()) {
167 out->set_ndims_byte(kUnknownRank);
168 out->set_num_elements(-1);
169 } else {
170 out->set_ndims_byte(0);
171 out->set_num_elements(1);
172 Status s = Status::OK();
173 for (const auto& d : proto.dim()) {
174 s = out->AddDimWithStatus(d.size());
175 if (!s.ok()) {
176 return s;
177 }
178 }
179 }
180 return Status::OK();
181 }
182
183 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)184 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
185 set_tag(REP16);
186 set_data_type(DT_INVALID);
187 TF_CHECK_OK(InitDims(dim_sizes));
188 }
189
190 template <class Shape>
BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,TensorShapeBase * out)191 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
192 gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
193 out->set_tag(REP16);
194 out->set_data_type(DT_INVALID);
195 return out->InitDims(dim_sizes);
196 }
197
198 // Returns true iff partial is true and val is < 0.
199 // REQUIRES: val < kMaxRep16
200 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64_t val)201 static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) {
202 if (partial) {
203 if (val < 0) {
204 dst[dim] = std::numeric_limits<uint16>::max();
205 return true;
206 }
207 }
208 dst[dim] = val;
209 return false;
210 }
211
212 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)213 Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
214 DCHECK_EQ(tag(), REP16);
215
216 // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
217 // below cannot overflow.
218 static const int64_t kMaxSmall = 0xd744;
219 static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
220 "bad overflow check");
221 bool large_size = false;
222 for (auto s : dim_sizes) {
223 if (s > kMaxSmall) {
224 large_size = true;
225 break;
226 }
227 }
228
229 if (!kIsPartial && !large_size) {
230 for (auto s : dim_sizes) {
231 if (TF_PREDICT_FALSE(s < 0)) {
232 return errors::Internal(
233 "Expected shape dimensions to be non-negative, got ", s);
234 }
235 }
236 }
237
238 if (!large_size) {
239 // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
240 uint16* dst = as16()->dims_;
241 switch (dim_sizes.size()) {
242 case 1: {
243 set_ndims_byte(1);
244 const int64_t size = dim_sizes[0];
245 const bool neg = Set16(kIsPartial, dst, 0, size);
246 set_num_elements(neg ? -1 : size);
247 return Status::OK();
248 }
249 case 2: {
250 set_ndims_byte(2);
251 const int64_t size0 = dim_sizes[0];
252 const int64_t size1 = dim_sizes[1];
253 bool neg = Set16(kIsPartial, dst, 0, size0);
254 neg |= Set16(kIsPartial, dst, 1, size1);
255 set_num_elements(neg ? -1 : (size0 * size1));
256 return Status::OK();
257 }
258 case 3: {
259 set_ndims_byte(3);
260 const int64_t size0 = dim_sizes[0];
261 const int64_t size1 = dim_sizes[1];
262 const int64_t size2 = dim_sizes[2];
263 bool neg = Set16(kIsPartial, dst, 0, size0);
264 neg |= Set16(kIsPartial, dst, 1, size1);
265 neg |= Set16(kIsPartial, dst, 2, size2);
266 set_num_elements(neg ? -1 : (size0 * size1 * size2));
267 return Status::OK();
268 }
269 case 4: {
270 set_ndims_byte(4);
271 const int64_t size0 = dim_sizes[0];
272 const int64_t size1 = dim_sizes[1];
273 const int64_t size2 = dim_sizes[2];
274 const int64_t size3 = dim_sizes[3];
275 bool neg = Set16(kIsPartial, dst, 0, size0);
276 neg |= Set16(kIsPartial, dst, 1, size1);
277 neg |= Set16(kIsPartial, dst, 2, size2);
278 neg |= Set16(kIsPartial, dst, 3, size3);
279 set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
280 return Status::OK();
281 }
282 }
283 }
284
285 set_ndims_byte(0);
286 set_num_elements(1);
287 Status status = Status::OK();
288 for (int64_t s : dim_sizes) {
289 status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
290 if (!status.ok()) {
291 return status;
292 }
293 }
294
295 return status;
296 }
297
298 template <class Shape>
TensorShapeBase()299 TensorShapeBase<Shape>::TensorShapeBase() {
300 set_tag(REP16);
301 set_data_type(DT_INVALID);
302 if (kIsPartial) {
303 set_ndims_byte(kUnknownRank);
304 set_num_elements(-1);
305 } else {
306 set_ndims_byte(0);
307 set_num_elements(1);
308 }
309 }
310
DestructorOutOfLine()311 void TensorShapeRep::DestructorOutOfLine() {
312 DCHECK(tag() == REP_OUT_OF_LINE);
313 delete as64()->dims_;
314 }
315
SlowCopyFrom(const TensorShapeRep & b)316 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
317 if (b.tag() != REP_OUT_OF_LINE) {
318 if (tag() == REP_OUT_OF_LINE) {
319 delete as64()->dims_;
320 }
321 memcpy(buf(), b.buf(), sizeof(u_.buf));
322 // memcpy above implicitly also does:
323 // set_tag(b.tag());
324 // set_ndims_byte(b.ndims_byte());
325 // set_data_type(b.data_type());
326 } else {
327 set_ndims_byte(b.ndims_byte());
328 set_data_type(b.data_type());
329 if (tag() == REP_OUT_OF_LINE) {
330 // vector already allocated
331 *(as64()->dims_) = *(b.as64()->dims_);
332 } else {
333 set_tag(REP_OUT_OF_LINE);
334 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
335 }
336 }
337 }
338
339 template <class Shape>
dim_size(int d) const340 int64 TensorShapeBase<Shape>::dim_size(int d) const {
341 if (unknown_rank()) return -1;
342 DCHECK_GE(d, 0);
343 DCHECK_LT(d, dims());
344 if (tag() == REP16) {
345 uint16 dim = as16()->dims_[d];
346 if (kIsPartial && dim == kUnknownRep16) return -1;
347 return dim;
348 } else if (tag() == REP32) {
349 uint32 dim = as32()->dims_[d];
350 if (kIsPartial && dim == kUnknownRep32) return -1;
351 return dim;
352 } else {
353 return (*as64()->dims_)[d];
354 }
355 }
356
Clear()357 void TensorShapeRep::Clear() {
358 ClearAllButDataType();
359 set_data_type(DT_INVALID);
360 }
361
ClearAllButDataType()362 void TensorShapeRep::ClearAllButDataType() {
363 if (tag() == REP_OUT_OF_LINE) {
364 delete as64()->dims_;
365 }
366 set_tag(REP16);
367 set_ndims_byte(0);
368 // Leaves data_type alone
369 set_num_elements(1);
370 }
371
372 template <class Shape>
RecomputeNumElements()373 Status TensorShapeBase<Shape>::RecomputeNumElements() {
374 if (unknown_rank()) {
375 set_num_elements(-1);
376 return Status::OK();
377 }
378 int64_t n = 1;
379 for (auto dim : *this) {
380 if (kIsPartial && dim.size < 0) {
381 n = -1;
382 break;
383 }
384 n = MultiplyWithoutOverflow(n, dim.size);
385 if (TF_PREDICT_FALSE(n < 0)) {
386 return errors::InvalidArgument(
387 "Shape ", this->DebugString(),
388 " results in overflow when computing number of elements");
389 }
390 }
391 set_num_elements(n);
392 return Status::OK();
393 }
394
395 template <class Shape>
AddDim(int64_t size)396 void TensorShapeBase<Shape>::AddDim(int64_t size) {
397 if (!kIsPartial) CHECK_GE(size, 0);
398 if (unknown_rank()) return;
399 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
400 int64_t new_num_elements;
401 if (kIsPartial && (num_elements() < 0 || size < 0)) {
402 new_num_elements = -1;
403 } else {
404 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
405 CHECK_LE(0, new_num_elements);
406 }
407 UnsafeAddDim(size, new_num_elements);
408 }
409
410 template <class Shape>
AddDimWithStatus(int64_t size)411 Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
412 if (!kIsPartial) {
413 if (TF_PREDICT_FALSE(size < 0)) {
414 return errors::Internal("Expected a non-negative size, got ", size);
415 }
416 }
417
418 if (unknown_rank()) {
419 return Status::OK();
420 }
421
422 if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423 return errors::Internal("Too many dimensions in tensor");
424 }
425
426 int64_t new_num_elements;
427 if (kIsPartial && (num_elements() < 0 || size < 0)) {
428 new_num_elements = -1;
429 } else {
430 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432 return errors::Internal("Encountered overflow when multiplying ",
433 num_elements(), " with ", size,
434 ", result: ", new_num_elements);
435 }
436 }
437
438 UnsafeAddDim(size, new_num_elements);
439 return Status::OK();
440 }
441
442 template <class Shape>
UnsafeAddDim(int64_t size,int64_t new_num_elements)443 void TensorShapeBase<Shape>::UnsafeAddDim(int64_t size,
444 int64_t new_num_elements) {
445 const int nd = ndims_byte();
446 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
447 as16()->dims_[nd] =
448 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
449 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
450 as32()->dims_[nd] =
451 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
452 } else if (tag() == REP_OUT_OF_LINE) {
453 as64()->dims_->push_back(size);
454 } else {
455 // Need to change representation
456 gtl::InlinedVector<int64, 8> vals;
457 AppendTo(*this, &vals);
458 vals.push_back(size);
459 // We know we can't be REP16. See if we have a small enough
460 // number of dimensions and each dimension's size is small enough
461 // to allow REP32.
462 bool can_be_rep32 = (vals.size() <= 3);
463 if (can_be_rep32) {
464 for (size_t i = 0; i < vals.size(); i++) {
465 if (vals[i] >= kMaxRep32) {
466 can_be_rep32 = false;
467 break;
468 }
469 }
470 }
471 if (can_be_rep32) {
472 set_tag(REP32);
473 for (size_t d = 0; d < vals.size(); d++) {
474 as32()->dims_[d] = kIsPartial && vals[d] < 0
475 ? kUnknownRep32
476 : static_cast<uint32>(vals[d]);
477 }
478 } else {
479 set_tag(REP_OUT_OF_LINE);
480 as64()->dims_ =
481 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
482 }
483 }
484 set_ndims_byte(nd + 1);
485 set_num_elements(new_num_elements);
486 }
487
488 template <class Shape>
AppendShape(const TensorShapeBase & shape)489 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
490 for (auto d : shape) AddDim(d.size);
491 }
492
493 template <class Shape>
AppendShapeWithStatus(const TensorShapeBase & shape)494 Status TensorShapeBase<Shape>::AppendShapeWithStatus(
495 const TensorShapeBase& shape) {
496 Status s = Status::OK();
497 for (auto d : shape) {
498 s.Update(AddDimWithStatus(d.size));
499 if (!s.ok()) {
500 return s;
501 }
502 }
503 return s;
504 }
505
506 template <class Shape>
InsertDim(int d,int64_t size)507 void TensorShapeBase<Shape>::InsertDim(int d, int64_t size) {
508 CHECK_GE(d, 0);
509 CHECK_LE(d, dims());
510 if (!kIsPartial) CHECK_GE(size, 0);
511 CHECK_LT(dims(), MaxDimensions());
512 gtl::InlinedVector<int64, 8> vals;
513 AppendTo(*this, &vals);
514 vals.insert(vals.begin() + d, size);
515 ClearAllButDataType();
516 for (auto dval : vals) {
517 AddDim(dval);
518 }
519 }
520
521 template <class Shape>
InsertDimWithStatus(int d,int64_t size)522 Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) {
523 if (!kIsPartial) {
524 if (TF_PREDICT_FALSE(size < 0)) {
525 return errors::Internal("Expected a non-negative size, got ", size);
526 }
527 }
528
529 if (TF_PREDICT_FALSE(d < 0)) {
530 return errors::Internal("The insertion index must be non-negative, got ",
531 d);
532 }
533 if (TF_PREDICT_FALSE(d > dims())) {
534 return errors::Internal("The insertion index must be at most ", dims(),
535 " got ", d);
536 }
537 if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
538 return errors::Internal("Shape has ", dims(),
539 " dimensions which is the maximum allowed");
540 }
541
542 gtl::InlinedVector<int64, 8> vals;
543 AppendTo(*this, &vals);
544 vals.insert(vals.begin() + d, size);
545 ClearAllButDataType();
546
547 Status s = Status::OK();
548 for (auto dval : vals) {
549 s.Update(AddDimWithStatus(dval));
550 if (!s.ok()) {
551 return s;
552 }
553 }
554 return s;
555 }
556
557 template <class Shape>
dim_sizes() const558 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
559 gtl::InlinedVector<int64, 4> result;
560 for (auto dim : *this) {
561 result.push_back(dim.size);
562 }
563 return result;
564 }
565
566 template <class Shape>
set_dim(int d,int64_t size)567 void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
568 CHECK_GE(d, 0);
569 CHECK_LT(d, dims());
570 if (!kIsPartial) {
571 CHECK_GE(size, 0);
572 }
573 if (tag() == REP16 && size < kMaxRep16) {
574 as16()->dims_[d] =
575 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
576 } else if (tag() == REP32 && size < kMaxRep32) {
577 as32()->dims_[d] =
578 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
579 } else if (tag() == REP_OUT_OF_LINE) {
580 (*as64()->dims_)[d] = size;
581 } else {
582 // Must upgrade
583 gtl::InlinedVector<int64, 8> vals;
584 AppendTo(*this, &vals);
585 vals[d] = size;
586 ClearAllButDataType();
587 for (auto dval : vals) {
588 AddDim(dval);
589 }
590 }
591 TF_CHECK_OK(RecomputeNumElements());
592 }
593
594 template <class Shape>
SetDimWithStatus(int d,int64_t size)595 Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
596 if (TF_PREDICT_FALSE(d < 0)) {
597 return errors::Internal("Index must be non-negative, got ", d);
598 }
599 if (TF_PREDICT_FALSE(d >= dims())) {
600 return errors::Internal("Index must be less than ", dims(), ", got ", d);
601 }
602 if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) {
603 return errors::Internal("Expected a non-negative size, got ", size);
604 }
605
606 if (tag() == REP16 && size < kMaxRep16) {
607 as16()->dims_[d] =
608 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
609 } else if (tag() == REP32 && size < kMaxRep32) {
610 as32()->dims_[d] =
611 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
612 } else if (tag() == REP_OUT_OF_LINE) {
613 (*as64()->dims_)[d] = size;
614 } else {
615 // Must upgrade
616 gtl::InlinedVector<int64, 8> vals;
617 AppendTo(*this, &vals);
618 vals[d] = size;
619 ClearAllButDataType();
620
621 Status s = Status::OK();
622 for (auto dval : vals) {
623 s.Update(AddDimWithStatus(dval));
624 if (!s.ok()) {
625 return s;
626 }
627 }
628 }
629
630 return RecomputeNumElements();
631 }
632
633 template <class Shape>
RemoveDimRange(int begin,int end)634 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
635 if (unknown_rank()) return;
636 begin = begin < 0 ? dims() + begin + 1 : begin;
637 end = end < 0 ? dims() + end + 1 : end;
638 CHECK_GE(begin, 0);
639 CHECK_LE(begin, dims());
640 CHECK_GE(end, 0);
641 CHECK_LE(end, dims());
642 if (begin >= end) return;
643 gtl::InlinedVector<int64, 8> vals;
644 AppendTo(*this, &vals);
645 vals.erase(vals.begin() + begin, vals.begin() + end);
646 ClearAllButDataType();
647 for (auto dval : vals) {
648 AddDim(dval);
649 }
650 TF_CHECK_OK(RecomputeNumElements());
651 }
652
653 template <class Shape>
RemoveDimRangeWithStatus(int begin,int end)654 Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
655 if (unknown_rank()) {
656 return Status::OK();
657 }
658
659 begin = begin < 0 ? dims() + begin + 1 : begin;
660 end = end < 0 ? dims() + end + 1 : end;
661
662 if (TF_PREDICT_FALSE(begin < 0)) {
663 return errors::Internal("Start index must be non-negative, got ", begin);
664 }
665 if (TF_PREDICT_FALSE(begin > dims())) {
666 return errors::Internal("Start index must be less than ", dims(), ", got ",
667 begin);
668 }
669 if (TF_PREDICT_FALSE(end < 0)) {
670 return errors::Internal("End index must be non-negative, got ", end);
671 }
672 if (TF_PREDICT_FALSE(end > dims())) {
673 return errors::Internal("End index must be less than ", dims(), ", got ",
674 end);
675 }
676
677 if (begin >= end) {
678 return Status::OK();
679 }
680
681 gtl::InlinedVector<int64, 8> vals;
682 AppendTo(*this, &vals);
683 vals.erase(vals.begin() + begin, vals.begin() + end);
684 ClearAllButDataType();
685
686 Status s = Status::OK();
687 for (auto dval : vals) {
688 s.Update(AddDimWithStatus(dval));
689 if (!s.ok()) {
690 return s;
691 }
692 }
693
694 return RecomputeNumElements();
695 }
696
IsSameSize(const TensorShape & b) const697 bool TensorShape::IsSameSize(const TensorShape& b) const {
698 if (b.dims() != dims()) return false;
699 for (int d = 0; d < dims(); d++) {
700 if (dim_size(d) != b.dim_size(d)) return false;
701 }
702 return true;
703 }
704
705 template <class Shape>
AsProto(TensorShapeProto * proto) const706 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
707 proto->Clear();
708 if (unknown_rank()) {
709 proto->set_unknown_rank(true);
710 } else {
711 for (int i = 0; i < dims(); i++) {
712 proto->add_dim()->set_size(dim_size(i));
713 }
714 }
715 }
716
717 template <class Shape>
begin() const718 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
719 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
720 }
721
722 template <class Shape>
end() const723 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
724 const int max_dim = unknown_rank() ? -1 : dims();
725 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
726 }
727
DebugString() const728 string TensorShapeRep::DebugString() const {
729 const auto& shape = *static_cast<const PartialTensorShape*>(this);
730 if (shape.unknown_rank()) return "<unknown>";
731 string s = "[";
732 for (int i = 0; i < shape.dims(); i++) {
733 if (i > 0) strings::StrAppend(&s, ",");
734 int64_t dim = shape.dim_size(i);
735 if (dim < 0) {
736 strings::StrAppend(&s, "?");
737 } else {
738 strings::StrAppend(&s, dim);
739 }
740 }
741 strings::StrAppend(&s, "]");
742 return s;
743 }
744
DebugString(const TensorShapeProto & proto)745 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
746 string s;
747 if (proto.unknown_rank()) {
748 strings::StrAppend(&s, "<unknown>");
749 if (proto.dim_size() == 0) return s;
750 }
751 strings::StrAppend(&s, "[");
752 bool first = true;
753 for (const auto& d : proto.dim()) {
754 if (!first) strings::StrAppend(&s, ",");
755 if (d.size() == -1) {
756 strings::StrAppend(&s, "?");
757 } else {
758 strings::StrAppend(&s, d.size());
759 }
760 first = false;
761 }
762 strings::StrAppend(&s, "]");
763 return s;
764 }
765
StartsWith(const TensorShape & shape,const TensorShape & prefix)766 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
767 const TensorShape& prefix) {
768 if (shape.dims() < prefix.dims()) return false;
769 for (int i = 0; i < prefix.dims(); ++i) {
770 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
771 }
772 return true;
773 }
774
EndsWith(const TensorShape & shape,const TensorShape & suffix)775 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
776 const TensorShape& suffix) {
777 const int suffix_size = suffix.dims();
778 if (shape.dims() < suffix_size) return false;
779 for (int i = 0; i < suffix_size; ++i) {
780 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
781 return false;
782 }
783 }
784 return true;
785 }
786
787 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64_t n,Shape * out)788 Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) {
789 out->Clear();
790 if (n > TensorShape::MaxDimensions()) {
791 return errors::InvalidArgument("Too many dimensions");
792 }
793 if (n < 0) {
794 return errors::InvalidArgument("Negative number of dimensions ", n);
795 }
796 for (int64_t i = 0; i < n; ++i) {
797 T dim = internal::SubtleMustCopy(dims[i]);
798 int64_t new_num_elements;
799 if (dim < 0) {
800 if (!out->kIsPartial) {
801 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
802 }
803 if (dim < -1) {
804 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
805 }
806 dim = -1;
807 new_num_elements = -1;
808 } else if (out->num_elements() < 0) {
809 new_num_elements = -1;
810 } else {
811 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
812 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
813 TensorShapeProto proto;
814 for (int64_t j = 0; j < n; ++j) {
815 proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
816 }
817 return errors::InvalidArgument(
818 "Shape ", TensorShape::DebugString(proto),
819 " would have more than 2**63 - 1 elements");
820 }
821 }
822 out->UnsafeAddDim(dim, new_num_elements);
823 }
824 return Status::OK();
825 }
826
827 #define MAKE_SHAPE(T, Shape) \
828 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \
829 return MakeShapeHelper(dims, n, out); \
830 } \
831 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
832 return MakeShapeHelper(shape.data(), shape.size(), out); \
833 }
MAKE_SHAPE(int32,TensorShape)834 MAKE_SHAPE(int32, TensorShape)
835 MAKE_SHAPE(int64, TensorShape)
836 MAKE_SHAPE(int32, PartialTensorShape)
837 MAKE_SHAPE(int64, PartialTensorShape)
838 #undef MAKE_SHAPE
839
840 string TensorShapeUtils::ShapeListString(
841 const gtl::ArraySlice<TensorShape>& shapes) {
842 string result = "[";
843 bool first = true;
844 for (const TensorShape& shape : shapes) {
845 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
846 first = false;
847 }
848 strings::StrAppend(&result, "]");
849 return result;
850 }
851
Concatenate(int64_t size) const852 PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const {
853 PartialTensorShape out = *this;
854 out.AddDim(size);
855 return out;
856 }
857
ConcatenateWithStatus(int64_t size,PartialTensorShape * out) const858 Status PartialTensorShape::ConcatenateWithStatus(
859 int64_t size, PartialTensorShape* out) const {
860 out = const_cast<PartialTensorShape*>(this);
861 return out->AddDimWithStatus(size);
862 }
863
Concatenate(const PartialTensorShape & shape) const864 PartialTensorShape PartialTensorShape::Concatenate(
865 const PartialTensorShape& shape) const {
866 if (unknown_rank() || shape.unknown_rank()) {
867 return PartialTensorShape();
868 }
869 PartialTensorShape out = *this;
870 for (auto dim : shape) out.AddDim(dim.size);
871 return out;
872 }
873
ConcatenateWithStatus(const PartialTensorShape & shape,PartialTensorShape * out) const874 Status PartialTensorShape::ConcatenateWithStatus(
875 const PartialTensorShape& shape, PartialTensorShape* out) const {
876 if (unknown_rank() || shape.unknown_rank()) {
877 *out = PartialTensorShape();
878 return Status::OK();
879 }
880 out = const_cast<PartialTensorShape*>(this);
881 for (auto dim : shape) {
882 Status s = out->AddDimWithStatus(dim.size);
883 if (!s.ok()) return s;
884 }
885
886 return Status::OK();
887 }
888
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const889 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
890 PartialTensorShape* result) const {
891 if (unknown_rank()) {
892 *result = shape;
893 return Status::OK();
894 }
895 if (shape.unknown_rank()) {
896 *result = *this;
897 return Status::OK();
898 }
899 const int dims_ = dims();
900 if (dims_ != shape.dims()) {
901 return errors::InvalidArgument(
902 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
903 shape.dims());
904 }
905
906 if (result == this) {
907 return errors::Internal(
908 "PartialTensorShape::MergeWith: cannot merge shape with itself");
909 }
910
911 result->Clear();
912 Status s = Status::OK();
913 for (int i = 0; i < dims_; ++i) {
914 const int64_t dim0 = dim_size(i);
915 const int64_t dim1 = shape.dim_size(i);
916 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
917 return errors::InvalidArgument(
918 "PartialTensorShape: Incompatible shapes during merge: ",
919 DebugString(), " vs. ", shape.DebugString());
920 }
921 s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
922 if (!s.ok()) {
923 return s;
924 }
925 }
926 return Status::OK();
927 }
928
AsTensorShape(TensorShape * shape) const929 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
930 if (IsFullyDefined()) {
931 const TensorShapeRep* rep = this;
932 *shape = *static_cast<const TensorShape*>(rep);
933 return true;
934 }
935 return false;
936 }
937
IsIdenticalTo(const PartialTensorShape & shape) const938 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
939 if (unknown_rank() || shape.unknown_rank()) {
940 return unknown_rank() == shape.unknown_rank();
941 }
942 if (dims() != shape.dims()) return false;
943 for (int i = 0; i < dims(); i++) {
944 if (dim_size(i) != shape.dim_size(i)) return false;
945 }
946 return true;
947 }
948
IsCompatibleWith(const PartialTensorShape & shape) const949 bool PartialTensorShape::IsCompatibleWith(
950 const PartialTensorShape& shape) const {
951 if (unknown_rank() || shape.unknown_rank()) return true;
952 if (dims() != shape.dims()) return false;
953 for (int i = 0; i < dims(); i++) {
954 const int64_t dim0 = dim_size(i);
955 const int64_t dim1 = shape.dim_size(i);
956 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
957 }
958 return true;
959 }
960
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)961 string PartialTensorShapeUtils::PartialShapeListString(
962 const gtl::ArraySlice<PartialTensorShape>& shapes) {
963 string result = "[";
964 bool first = true;
965 for (const PartialTensorShape& shape : shapes) {
966 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
967 first = false;
968 }
969 strings::StrAppend(&result, "]");
970 return result;
971 }
972
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)973 bool PartialTensorShapeUtils::AreCompatible(
974 const gtl::ArraySlice<PartialTensorShape>& shapes0,
975 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
976 if (shapes0.size() == shapes1.size()) {
977 for (size_t i = 0; i < shapes0.size(); ++i) {
978 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
979 return false;
980 }
981 }
982 return true;
983 } else {
984 return false;
985 }
986 }
987
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)988 bool PartialTensorShapeUtils::AreIdentical(
989 const gtl::ArraySlice<PartialTensorShape>& shapes0,
990 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
991 if (shapes0.size() == shapes1.size()) {
992 for (size_t i = 0; i < shapes0.size(); ++i) {
993 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
994 return false;
995 }
996 }
997 return true;
998 } else {
999 return false;
1000 }
1001 }
1002
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)1003 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
1004 int64* num_elements) {
1005 int64_t n = 1;
1006 for (auto dim : shape) {
1007 n = MultiplyWithoutOverflow(n, dim);
1008 if (n < 0) {
1009 return errors::InvalidArgument("Can't compute total size of shape [",
1010 absl::StrJoin(shape, ","),
1011 "]; product would overflow int64");
1012 }
1013 }
1014 *num_elements = n;
1015 return Status::OK();
1016 }
1017
1018 template class TensorShapeBase<TensorShape>;
1019 template class TensorShapeBase<PartialTensorShape>;
1020
1021 } // namespace tensorflow
1022