• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/util/overflow.h"
25 
26 namespace tensorflow {
27 
28 // TensorShape and PartialTensorShape should have no fields beyond
29 // TensorShapeRep.  In particular, their sizes should be the same.
30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
31               "TensorShape must have no fields beyond TensorShapeRep");
32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
33               "PartialTensorShape must have no fields beyond TensorShapeRep");
34 
35 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)36 static void AppendTo(const TensorShapeBase<Shape>& s,
37                      gtl::InlinedVector<int64, 8>* vals) {
38   for (auto dim : s) {
39     vals->push_back(dim.size);
40   }
41 }
42 
CheckDimsEqual(int NDIMS) const43 void TensorShape::CheckDimsEqual(int NDIMS) const {
44   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
45                           << " from a tensor of " << dims() << " dimensions";
46 }
47 
CheckDimsAtLeast(int NDIMS) const48 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
49   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
50                           << " dimensions from a tensor of " << dims()
51                           << " dimensions";
52 }
53 
54 // TODO(slebedev): Consider merging IsValid implementations.
55 template <class Shape>
IsValid()56 bool TensorShapeBase<Shape>::IsValid() {
57   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
58   // unknown_shape() set, and it seems hard to remove this without backwards
59   // compatibility issues.
60   if (kIsPartial && unknown_rank()) return dims() == 0;
61   int64 num_elements = 1;
62   if (dims() > MaxDimensions()) return false;
63   for (auto d : dim_sizes()) {
64     if (d < (kIsPartial ? -1 : 0)) return false;
65     if (d == -1) {
66       num_elements = -1;
67     } else if (!kIsPartial || num_elements >= 0) {
68       num_elements = MultiplyWithoutOverflow(num_elements, d);
69       if (num_elements < 0) return false;
70     }
71   }
72   return true;
73 }
74 
75 template <class Shape>
IsValid(const TensorShapeProto & proto)76 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
77   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
78   // unknown_shape() set, and it seems hard to remove this without backwards
79   // compatibility issues.
80   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
81   int64 num_elements = 1;
82   if (proto.dim().size() > MaxDimensions()) return false;
83   for (const auto& d : proto.dim()) {
84     if (d.size() < (kIsPartial ? -1 : 0)) return false;
85     if (d.size() == -1) {
86       num_elements = -1;
87     } else if (!kIsPartial || num_elements >= 0) {
88       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
89       if (num_elements < 0) return false;
90     }
91   }
92   return true;
93 }
94 
95 template <class Shape>
IsValidShape(const TensorShapeProto & proto)96 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
97   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
98   // unknown_shape() set, and it seems hard to remove this without backwards
99   // compatibility issues.
100   if (kIsPartial && proto.unknown_rank()) {
101     if (proto.dim_size() > 0) {
102       return errors::InvalidArgument(
103           "An unknown shape must not have any dimensions set.");
104     }
105     return Status::OK();
106   }
107   int64 num_elements = 1;
108   if (proto.dim().size() > MaxDimensions()) {
109     return errors::InvalidArgument("Shape ", DebugString(proto),
110                                    " has too many dimensions");
111   }
112   for (const auto& d : proto.dim()) {
113     if (d.size() < (kIsPartial ? -1 : 0)) {
114       if (kIsPartial) {
115         return errors::InvalidArgument(
116             "Shape ", DebugString(proto),
117             " has dimensions with values below -1 (where -1 means unknown)");
118       } else {
119         return errors::InvalidArgument("Shape ", DebugString(proto),
120                                        " is not fully defined");
121       }
122     }
123     if (d.size() == -1) {
124       num_elements = -1;
125     } else if (!kIsPartial || num_elements >= 0) {
126       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
127       if (num_elements < 0) {
128         return errors::InvalidArgument(
129             "Shape ", DebugString(proto),
130             " is too large (more than 2**63 - 1 entries)");
131       }
132     }
133   }
134   return Status::OK();
135 }
136 
137 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)138 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
139   set_tag(REP16);
140   set_data_type(DT_INVALID);
141   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
142   // unknown_shape() set, and it seems hard to remove this without backwards
143   // compatibility issues.
144   if (kIsPartial && proto.unknown_rank()) {
145     set_ndims_byte(kUnknownRank);
146     set_num_elements(-1);
147   } else {
148     set_ndims_byte(0);
149     set_num_elements(1);
150     for (const auto& d : proto.dim()) {
151       AddDim(d.size());
152     }
153   }
154 }
155 
156 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)157 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
158   set_tag(REP16);
159   set_data_type(DT_INVALID);
160   InitDims(dim_sizes);
161 }
162 
163 // Returns true iff partial is true and val is < 0.
164 // REQUIRES: val < kMaxRep16
165 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)166 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
167   if (partial) {
168     if (val < 0) {
169       dst[dim] = std::numeric_limits<uint16>::max();
170       return true;
171     }
172   } else {
173     CHECK_GE(val, 0);
174   }
175   dst[dim] = val;
176   return false;
177 }
178 
179 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)180 void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
181   DCHECK_EQ(tag(), REP16);
182 
183   // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
184   // below cannot overflow.
185   static const uint64 kMaxSmall = 0xd744;
186   static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
187                 "bad overflow check");
188   bool large_size = false;
189   for (auto s : dim_sizes) {
190     if (s > kMaxSmall) {
191       large_size = true;
192       break;
193     }
194   }
195 
196   if (!large_size) {
197     // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
198     uint16* dst = as16()->dims_;
199     switch (dim_sizes.size()) {
200       case 1: {
201         set_ndims_byte(1);
202         const int64 size = dim_sizes[0];
203         const bool neg = Set16(kIsPartial, dst, 0, size);
204         set_num_elements(neg ? -1 : size);
205         return;
206       }
207       case 2: {
208         set_ndims_byte(2);
209         const int64 size0 = dim_sizes[0];
210         const int64 size1 = dim_sizes[1];
211         bool neg = Set16(kIsPartial, dst, 0, size0);
212         neg |= Set16(kIsPartial, dst, 1, size1);
213         set_num_elements(neg ? -1 : (size0 * size1));
214         return;
215       }
216       case 3: {
217         set_ndims_byte(3);
218         const int64 size0 = dim_sizes[0];
219         const int64 size1 = dim_sizes[1];
220         const int64 size2 = dim_sizes[2];
221         bool neg = Set16(kIsPartial, dst, 0, size0);
222         neg |= Set16(kIsPartial, dst, 1, size1);
223         neg |= Set16(kIsPartial, dst, 2, size2);
224         set_num_elements(neg ? -1 : (size0 * size1 * size2));
225         return;
226       }
227       case 4: {
228         set_ndims_byte(4);
229         const int64 size0 = dim_sizes[0];
230         const int64 size1 = dim_sizes[1];
231         const int64 size2 = dim_sizes[2];
232         const int64 size3 = dim_sizes[3];
233         bool neg = Set16(kIsPartial, dst, 0, size0);
234         neg |= Set16(kIsPartial, dst, 1, size1);
235         neg |= Set16(kIsPartial, dst, 2, size2);
236         neg |= Set16(kIsPartial, dst, 3, size3);
237         set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
238         return;
239       }
240     }
241   }
242 
243   set_ndims_byte(0);
244   set_num_elements(1);
245   for (int64 s : dim_sizes) {
246     AddDim(internal::SubtleMustCopy(s));
247   }
248 }
249 
250 template <class Shape>
TensorShapeBase()251 TensorShapeBase<Shape>::TensorShapeBase() {
252   set_tag(REP16);
253   set_data_type(DT_INVALID);
254   if (kIsPartial) {
255     set_ndims_byte(kUnknownRank);
256     set_num_elements(-1);
257   } else {
258     set_ndims_byte(0);
259     set_num_elements(1);
260   }
261 }
262 
DestructorOutOfLine()263 void TensorShapeRep::DestructorOutOfLine() {
264   DCHECK(tag() == REP_OUT_OF_LINE);
265   delete as64()->dims_;
266 }
267 
SlowCopyFrom(const TensorShapeRep & b)268 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
269   if (b.tag() != REP_OUT_OF_LINE) {
270     if (tag() == REP_OUT_OF_LINE) {
271       delete as64()->dims_;
272     }
273     memcpy(buf(), b.buf(), sizeof(u_.buf));
274     // memcpy above implicitly also does:
275     //   set_tag(b.tag());
276     //   set_ndims_byte(b.ndims_byte());
277     //   set_data_type(b.data_type());
278   } else {
279     DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
280     set_ndims_byte(b.ndims_byte());
281     set_data_type(b.data_type());
282     if (tag() == REP_OUT_OF_LINE) {
283       // vector already allocated
284       *(as64()->dims_) = *(b.as64()->dims_);
285     } else {
286       set_tag(REP_OUT_OF_LINE);
287       as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
288     }
289   }
290 }
291 
292 template <class Shape>
dim_size(int d) const293 int64 TensorShapeBase<Shape>::dim_size(int d) const {
294   if (unknown_rank()) return -1;
295   DCHECK_GE(d, 0);
296   DCHECK_LT(d, dims());
297   if (tag() == REP16) {
298     uint16 dim = as16()->dims_[d];
299     if (kIsPartial && dim == kUnknownRep16) return -1;
300     return dim;
301   } else if (tag() == REP32) {
302     uint32 dim = as32()->dims_[d];
303     if (kIsPartial && dim == kUnknownRep32) return -1;
304     return dim;
305   } else {
306     return (*as64()->dims_)[d];
307   }
308 }
309 
Clear()310 void TensorShapeRep::Clear() {
311   ClearAllButDataType();
312   set_data_type(DT_INVALID);
313 }
314 
ClearAllButDataType()315 void TensorShapeRep::ClearAllButDataType() {
316   if (tag() == REP_OUT_OF_LINE) {
317     delete as64()->dims_;
318   }
319   set_tag(REP16);
320   set_ndims_byte(0);
321   // Leaves data_type alone
322   set_num_elements(1);
323 }
324 
325 template <class Shape>
RecomputeNumElements()326 void TensorShapeBase<Shape>::RecomputeNumElements() {
327   if (unknown_rank()) {
328     set_num_elements(-1);
329     return;
330   }
331   int64 n = 1;
332   for (auto dim : *this) {
333     if (kIsPartial && dim.size < 0) {
334       n = -1;
335       break;
336     }
337     n = MultiplyWithoutOverflow(n, dim.size);
338     CHECK_LE(0, n);
339   }
340   set_num_elements(n);
341 }
342 
343 template <class Shape>
AddDim(int64 size)344 void TensorShapeBase<Shape>::AddDim(int64 size) {
345   if (!kIsPartial) CHECK_GE(size, 0);
346   if (unknown_rank()) return;
347   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
348   int64 new_num_elements;
349   if (kIsPartial && (num_elements() < 0 || size < 0)) {
350     new_num_elements = -1;
351   } else {
352     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
353     CHECK_LE(0, new_num_elements);
354   }
355   UnsafeAddDim(size, new_num_elements);
356 }
357 
358 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)359 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
360   const int nd = ndims_byte();
361   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
362     as16()->dims_[nd] =
363         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
364   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
365     as32()->dims_[nd] =
366         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
367   } else if (tag() == REP_OUT_OF_LINE) {
368     as64()->dims_->push_back(size);
369   } else {
370     // Need to change representation
371     gtl::InlinedVector<int64, 8> vals;
372     AppendTo(*this, &vals);
373     vals.push_back(size);
374     // We know we can't be REP16.  See if we have a small enough
375     // number of dimensions and each dimension's size is small enough
376     // to allow REP32.
377     bool can_be_rep32 = (vals.size() <= 3);
378     if (can_be_rep32) {
379       for (size_t i = 0; i < vals.size(); i++) {
380         if (vals[i] >= kMaxRep32) {
381           can_be_rep32 = false;
382           break;
383         }
384       }
385     }
386     if (can_be_rep32) {
387       set_tag(REP32);
388       for (size_t d = 0; d < vals.size(); d++) {
389         as32()->dims_[d] = kIsPartial && vals[d] < 0
390                                ? kUnknownRep32
391                                : static_cast<uint32>(vals[d]);
392       }
393     } else {
394       set_tag(REP_OUT_OF_LINE);
395       as64()->dims_ =
396           new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
397     }
398   }
399   set_ndims_byte(nd + 1);
400   set_num_elements(new_num_elements);
401 }
402 
403 template <class Shape>
AppendShape(const TensorShapeBase & shape)404 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
405   for (auto d : shape) AddDim(d.size);
406 }
407 
408 template <class Shape>
InsertDim(int d,int64 size)409 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
410   CHECK_GE(d, 0);
411   CHECK_LE(d, dims());
412   if (!kIsPartial) CHECK_GE(size, 0);
413   CHECK_LT(dims(), MaxDimensions());
414   gtl::InlinedVector<int64, 8> vals;
415   AppendTo(*this, &vals);
416   vals.insert(vals.begin() + d, size);
417   ClearAllButDataType();
418   for (auto dval : vals) {
419     AddDim(dval);
420   }
421 }
422 
423 template <class Shape>
dim_sizes() const424 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
425   gtl::InlinedVector<int64, 4> result;
426   for (auto dim : *this) {
427     result.push_back(dim.size);
428   }
429   return result;
430 }
431 
432 template <class Shape>
set_dim(int d,int64 size)433 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
434   CHECK_GE(d, 0);
435   CHECK_LT(d, dims());
436   CHECK_GE(size, 0);
437   if (tag() == REP16 && size < kMaxRep16) {
438     as16()->dims_[d] =
439         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
440   } else if (tag() == REP32 && size < kMaxRep32) {
441     as32()->dims_[d] =
442         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
443   } else if (tag() == REP_OUT_OF_LINE) {
444     (*as64()->dims_)[d] = size;
445   } else {
446     // Must upgrade
447     gtl::InlinedVector<int64, 8> vals;
448     AppendTo(*this, &vals);
449     vals[d] = size;
450     ClearAllButDataType();
451     for (auto dval : vals) {
452       AddDim(dval);
453     }
454   }
455   RecomputeNumElements();
456 }
457 
458 template <class Shape>
RemoveDimRange(int begin,int end)459 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
460   if (unknown_rank()) return;
461   begin = begin < 0 ? dims() + begin + 1 : begin;
462   end = end < 0 ? dims() + end + 1 : end;
463   CHECK_GE(begin, 0);
464   CHECK_LE(begin, dims());
465   CHECK_GE(end, 0);
466   CHECK_LE(end, dims());
467   if (begin >= end) return;
468   gtl::InlinedVector<int64, 8> vals;
469   AppendTo(*this, &vals);
470   vals.erase(vals.begin() + begin, vals.begin() + end);
471   ClearAllButDataType();
472   for (auto dval : vals) {
473     AddDim(dval);
474   }
475   RecomputeNumElements();
476 }
477 
IsSameSize(const TensorShape & b) const478 bool TensorShape::IsSameSize(const TensorShape& b) const {
479   if (b.dims() != dims()) return false;
480   for (int d = 0; d < dims(); d++) {
481     if (dim_size(d) != b.dim_size(d)) return false;
482   }
483   return true;
484 }
485 
486 template <class Shape>
AsProto(TensorShapeProto * proto) const487 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
488   proto->Clear();
489   if (unknown_rank()) {
490     proto->set_unknown_rank(true);
491   } else {
492     for (int i = 0; i < dims(); i++) {
493       proto->add_dim()->set_size(dim_size(i));
494     }
495   }
496 }
497 
DumpRep() const498 void TensorShapeRep::DumpRep() const {
499 #if 0
500   fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
501   if (tag() == REP16) {
502     fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
503     for (int i = 0; i < ndims_byte(); i++) {
504       fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
505     }
506   } else if (tag_ == REP32) {
507     fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
508     for (int i = 0; i < ndims_byte(); i++) {
509       fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
510     }
511   } else if (tag_ == REP_OUT_OF_LINE) {
512     fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
513     for (int i = 0; i < ndims_byte(); i++) {
514       fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
515     }
516   }
517 #endif
518 }
519 
520 template <class Shape>
begin() const521 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
522   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
523 }
524 
525 template <class Shape>
end() const526 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
527   CHECK(!unknown_rank());
528   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
529 }
530 
DebugString() const531 string TensorShapeRep::DebugString() const {
532   const auto& shape = *static_cast<const PartialTensorShape*>(this);
533   if (shape.unknown_rank()) return "<unknown>";
534   string s = "[";
535   for (int i = 0; i < shape.dims(); i++) {
536     if (i > 0) strings::StrAppend(&s, ",");
537     int64 dim = shape.dim_size(i);
538     if (dim < 0) {
539       strings::StrAppend(&s, "?");
540     } else {
541       strings::StrAppend(&s, dim);
542     }
543   }
544   strings::StrAppend(&s, "]");
545   return s;
546 }
547 
DebugString(const TensorShapeProto & proto)548 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
549   string s;
550   if (proto.unknown_rank()) {
551     strings::StrAppend(&s, "<unknown>");
552     if (proto.dim_size() == 0) return s;
553   }
554   strings::StrAppend(&s, "[");
555   bool first = true;
556   for (const auto& d : proto.dim()) {
557     if (!first) strings::StrAppend(&s, ",");
558     if (d.size() == -1) {
559       strings::StrAppend(&s, "?");
560     } else {
561       strings::StrAppend(&s, d.size());
562     }
563     first = false;
564   }
565   strings::StrAppend(&s, "]");
566   return s;
567 }
568 
StartsWith(const TensorShape & shape,const TensorShape & prefix)569 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
570                                   const TensorShape& prefix) {
571   if (shape.dims() < prefix.dims()) return false;
572   for (int i = 0; i < prefix.dims(); ++i) {
573     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
574   }
575   return true;
576 }
577 
EndsWith(const TensorShape & shape,const TensorShape & suffix)578 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
579                                 const TensorShape& suffix) {
580   const int suffix_size = suffix.dims();
581   if (shape.dims() < suffix_size) return false;
582   for (int i = 0; i < suffix_size; ++i) {
583     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
584       return false;
585     }
586   }
587   return true;
588 }
589 
590 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)591 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
592   out->Clear();
593   if (n > TensorShape::MaxDimensions()) {
594     return errors::InvalidArgument("Too many dimensions");
595   }
596   if (n < 0) {
597     return errors::InvalidArgument("Negative number of dimensions ", n);
598   }
599   for (int64 i = 0; i < n; ++i) {
600     T dim = internal::SubtleMustCopy(dims[i]);
601     int64 new_num_elements;
602     if (dim < 0) {
603       if (!out->kIsPartial) {
604         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
605       }
606       if (dim < -1) {
607         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
608       }
609       dim = -1;
610       new_num_elements = -1;
611     } else if (out->num_elements() < 0) {
612       new_num_elements = -1;
613     } else {
614       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
615       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
616         TensorShapeProto proto;
617         for (int64 j = 0; j < n; ++j) {
618           proto.add_dim()->set_size(dim);
619         }
620         return errors::InvalidArgument(
621             "Shape ", TensorShape::DebugString(proto),
622             " would have more than 2**63 - 1 elements");
623       }
624     }
625     out->UnsafeAddDim(dim, new_num_elements);
626   }
627   return Status::OK();
628 }
629 
630 #define MAKE_SHAPE(T, Shape)                                                 \
631   Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) {   \
632     return MakeShapeHelper(dims, n, out);                                    \
633   }                                                                          \
634   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
635     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
636   }
MAKE_SHAPE(int32,TensorShape)637 MAKE_SHAPE(int32, TensorShape)
638 MAKE_SHAPE(int64, TensorShape)
639 MAKE_SHAPE(int32, PartialTensorShape)
640 MAKE_SHAPE(int64, PartialTensorShape)
641 #undef MAKE_SHAPE
642 
643 string TensorShapeUtils::ShapeListString(
644     const gtl::ArraySlice<TensorShape>& shapes) {
645   string result = "[";
646   bool first = true;
647   for (const TensorShape& shape : shapes) {
648     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
649     first = false;
650   }
651   strings::StrAppend(&result, "]");
652   return result;
653 }
654 
Concatenate(int64 size) const655 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
656   PartialTensorShape out = *this;
657   out.AddDim(size);
658   return out;
659 }
660 
Concatenate(const PartialTensorShape & shape) const661 PartialTensorShape PartialTensorShape::Concatenate(
662     const PartialTensorShape& shape) const {
663   if (unknown_rank() || shape.unknown_rank()) {
664     return PartialTensorShape();
665   }
666   PartialTensorShape out = *this;
667   for (auto dim : shape) out.AddDim(dim.size);
668   return out;
669 }
670 
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const671 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
672                                      PartialTensorShape* result) const {
673   if (unknown_rank()) {
674     *result = shape;
675     return Status::OK();
676   }
677   if (shape.unknown_rank()) {
678     *result = *this;
679     return Status::OK();
680   }
681   const int dims_ = dims();
682   if (dims_ != shape.dims()) {
683     return errors::InvalidArgument(
684         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
685         shape.dims());
686   }
687   CHECK(result != this);
688   result->Clear();
689   for (int i = 0; i < dims_; ++i) {
690     const int64 dim0 = dim_size(i);
691     const int64 dim1 = shape.dim_size(i);
692     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
693       return errors::InvalidArgument(
694           "PartialTensorShape: Incompatible shapes during merge: ",
695           DebugString(), " vs. ", shape.DebugString());
696     }
697     result->AddDim(dim0 >= 0 ? dim0 : dim1);
698   }
699   return Status::OK();
700 }
701 
AsTensorShape(TensorShape * shape) const702 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
703   if (IsFullyDefined()) {
704     const TensorShapeRep* rep = this;
705     *shape = *static_cast<const TensorShape*>(rep);
706     return true;
707   }
708   return false;
709 }
710 
IsIdenticalTo(const PartialTensorShape & shape) const711 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
712   if (unknown_rank() || shape.unknown_rank()) {
713     return unknown_rank() == shape.unknown_rank();
714   }
715   if (dims() != shape.dims()) return false;
716   for (int i = 0; i < dims(); i++) {
717     if (dim_size(i) != shape.dim_size(i)) return false;
718   }
719   return true;
720 }
721 
IsCompatibleWith(const PartialTensorShape & shape) const722 bool PartialTensorShape::IsCompatibleWith(
723     const PartialTensorShape& shape) const {
724   if (unknown_rank() || shape.unknown_rank()) return true;
725   if (dims() != shape.dims()) return false;
726   for (int i = 0; i < dims(); i++) {
727     const int64 dim0 = dim_size(i);
728     const int64 dim1 = shape.dim_size(i);
729     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
730   }
731   return true;
732 }
733 
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)734 string PartialTensorShapeUtils::PartialShapeListString(
735     const gtl::ArraySlice<PartialTensorShape>& shapes) {
736   string result = "[";
737   bool first = true;
738   for (const PartialTensorShape& shape : shapes) {
739     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
740     first = false;
741   }
742   strings::StrAppend(&result, "]");
743   return result;
744 }
745 
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)746 bool PartialTensorShapeUtils::AreCompatible(
747     const gtl::ArraySlice<PartialTensorShape>& shapes0,
748     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
749   if (shapes0.size() == shapes1.size()) {
750     for (size_t i = 0; i < shapes0.size(); ++i) {
751       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
752         return false;
753       }
754     }
755     return true;
756   } else {
757     return false;
758   }
759 }
760 
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)761 bool PartialTensorShapeUtils::AreIdentical(
762     const gtl::ArraySlice<PartialTensorShape>& shapes0,
763     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
764   if (shapes0.size() == shapes1.size()) {
765     for (size_t i = 0; i < shapes0.size(); ++i) {
766       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
767         return false;
768       }
769     }
770     return true;
771   } else {
772     return false;
773   }
774 }
775 
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)776 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
777                                      int64* num_elements) {
778   int64 n = 1;
779   for (auto dim : shape) {
780     n = MultiplyWithoutOverflow(n, dim);
781     if (n < 0) {
782       return errors::InvalidArgument("Can't compute total size of shape [",
783                                      absl::StrJoin(shape, ","),
784                                      "]; product would overflow int64");
785     }
786   }
787   *num_elements = n;
788   return Status::OK();
789 }
790 
791 template class TensorShapeBase<TensorShape>;
792 template class TensorShapeBase<PartialTensorShape>;
793 
794 }  // namespace tensorflow
795