• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/overflow.h"
27 
28 namespace tensorflow {
29 
30 // TensorShape and PartialTensorShape should have no fields beyond
31 // TensorShapeRep.  In particular, their sizes should be the same.
32 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
33               "TensorShape must have no fields beyond TensorShapeRep");
34 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
35               "PartialTensorShape must have no fields beyond TensorShapeRep");
36 
37 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)38 static void AppendTo(const TensorShapeBase<Shape>& s,
39                      gtl::InlinedVector<int64, 8>* vals) {
40   for (auto dim : s) {
41     vals->push_back(dim.size);
42   }
43 }
44 
CheckDimsEqual(int NDIMS) const45 void TensorShape::CheckDimsEqual(int NDIMS) const {
46   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
47                           << " from a tensor of " << dims() << " dimensions";
48 }
49 
CheckDimsAtLeast(int NDIMS) const50 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
51   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
52                           << " dimensions from a tensor of " << dims()
53                           << " dimensions";
54 }
55 
56 // TODO(slebedev): Consider merging IsValid implementations.
57 template <class Shape>
IsValid()58 bool TensorShapeBase<Shape>::IsValid() {
59   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
60   // unknown_shape() set, and it seems hard to remove this without backwards
61   // compatibility issues.
62   if (kIsPartial && unknown_rank()) return dims() == 0;
63   int64_t num_elements = 1;
64   if (dims() > MaxDimensions()) return false;
65   for (auto d : dim_sizes()) {
66     if (d < (kIsPartial ? -1 : 0)) return false;
67     if (d == -1) {
68       num_elements = -1;
69     } else if (!kIsPartial || num_elements >= 0) {
70       num_elements = MultiplyWithoutOverflow(num_elements, d);
71       if (num_elements < 0) return false;
72     }
73   }
74   return true;
75 }
76 
77 template <class Shape>
IsValid(const TensorShapeProto & proto)78 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
79   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
80   // unknown_shape() set, and it seems hard to remove this without backwards
81   // compatibility issues.
82   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
83   int64_t num_elements = 1;
84   if (proto.dim().size() > MaxDimensions()) return false;
85   for (const auto& d : proto.dim()) {
86     if (d.size() < (kIsPartial ? -1 : 0)) return false;
87     if (d.size() == -1) {
88       num_elements = -1;
89     } else if (!kIsPartial || num_elements >= 0) {
90       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
91       if (num_elements < 0) return false;
92     }
93   }
94   return true;
95 }
96 
97 template <class Shape>
IsValidShape(const TensorShapeProto & proto)98 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
99   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
100   // unknown_shape() set, and it seems hard to remove this without backwards
101   // compatibility issues.
102   if (kIsPartial && proto.unknown_rank()) {
103     if (proto.dim_size() > 0) {
104       return errors::InvalidArgument(
105           "An unknown shape must not have any dimensions set.");
106     }
107     return Status::OK();
108   }
109   int64_t num_elements = 1;
110   if (proto.dim().size() > MaxDimensions()) {
111     return errors::InvalidArgument("Shape ", DebugString(proto),
112                                    " has too many dimensions");
113   }
114   for (const auto& d : proto.dim()) {
115     if (d.size() < (kIsPartial ? -1 : 0)) {
116       if (kIsPartial) {
117         return errors::InvalidArgument(
118             "Shape ", DebugString(proto),
119             " has dimensions with values below -1 (where -1 means unknown)");
120       } else {
121         return errors::InvalidArgument("Shape ", DebugString(proto),
122                                        " is not fully defined");
123       }
124     }
125     if (d.size() == -1) {
126       num_elements = -1;
127     } else if (!kIsPartial || num_elements >= 0) {
128       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
129       if (num_elements < 0) {
130         return errors::InvalidArgument(
131             "Shape ", DebugString(proto),
132             " is too large (more than 2**63 - 1 entries)");
133       }
134     }
135   }
136   return Status::OK();
137 }
138 
139 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)140 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
141   set_tag(REP16);
142   set_data_type(DT_INVALID);
143   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
144   // unknown_shape() set, and it seems hard to remove this without backwards
145   // compatibility issues.
146   if (kIsPartial && proto.unknown_rank()) {
147     set_ndims_byte(kUnknownRank);
148     set_num_elements(-1);
149   } else {
150     set_ndims_byte(0);
151     set_num_elements(1);
152     for (const auto& d : proto.dim()) {
153       AddDim(d.size());
154     }
155   }
156 }
157 
158 template <class Shape>
BuildTensorShapeBase(const TensorShapeProto & proto,TensorShapeBase * out)159 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
160     const TensorShapeProto& proto, TensorShapeBase* out) {
161   out->set_tag(REP16);
162   out->set_data_type(DT_INVALID);
163   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
164   // unknown_shape() set, and it seems hard to remove this without backwards
165   // compatibility issues.
166   if (kIsPartial && proto.unknown_rank()) {
167     out->set_ndims_byte(kUnknownRank);
168     out->set_num_elements(-1);
169   } else {
170     out->set_ndims_byte(0);
171     out->set_num_elements(1);
172     Status s = Status::OK();
173     for (const auto& d : proto.dim()) {
174       s = out->AddDimWithStatus(d.size());
175       if (!s.ok()) {
176         return s;
177       }
178     }
179   }
180   return Status::OK();
181 }
182 
183 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)184 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
185   set_tag(REP16);
186   set_data_type(DT_INVALID);
187   TF_CHECK_OK(InitDims(dim_sizes));
188 }
189 
190 template <class Shape>
BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,TensorShapeBase * out)191 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
192     gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
193   out->set_tag(REP16);
194   out->set_data_type(DT_INVALID);
195   return out->InitDims(dim_sizes);
196 }
197 
198 // Returns true iff partial is true and val is < 0.
199 // REQUIRES: val < kMaxRep16
200 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64_t val)201 static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) {
202   if (partial) {
203     if (val < 0) {
204       dst[dim] = std::numeric_limits<uint16>::max();
205       return true;
206     }
207   }
208   dst[dim] = val;
209   return false;
210 }
211 
212 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)213 Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
214   DCHECK_EQ(tag(), REP16);
215 
216   // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
217   // below cannot overflow.
218   static const int64_t kMaxSmall = 0xd744;
219   static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
220                 "bad overflow check");
221   bool large_size = false;
222   for (auto s : dim_sizes) {
223     if (s > kMaxSmall) {
224       large_size = true;
225       break;
226     }
227   }
228 
229   if (!kIsPartial && !large_size) {
230     for (auto s : dim_sizes) {
231       if (TF_PREDICT_FALSE(s < 0)) {
232         return errors::Internal(
233             "Expected shape dimensions to be non-negative, got ", s);
234       }
235     }
236   }
237 
238   if (!large_size) {
239     // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
240     uint16* dst = as16()->dims_;
241     switch (dim_sizes.size()) {
242       case 1: {
243         set_ndims_byte(1);
244         const int64_t size = dim_sizes[0];
245         const bool neg = Set16(kIsPartial, dst, 0, size);
246         set_num_elements(neg ? -1 : size);
247         return Status::OK();
248       }
249       case 2: {
250         set_ndims_byte(2);
251         const int64_t size0 = dim_sizes[0];
252         const int64_t size1 = dim_sizes[1];
253         bool neg = Set16(kIsPartial, dst, 0, size0);
254         neg |= Set16(kIsPartial, dst, 1, size1);
255         set_num_elements(neg ? -1 : (size0 * size1));
256         return Status::OK();
257       }
258       case 3: {
259         set_ndims_byte(3);
260         const int64_t size0 = dim_sizes[0];
261         const int64_t size1 = dim_sizes[1];
262         const int64_t size2 = dim_sizes[2];
263         bool neg = Set16(kIsPartial, dst, 0, size0);
264         neg |= Set16(kIsPartial, dst, 1, size1);
265         neg |= Set16(kIsPartial, dst, 2, size2);
266         set_num_elements(neg ? -1 : (size0 * size1 * size2));
267         return Status::OK();
268       }
269       case 4: {
270         set_ndims_byte(4);
271         const int64_t size0 = dim_sizes[0];
272         const int64_t size1 = dim_sizes[1];
273         const int64_t size2 = dim_sizes[2];
274         const int64_t size3 = dim_sizes[3];
275         bool neg = Set16(kIsPartial, dst, 0, size0);
276         neg |= Set16(kIsPartial, dst, 1, size1);
277         neg |= Set16(kIsPartial, dst, 2, size2);
278         neg |= Set16(kIsPartial, dst, 3, size3);
279         set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
280         return Status::OK();
281       }
282     }
283   }
284 
285   set_ndims_byte(0);
286   set_num_elements(1);
287   Status status = Status::OK();
288   for (int64_t s : dim_sizes) {
289     status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
290     if (!status.ok()) {
291       return status;
292     }
293   }
294 
295   return status;
296 }
297 
298 template <class Shape>
TensorShapeBase()299 TensorShapeBase<Shape>::TensorShapeBase() {
300   set_tag(REP16);
301   set_data_type(DT_INVALID);
302   if (kIsPartial) {
303     set_ndims_byte(kUnknownRank);
304     set_num_elements(-1);
305   } else {
306     set_ndims_byte(0);
307     set_num_elements(1);
308   }
309 }
310 
DestructorOutOfLine()311 void TensorShapeRep::DestructorOutOfLine() {
312   DCHECK(tag() == REP_OUT_OF_LINE);
313   delete as64()->dims_;
314 }
315 
SlowCopyFrom(const TensorShapeRep & b)316 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
317   if (b.tag() != REP_OUT_OF_LINE) {
318     if (tag() == REP_OUT_OF_LINE) {
319       delete as64()->dims_;
320     }
321     memcpy(buf(), b.buf(), sizeof(u_.buf));
322     // memcpy above implicitly also does:
323     //   set_tag(b.tag());
324     //   set_ndims_byte(b.ndims_byte());
325     //   set_data_type(b.data_type());
326   } else {
327     set_ndims_byte(b.ndims_byte());
328     set_data_type(b.data_type());
329     if (tag() == REP_OUT_OF_LINE) {
330       // vector already allocated
331       *(as64()->dims_) = *(b.as64()->dims_);
332     } else {
333       set_tag(REP_OUT_OF_LINE);
334       as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
335     }
336   }
337 }
338 
339 template <class Shape>
dim_size(int d) const340 int64 TensorShapeBase<Shape>::dim_size(int d) const {
341   if (unknown_rank()) return -1;
342   DCHECK_GE(d, 0);
343   DCHECK_LT(d, dims());
344   if (tag() == REP16) {
345     uint16 dim = as16()->dims_[d];
346     if (kIsPartial && dim == kUnknownRep16) return -1;
347     return dim;
348   } else if (tag() == REP32) {
349     uint32 dim = as32()->dims_[d];
350     if (kIsPartial && dim == kUnknownRep32) return -1;
351     return dim;
352   } else {
353     return (*as64()->dims_)[d];
354   }
355 }
356 
Clear()357 void TensorShapeRep::Clear() {
358   ClearAllButDataType();
359   set_data_type(DT_INVALID);
360 }
361 
ClearAllButDataType()362 void TensorShapeRep::ClearAllButDataType() {
363   if (tag() == REP_OUT_OF_LINE) {
364     delete as64()->dims_;
365   }
366   set_tag(REP16);
367   set_ndims_byte(0);
368   // Leaves data_type alone
369   set_num_elements(1);
370 }
371 
372 template <class Shape>
RecomputeNumElements()373 Status TensorShapeBase<Shape>::RecomputeNumElements() {
374   if (unknown_rank()) {
375     set_num_elements(-1);
376     return Status::OK();
377   }
378   int64_t n = 1;
379   for (auto dim : *this) {
380     if (kIsPartial && dim.size < 0) {
381       n = -1;
382       break;
383     }
384     n = MultiplyWithoutOverflow(n, dim.size);
385     if (TF_PREDICT_FALSE(n < 0)) {
386       return errors::InvalidArgument(
387           "Shape ", this->DebugString(),
388           " results in overflow when computing number of elements");
389     }
390   }
391   set_num_elements(n);
392   return Status::OK();
393 }
394 
395 template <class Shape>
AddDim(int64_t size)396 void TensorShapeBase<Shape>::AddDim(int64_t size) {
397   if (!kIsPartial) CHECK_GE(size, 0);
398   if (unknown_rank()) return;
399   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
400   int64_t new_num_elements;
401   if (kIsPartial && (num_elements() < 0 || size < 0)) {
402     new_num_elements = -1;
403   } else {
404     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
405     CHECK_LE(0, new_num_elements);
406   }
407   UnsafeAddDim(size, new_num_elements);
408 }
409 
410 template <class Shape>
AddDimWithStatus(int64_t size)411 Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
412   if (!kIsPartial) {
413     if (TF_PREDICT_FALSE(size < 0)) {
414       return errors::Internal("Expected a non-negative size, got ", size);
415     }
416   }
417 
418   if (unknown_rank()) {
419     return Status::OK();
420   }
421 
422   if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423     return errors::Internal("Too many dimensions in tensor");
424   }
425 
426   int64_t new_num_elements;
427   if (kIsPartial && (num_elements() < 0 || size < 0)) {
428     new_num_elements = -1;
429   } else {
430     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431     if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432       return errors::Internal("Encountered overflow when multiplying ",
433                               num_elements(), " with ", size,
434                               ", result: ", new_num_elements);
435     }
436   }
437 
438   UnsafeAddDim(size, new_num_elements);
439   return Status::OK();
440 }
441 
442 template <class Shape>
UnsafeAddDim(int64_t size,int64_t new_num_elements)443 void TensorShapeBase<Shape>::UnsafeAddDim(int64_t size,
444                                           int64_t new_num_elements) {
445   const int nd = ndims_byte();
446   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
447     as16()->dims_[nd] =
448         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
449   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
450     as32()->dims_[nd] =
451         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
452   } else if (tag() == REP_OUT_OF_LINE) {
453     as64()->dims_->push_back(size);
454   } else {
455     // Need to change representation
456     gtl::InlinedVector<int64, 8> vals;
457     AppendTo(*this, &vals);
458     vals.push_back(size);
459     // We know we can't be REP16.  See if we have a small enough
460     // number of dimensions and each dimension's size is small enough
461     // to allow REP32.
462     bool can_be_rep32 = (vals.size() <= 3);
463     if (can_be_rep32) {
464       for (size_t i = 0; i < vals.size(); i++) {
465         if (vals[i] >= kMaxRep32) {
466           can_be_rep32 = false;
467           break;
468         }
469       }
470     }
471     if (can_be_rep32) {
472       set_tag(REP32);
473       for (size_t d = 0; d < vals.size(); d++) {
474         as32()->dims_[d] = kIsPartial && vals[d] < 0
475                                ? kUnknownRep32
476                                : static_cast<uint32>(vals[d]);
477       }
478     } else {
479       set_tag(REP_OUT_OF_LINE);
480       as64()->dims_ =
481           new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
482     }
483   }
484   set_ndims_byte(nd + 1);
485   set_num_elements(new_num_elements);
486 }
487 
488 template <class Shape>
AppendShape(const TensorShapeBase & shape)489 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
490   for (auto d : shape) AddDim(d.size);
491 }
492 
493 template <class Shape>
AppendShapeWithStatus(const TensorShapeBase & shape)494 Status TensorShapeBase<Shape>::AppendShapeWithStatus(
495     const TensorShapeBase& shape) {
496   Status s = Status::OK();
497   for (auto d : shape) {
498     s.Update(AddDimWithStatus(d.size));
499     if (!s.ok()) {
500       return s;
501     }
502   }
503   return s;
504 }
505 
506 template <class Shape>
InsertDim(int d,int64_t size)507 void TensorShapeBase<Shape>::InsertDim(int d, int64_t size) {
508   CHECK_GE(d, 0);
509   CHECK_LE(d, dims());
510   if (!kIsPartial) CHECK_GE(size, 0);
511   CHECK_LT(dims(), MaxDimensions());
512   gtl::InlinedVector<int64, 8> vals;
513   AppendTo(*this, &vals);
514   vals.insert(vals.begin() + d, size);
515   ClearAllButDataType();
516   for (auto dval : vals) {
517     AddDim(dval);
518   }
519 }
520 
521 template <class Shape>
InsertDimWithStatus(int d,int64_t size)522 Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) {
523   if (!kIsPartial) {
524     if (TF_PREDICT_FALSE(size < 0)) {
525       return errors::Internal("Expected a non-negative size, got ", size);
526     }
527   }
528 
529   if (TF_PREDICT_FALSE(d < 0)) {
530     return errors::Internal("The insertion index must be non-negative, got ",
531                             d);
532   }
533   if (TF_PREDICT_FALSE(d > dims())) {
534     return errors::Internal("The insertion index must be at most ", dims(),
535                             " got ", d);
536   }
537   if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
538     return errors::Internal("Shape has ", dims(),
539                             " dimensions which is the maximum allowed");
540   }
541 
542   gtl::InlinedVector<int64, 8> vals;
543   AppendTo(*this, &vals);
544   vals.insert(vals.begin() + d, size);
545   ClearAllButDataType();
546 
547   Status s = Status::OK();
548   for (auto dval : vals) {
549     s.Update(AddDimWithStatus(dval));
550     if (!s.ok()) {
551       return s;
552     }
553   }
554   return s;
555 }
556 
557 template <class Shape>
dim_sizes() const558 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
559   gtl::InlinedVector<int64, 4> result;
560   for (auto dim : *this) {
561     result.push_back(dim.size);
562   }
563   return result;
564 }
565 
566 template <class Shape>
set_dim(int d,int64_t size)567 void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
568   CHECK_GE(d, 0);
569   CHECK_LT(d, dims());
570   if (!kIsPartial) {
571     CHECK_GE(size, 0);
572   }
573   if (tag() == REP16 && size < kMaxRep16) {
574     as16()->dims_[d] =
575         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
576   } else if (tag() == REP32 && size < kMaxRep32) {
577     as32()->dims_[d] =
578         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
579   } else if (tag() == REP_OUT_OF_LINE) {
580     (*as64()->dims_)[d] = size;
581   } else {
582     // Must upgrade
583     gtl::InlinedVector<int64, 8> vals;
584     AppendTo(*this, &vals);
585     vals[d] = size;
586     ClearAllButDataType();
587     for (auto dval : vals) {
588       AddDim(dval);
589     }
590   }
591   TF_CHECK_OK(RecomputeNumElements());
592 }
593 
594 template <class Shape>
SetDimWithStatus(int d,int64_t size)595 Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
596   if (TF_PREDICT_FALSE(d < 0)) {
597     return errors::Internal("Index must be non-negative, got ", d);
598   }
599   if (TF_PREDICT_FALSE(d >= dims())) {
600     return errors::Internal("Index must be less than ", dims(), ", got ", d);
601   }
602   if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) {
603     return errors::Internal("Expected a non-negative size, got ", size);
604   }
605 
606   if (tag() == REP16 && size < kMaxRep16) {
607     as16()->dims_[d] =
608         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
609   } else if (tag() == REP32 && size < kMaxRep32) {
610     as32()->dims_[d] =
611         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
612   } else if (tag() == REP_OUT_OF_LINE) {
613     (*as64()->dims_)[d] = size;
614   } else {
615     // Must upgrade
616     gtl::InlinedVector<int64, 8> vals;
617     AppendTo(*this, &vals);
618     vals[d] = size;
619     ClearAllButDataType();
620 
621     Status s = Status::OK();
622     for (auto dval : vals) {
623       s.Update(AddDimWithStatus(dval));
624       if (!s.ok()) {
625         return s;
626       }
627     }
628   }
629 
630   return RecomputeNumElements();
631 }
632 
633 template <class Shape>
RemoveDimRange(int begin,int end)634 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
635   if (unknown_rank()) return;
636   begin = begin < 0 ? dims() + begin + 1 : begin;
637   end = end < 0 ? dims() + end + 1 : end;
638   CHECK_GE(begin, 0);
639   CHECK_LE(begin, dims());
640   CHECK_GE(end, 0);
641   CHECK_LE(end, dims());
642   if (begin >= end) return;
643   gtl::InlinedVector<int64, 8> vals;
644   AppendTo(*this, &vals);
645   vals.erase(vals.begin() + begin, vals.begin() + end);
646   ClearAllButDataType();
647   for (auto dval : vals) {
648     AddDim(dval);
649   }
650   TF_CHECK_OK(RecomputeNumElements());
651 }
652 
653 template <class Shape>
RemoveDimRangeWithStatus(int begin,int end)654 Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
655   if (unknown_rank()) {
656     return Status::OK();
657   }
658 
659   begin = begin < 0 ? dims() + begin + 1 : begin;
660   end = end < 0 ? dims() + end + 1 : end;
661 
662   if (TF_PREDICT_FALSE(begin < 0)) {
663     return errors::Internal("Start index must be non-negative, got ", begin);
664   }
665   if (TF_PREDICT_FALSE(begin > dims())) {
666     return errors::Internal("Start index must be less than ", dims(), ", got ",
667                             begin);
668   }
669   if (TF_PREDICT_FALSE(end < 0)) {
670     return errors::Internal("End index must be non-negative, got ", end);
671   }
672   if (TF_PREDICT_FALSE(end > dims())) {
673     return errors::Internal("End index must be less than ", dims(), ", got ",
674                             end);
675   }
676 
677   if (begin >= end) {
678     return Status::OK();
679   }
680 
681   gtl::InlinedVector<int64, 8> vals;
682   AppendTo(*this, &vals);
683   vals.erase(vals.begin() + begin, vals.begin() + end);
684   ClearAllButDataType();
685 
686   Status s = Status::OK();
687   for (auto dval : vals) {
688     s.Update(AddDimWithStatus(dval));
689     if (!s.ok()) {
690       return s;
691     }
692   }
693 
694   return RecomputeNumElements();
695 }
696 
IsSameSize(const TensorShape & b) const697 bool TensorShape::IsSameSize(const TensorShape& b) const {
698   if (b.dims() != dims()) return false;
699   for (int d = 0; d < dims(); d++) {
700     if (dim_size(d) != b.dim_size(d)) return false;
701   }
702   return true;
703 }
704 
705 template <class Shape>
AsProto(TensorShapeProto * proto) const706 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
707   proto->Clear();
708   if (unknown_rank()) {
709     proto->set_unknown_rank(true);
710   } else {
711     for (int i = 0; i < dims(); i++) {
712       proto->add_dim()->set_size(dim_size(i));
713     }
714   }
715 }
716 
717 template <class Shape>
begin() const718 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
719   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
720 }
721 
722 template <class Shape>
end() const723 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
724   const int max_dim = unknown_rank() ? -1 : dims();
725   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
726 }
727 
DebugString() const728 string TensorShapeRep::DebugString() const {
729   const auto& shape = *static_cast<const PartialTensorShape*>(this);
730   if (shape.unknown_rank()) return "<unknown>";
731   string s = "[";
732   for (int i = 0; i < shape.dims(); i++) {
733     if (i > 0) strings::StrAppend(&s, ",");
734     int64_t dim = shape.dim_size(i);
735     if (dim < 0) {
736       strings::StrAppend(&s, "?");
737     } else {
738       strings::StrAppend(&s, dim);
739     }
740   }
741   strings::StrAppend(&s, "]");
742   return s;
743 }
744 
DebugString(const TensorShapeProto & proto)745 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
746   string s;
747   if (proto.unknown_rank()) {
748     strings::StrAppend(&s, "<unknown>");
749     if (proto.dim_size() == 0) return s;
750   }
751   strings::StrAppend(&s, "[");
752   bool first = true;
753   for (const auto& d : proto.dim()) {
754     if (!first) strings::StrAppend(&s, ",");
755     if (d.size() == -1) {
756       strings::StrAppend(&s, "?");
757     } else {
758       strings::StrAppend(&s, d.size());
759     }
760     first = false;
761   }
762   strings::StrAppend(&s, "]");
763   return s;
764 }
765 
StartsWith(const TensorShape & shape,const TensorShape & prefix)766 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
767                                   const TensorShape& prefix) {
768   if (shape.dims() < prefix.dims()) return false;
769   for (int i = 0; i < prefix.dims(); ++i) {
770     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
771   }
772   return true;
773 }
774 
EndsWith(const TensorShape & shape,const TensorShape & suffix)775 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
776                                 const TensorShape& suffix) {
777   const int suffix_size = suffix.dims();
778   if (shape.dims() < suffix_size) return false;
779   for (int i = 0; i < suffix_size; ++i) {
780     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
781       return false;
782     }
783   }
784   return true;
785 }
786 
787 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64_t n,Shape * out)788 Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) {
789   out->Clear();
790   if (n > TensorShape::MaxDimensions()) {
791     return errors::InvalidArgument("Too many dimensions");
792   }
793   if (n < 0) {
794     return errors::InvalidArgument("Negative number of dimensions ", n);
795   }
796   for (int64_t i = 0; i < n; ++i) {
797     T dim = internal::SubtleMustCopy(dims[i]);
798     int64_t new_num_elements;
799     if (dim < 0) {
800       if (!out->kIsPartial) {
801         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
802       }
803       if (dim < -1) {
804         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
805       }
806       dim = -1;
807       new_num_elements = -1;
808     } else if (out->num_elements() < 0) {
809       new_num_elements = -1;
810     } else {
811       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
812       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
813         TensorShapeProto proto;
814         for (int64_t j = 0; j < n; ++j) {
815           proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
816         }
817         return errors::InvalidArgument(
818             "Shape ", TensorShape::DebugString(proto),
819             " would have more than 2**63 - 1 elements");
820       }
821     }
822     out->UnsafeAddDim(dim, new_num_elements);
823   }
824   return Status::OK();
825 }
826 
827 #define MAKE_SHAPE(T, Shape)                                                 \
828   Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) {   \
829     return MakeShapeHelper(dims, n, out);                                    \
830   }                                                                          \
831   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
832     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
833   }
MAKE_SHAPE(int32,TensorShape)834 MAKE_SHAPE(int32, TensorShape)
835 MAKE_SHAPE(int64, TensorShape)
836 MAKE_SHAPE(int32, PartialTensorShape)
837 MAKE_SHAPE(int64, PartialTensorShape)
838 #undef MAKE_SHAPE
839 
840 string TensorShapeUtils::ShapeListString(
841     const gtl::ArraySlice<TensorShape>& shapes) {
842   string result = "[";
843   bool first = true;
844   for (const TensorShape& shape : shapes) {
845     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
846     first = false;
847   }
848   strings::StrAppend(&result, "]");
849   return result;
850 }
851 
Concatenate(int64_t size) const852 PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const {
853   PartialTensorShape out = *this;
854   out.AddDim(size);
855   return out;
856 }
857 
ConcatenateWithStatus(int64_t size,PartialTensorShape * out) const858 Status PartialTensorShape::ConcatenateWithStatus(
859     int64_t size, PartialTensorShape* out) const {
860   out = const_cast<PartialTensorShape*>(this);
861   return out->AddDimWithStatus(size);
862 }
863 
Concatenate(const PartialTensorShape & shape) const864 PartialTensorShape PartialTensorShape::Concatenate(
865     const PartialTensorShape& shape) const {
866   if (unknown_rank() || shape.unknown_rank()) {
867     return PartialTensorShape();
868   }
869   PartialTensorShape out = *this;
870   for (auto dim : shape) out.AddDim(dim.size);
871   return out;
872 }
873 
ConcatenateWithStatus(const PartialTensorShape & shape,PartialTensorShape * out) const874 Status PartialTensorShape::ConcatenateWithStatus(
875     const PartialTensorShape& shape, PartialTensorShape* out) const {
876   if (unknown_rank() || shape.unknown_rank()) {
877     *out = PartialTensorShape();
878     return Status::OK();
879   }
880   out = const_cast<PartialTensorShape*>(this);
881   for (auto dim : shape) {
882     Status s = out->AddDimWithStatus(dim.size);
883     if (!s.ok()) return s;
884   }
885 
886   return Status::OK();
887 }
888 
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const889 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
890                                      PartialTensorShape* result) const {
891   if (unknown_rank()) {
892     *result = shape;
893     return Status::OK();
894   }
895   if (shape.unknown_rank()) {
896     *result = *this;
897     return Status::OK();
898   }
899   const int dims_ = dims();
900   if (dims_ != shape.dims()) {
901     return errors::InvalidArgument(
902         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
903         shape.dims());
904   }
905 
906   if (result == this) {
907     return errors::Internal(
908         "PartialTensorShape::MergeWith: cannot merge shape with itself");
909   }
910 
911   result->Clear();
912   Status s = Status::OK();
913   for (int i = 0; i < dims_; ++i) {
914     const int64_t dim0 = dim_size(i);
915     const int64_t dim1 = shape.dim_size(i);
916     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
917       return errors::InvalidArgument(
918           "PartialTensorShape: Incompatible shapes during merge: ",
919           DebugString(), " vs. ", shape.DebugString());
920     }
921     s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
922     if (!s.ok()) {
923       return s;
924     }
925   }
926   return Status::OK();
927 }
928 
AsTensorShape(TensorShape * shape) const929 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
930   if (IsFullyDefined()) {
931     const TensorShapeRep* rep = this;
932     *shape = *static_cast<const TensorShape*>(rep);
933     return true;
934   }
935   return false;
936 }
937 
IsIdenticalTo(const PartialTensorShape & shape) const938 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
939   if (unknown_rank() || shape.unknown_rank()) {
940     return unknown_rank() == shape.unknown_rank();
941   }
942   if (dims() != shape.dims()) return false;
943   for (int i = 0; i < dims(); i++) {
944     if (dim_size(i) != shape.dim_size(i)) return false;
945   }
946   return true;
947 }
948 
IsCompatibleWith(const PartialTensorShape & shape) const949 bool PartialTensorShape::IsCompatibleWith(
950     const PartialTensorShape& shape) const {
951   if (unknown_rank() || shape.unknown_rank()) return true;
952   if (dims() != shape.dims()) return false;
953   for (int i = 0; i < dims(); i++) {
954     const int64_t dim0 = dim_size(i);
955     const int64_t dim1 = shape.dim_size(i);
956     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
957   }
958   return true;
959 }
960 
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)961 string PartialTensorShapeUtils::PartialShapeListString(
962     const gtl::ArraySlice<PartialTensorShape>& shapes) {
963   string result = "[";
964   bool first = true;
965   for (const PartialTensorShape& shape : shapes) {
966     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
967     first = false;
968   }
969   strings::StrAppend(&result, "]");
970   return result;
971 }
972 
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)973 bool PartialTensorShapeUtils::AreCompatible(
974     const gtl::ArraySlice<PartialTensorShape>& shapes0,
975     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
976   if (shapes0.size() == shapes1.size()) {
977     for (size_t i = 0; i < shapes0.size(); ++i) {
978       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
979         return false;
980       }
981     }
982     return true;
983   } else {
984     return false;
985   }
986 }
987 
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)988 bool PartialTensorShapeUtils::AreIdentical(
989     const gtl::ArraySlice<PartialTensorShape>& shapes0,
990     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
991   if (shapes0.size() == shapes1.size()) {
992     for (size_t i = 0; i < shapes0.size(); ++i) {
993       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
994         return false;
995       }
996     }
997     return true;
998   } else {
999     return false;
1000   }
1001 }
1002 
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)1003 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
1004                                      int64* num_elements) {
1005   int64_t n = 1;
1006   for (auto dim : shape) {
1007     n = MultiplyWithoutOverflow(n, dim);
1008     if (n < 0) {
1009       return errors::InvalidArgument("Can't compute total size of shape [",
1010                                      absl::StrJoin(shape, ","),
1011                                      "]; product would overflow int64");
1012     }
1013   }
1014   *num_elements = n;
1015   return Status::OK();
1016 }
1017 
1018 template class TensorShapeBase<TensorShape>;
1019 template class TensorShapeBase<PartialTensorShape>;
1020 
1021 }  // namespace tensorflow
1022