• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/tensor_shape.pb.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/random/simple_philox.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 
27 namespace tensorflow {
28 class TensorShapeTestHelper {
29  public:
set_data_type(TensorShape * s,DataType t)30   static void set_data_type(TensorShape* s, DataType t) { s->set_data_type(t); }
data_type(const TensorShape * s)31   static uint8 data_type(const TensorShape* s) { return s->data_type(); }
32 };
33 
34 namespace {
35 
TEST(TensorShapeTest,Default)36 TEST(TensorShapeTest, Default) {
37   // The default TensorShape constructor constructs a shape of 0-dim
38   // and 1-element.
39   TensorShape s;
40   EXPECT_EQ(s.dims(), 0);
41   EXPECT_EQ(s.num_elements(), 1);
42 }
43 
TEST(TensorShapeTest,set_dim)44 TEST(TensorShapeTest, set_dim) {
45   TensorShape s({10, 5});
46 
47   s.set_dim(0, 20);
48   ASSERT_EQ(2, s.dims());
49   EXPECT_EQ(20, s.dim_size(0));
50   EXPECT_EQ(100, s.num_elements());
51 
52   s.set_dim(1, 2);
53   ASSERT_EQ(2, s.dims());
54   EXPECT_EQ(2, s.dim_size(1));
55   EXPECT_EQ(40, s.num_elements());
56 }
57 
TEST(TensorShapeTest,RemoveDim)58 TEST(TensorShapeTest, RemoveDim) {
59   TensorShape s({10, 5});
60   s.RemoveDim(0);
61   EXPECT_EQ(5, s.num_elements());
62   ASSERT_EQ(1, s.dims());
63 }
64 
TEST(TensorShapeTest,RemoveAndAddDim)65 TEST(TensorShapeTest, RemoveAndAddDim) {
66   TensorShape s({10, 5, 20});
67   s.RemoveDim(1);
68   s.AddDim(100);
69 
70   EXPECT_EQ(20000, s.num_elements());
71   ASSERT_EQ(3, s.dims());
72 }
73 
TEST(TensorShapeTest,RemoveLastDims)74 TEST(TensorShapeTest, RemoveLastDims) {
75   TensorShape s({2, 3, 5, 7});
76   s.RemoveLastDims(1);
77 
78   ASSERT_EQ(3, s.dims());
79   EXPECT_EQ(30, s.num_elements());
80 
81   s.RemoveLastDims(2);
82   ASSERT_EQ(1, s.dims());
83   EXPECT_EQ(2, s.dim_size(0));
84 }
85 
TEST(TensorShapeTest,RemoveDimRange)86 TEST(TensorShapeTest, RemoveDimRange) {
87   TensorShape s0({2, 3, 5, 7});
88   // Empty interval => noop.
89   for (int i = -4; i <= 4; ++i) {
90     s0.RemoveDimRange(i, i);
91     ASSERT_EQ(4, s0.dims());
92     ASSERT_EQ(210, s0.num_elements());
93   }
94 
95   // Positive begin and end.
96   s0.RemoveDimRange(3, 1);  // Empty interval.
97   ASSERT_EQ(4, s0.dims());
98   ASSERT_EQ(210, s0.num_elements());
99   s0.RemoveDimRange(0, 3);
100   ASSERT_EQ(1, s0.dims());
101   EXPECT_EQ(7, s0.dim_size(0));
102   TensorShape s1({2, 3, 5, 7});
103   s1.RemoveDimRange(2, 3);
104   ASSERT_EQ(3, s1.dims());
105   ASSERT_EQ(42, s1.num_elements());
106 
107   // Negative begin or end.
108   TensorShape s2({2, 3, 5, 7});
109   s2.RemoveDimRange(-2, -3);  // Empty interval.
110   ASSERT_EQ(4, s2.dims());
111   ASSERT_EQ(210, s2.num_elements());
112   s2.RemoveDimRange(0, -2);
113   ASSERT_EQ(1, s2.dims());
114   ASSERT_EQ(7, s2.dim_size(0));
115   TensorShape s3({2, 3, 5, 7});
116   s3.RemoveDimRange(-3, -2);
117   ASSERT_EQ(3, s3.dims());
118   ASSERT_EQ(42, s3.num_elements());
119 }
120 
TEST(TensorShapeTest,InvalidShapeProto)121 TEST(TensorShapeTest, InvalidShapeProto) {
122   TensorShapeProto proto;
123   EXPECT_TRUE(TensorShape::IsValid(proto));
124 
125   proto.add_dim()->set_size(357);
126   proto.add_dim()->set_size(982);
127   EXPECT_TRUE(TensorShape::IsValid(proto));
128 
129   proto.Clear();
130   proto.add_dim()->set_size(-357);
131   proto.add_dim()->set_size(-982);
132   EXPECT_FALSE(TensorShape::IsValid(proto));
133 
134   proto.Clear();
135   proto.add_dim()->set_size(1LL << 35);
136   proto.add_dim()->set_size((1LL << 35) + 1);
137   EXPECT_FALSE(TensorShape::IsValid(proto));
138 }
139 
TEST(TensorShapeTest,TooManyDimsProto)140 TEST(TensorShapeTest, TooManyDimsProto) {
141   TensorShapeProto proto;
142   // Deliberate redundancy to ensure that both paths work.
143   EXPECT_TRUE(TensorShape::IsValid(proto));
144   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
145   for (int i = 0; i < TensorShape::MaxDimensions(); i++) {
146     proto.add_dim()->set_size(1);
147   }
148   EXPECT_TRUE(TensorShape::IsValid(proto));
149   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
150   proto.add_dim()->set_size(1);
151   EXPECT_FALSE(TensorShape::IsValid(proto));
152   EXPECT_FALSE(TensorShape::IsValidShape(proto).ok());
153 }
154 
TEST(TensorShapeTest,SetDimForEmptyTensor)155 TEST(TensorShapeTest, SetDimForEmptyTensor) {
156   TensorShape s({10, 5, 20});
157   EXPECT_EQ(1000, s.num_elements());
158   s.set_dim(1, 0);
159   EXPECT_EQ(0, s.num_elements());
160   s.set_dim(1, 7);
161   EXPECT_EQ(1400, s.num_elements());
162 }
163 
TEST(TensorShapeTest,AppendShape64BitIndices)164 TEST(TensorShapeTest, AppendShape64BitIndices) {
165   TensorShape s({10, 2147483648});
166 
167   EXPECT_EQ(10, s.dim_size(0));
168   EXPECT_EQ(2147483648, s.dim_size(1));
169 
170   TensorShape s2;
171   s2.AppendShape(s);
172   EXPECT_EQ(10, s2.dim_size(0));
173   EXPECT_EQ(2147483648, s2.dim_size(1));
174 }
175 
TEST(TensorShapeTest,DataType)176 TEST(TensorShapeTest, DataType) {
177   TensorShape s({});
178   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INVALID);
179   TensorShapeTestHelper::set_data_type(&s, DT_INT32);
180   s.AddDim(1);
181   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
182   s.AddDim(100000);
183   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
184   TensorShapeTestHelper::set_data_type(&s, DT_UINT16_REF);
185   s.AddDim(2);
186   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
187   s.AddDim(4);
188   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
189   s.AddDim(3);
190   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
191 
192   TensorShape s2 = s;
193   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
194   s2.RemoveDim(2);
195   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
196   TensorShapeTestHelper::set_data_type(&s2, DT_FLOAT);
197   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_FLOAT);
198   s2.Clear();
199   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_INVALID);
200 }
201 
TEST(TensorShapeTest,ostream)202 TEST(TensorShapeTest, ostream) {
203   TensorShape s({10, 5, 4});
204   std::stringstream ss;
205   ss << s;
206   EXPECT_EQ(ss.str(), "[10,5,4]");
207 }
208 
TEST(TensorShapeTest,AddDimWithStatus)209 TEST(TensorShapeTest, AddDimWithStatus) {
210   TensorShape s({10, 5, 20});
211   Status status = s.AddDimWithStatus(400);
212   EXPECT_TRUE(status.ok());
213   EXPECT_EQ(400000, s.num_elements());
214   ASSERT_EQ(4, s.dims());
215 
216   status = s.AddDimWithStatus(-1);
217   EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
218 }
219 
TEST(TensorShapeTest,Factory)220 TEST(TensorShapeTest, Factory) {
221   TensorShape s;
222   Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s);
223   EXPECT_TRUE(status.ok());
224   EXPECT_EQ(1000, s.num_elements());
225   ASSERT_EQ(3, s.dims());
226 
227   status = TensorShape::BuildTensorShapeBase({-10, 5, 20}, &s);
228   EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
229 }
230 
231 // -----------------------------------------------------------------------
232 // An old implementation of TensorShape using a different representation,
233 // preserved here in the unittest to allow us to have a randomized unittest
234 // that makes sure the behavior of TensorShape and TensorShapeOld are
235 // the same.
236 class TensorShapeIterOld;  // Declared below
237 
238 /// Manages the dimensions of a Tensor and their sizes.
239 class TensorShapeOld {
240  public:
241   /// \brief Construct a `TensorShape` from the provided sizes.
242   /// REQUIRES: `dim_sizes[i] >= 0`
243   explicit TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes);
TensorShapeOld(std::initializer_list<int64_t> dim_sizes)244   TensorShapeOld(std::initializer_list<int64_t> dim_sizes)
245       : TensorShapeOld(gtl::ArraySlice<int64_t>(dim_sizes)) {}
246 
247   /// REQUIRES: `IsValid(proto)`
248   explicit TensorShapeOld(const TensorShapeProto& proto);
249 
250   /// Create a tensor shape with no dimensions and one element, which you can
251   /// then call `AddDim()` on.
252   TensorShapeOld();
253 
254   /// Returns `true` iff `proto` is a valid tensor shape.
255   static bool IsValid(const TensorShapeProto& proto);
256 
257   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
258   /// status otherwise.
259   static Status IsValidShape(const TensorShapeProto& proto);
260 
261   /// Clear a tensor shape
262   void Clear();
263 
264   /// \brief Add a dimension to the end ("inner-most").
265   /// REQUIRES: `size >= 0`
266   void AddDim(int64_t size);
267 
268   /// Appends all the dimensions from `shape`.
269   void AppendShape(const TensorShapeOld& shape);
270 
271   /// \brief Insert a dimension somewhere in the `TensorShape`.
272   /// REQUIRES: `0 <= d <= dims()`
273   /// REQUIRES: `size >= 0`
274   void InsertDim(int d, int64_t size);
275 
276   /// \brief Modifies the size of the dimension `d` to be `size`
277   /// REQUIRES: `0 <= d < dims()`
278   /// REQUIRES: `size >= 0`
279   void set_dim(int d, int64_t size);
280 
281   /// \brief Removes dimension `d` from the `TensorShape`.
282   /// REQUIRES: `0 <= d < dims()`
283   void RemoveDim(int d);
284 
285   /// Return the number of dimensions in the tensor.
dims() const286   int dims() const { return dim_sizes_.size(); }
287 
288   /// \brief Returns the number of elements in dimension `d`.
289   /// REQUIRES: `0 <= d < dims()`
290   // TODO(touts): Rename to `dimension()` to match
291   // `Eigen::Tensor::dimension()`?
dim_size(int d) const292   int64_t dim_size(int d) const {
293     DCHECK_GE(d, 0);
294     DCHECK_LT(d, dims());
295     return dim_sizes_[d];
296   }
297 
298   /// Returns sizes of all dimensions.
dim_sizes() const299   gtl::ArraySlice<int64_t> dim_sizes() const { return dim_sizes_; }
300 
301   /// \brief Returns the number of elements in the tensor.
302   ///
303   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
304   /// which uses `ptrdiff_t`.
num_elements() const305   int64_t num_elements() const { return num_elements_; }
306 
307   /// Returns true if `*this` and `b` have the same sizes. Ignores
308   /// dimension names.
309   bool IsSameSize(const TensorShapeOld& b) const;
operator ==(const TensorShapeOld & b) const310   bool operator==(const TensorShapeOld& b) const { return IsSameSize(b); }
311 
312   /// Fill `*proto` from `*this`.
313   void AsProto(TensorShapeProto* proto) const;
314 
315   /// Fill `*dsizes` from `*this`.
316   template <int NDIMS>
317   Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
318 
319   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
320   /// which case we pad the rest of the sizes with 1.
321   template <int NDIMS>
322   Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
323 
324   /// For iterating through the dimensions.
325   TensorShapeIterOld begin() const;
326   TensorShapeIterOld end() const;
327 
328   /// For error messages.
329   string DebugString() const;
330 
331   /// Same as `TensorShape(proto).DebugString()` but doesn't crash for
332   /// invalid protos.
333   static string DebugString(const TensorShapeProto& proto);
334 
335  private:
336   // Recalculates the dimensions of this tensor after they are modified.
337   void recompute_dims();
338 
339   // TODO(josh11b): Maybe use something from the Eigen Tensor library
340   // for the sizes.
341   gtl::InlinedVector<int64_t, 4> dim_sizes_;
342 
343   // total number of elements (avoids recomputing it each time).
344   int64_t num_elements_;
345 };
346 
347 struct TensorShapeDimOld {
TensorShapeDimOldtensorflow::__anon3caac34d0111::TensorShapeDimOld348   explicit TensorShapeDimOld(int64_t s) : size(s) {}
349   int64_t size;
350 };
351 
352 class TensorShapeIterOld {
353  public:
TensorShapeIterOld(const TensorShapeOld * shape,int d)354   TensorShapeIterOld(const TensorShapeOld* shape, int d)
355       : shape_(shape), d_(d) {}
operator ==(const TensorShapeIterOld & rhs)356   bool operator==(const TensorShapeIterOld& rhs) {
357     DCHECK(shape_ == rhs.shape_);
358     return d_ == rhs.d_;
359   }
operator !=(const TensorShapeIterOld & rhs)360   bool operator!=(const TensorShapeIterOld& rhs) {
361     DCHECK(shape_ == rhs.shape_);
362     return d_ != rhs.d_;
363   }
operator ++()364   void operator++() { ++d_; }
operator *()365   TensorShapeDimOld operator*() {
366     return TensorShapeDimOld(shape_->dim_size(d_));
367   }
368 
369  private:
370   const TensorShapeOld* shape_;
371   int d_;
372 };
373 
374 // An upper limit of the total number of elements in a tensor.
375 static const int64_t kMaxElements = (1LL << 40);
376 
IsValid(const TensorShapeProto & proto)377 bool TensorShapeOld::IsValid(const TensorShapeProto& proto) {
378   int64_t num_elements = 1;
379   for (const auto& d : proto.dim()) {
380     if (d.size() < 0) return false;
381     num_elements *= d.size();
382     if (num_elements > kMaxElements) return false;
383   }
384   return true;
385 }
386 
IsValidShape(const TensorShapeProto & proto)387 Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) {
388   int64_t num_elements = 1;
389   for (const auto& d : proto.dim()) {
390     if (d.size() < 0) {
391       return errors::InvalidArgument("Shape ", DebugString(proto),
392                                      " has negative dimensions; ",
393                                      "perhaps an un-fed placeholder?");
394     }
395     num_elements *= d.size();
396     if (num_elements > kMaxElements) {
397       return errors::InvalidArgument("Shape ", DebugString(proto),
398                                      " is too large (more than ", kMaxElements,
399                                      " entries)");
400     }
401   }
402   return OkStatus();
403 }
404 
TensorShapeOld(const TensorShapeProto & proto)405 TensorShapeOld::TensorShapeOld(const TensorShapeProto& proto) {
406   dim_sizes_.reserve(proto.dim_size());
407   num_elements_ = 1;
408   for (const auto& d : proto.dim()) {
409     AddDim(d.size());
410   }
411 }
412 
TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes)413 TensorShapeOld::TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes) {
414   dim_sizes_.reserve(dim_sizes.size());
415   num_elements_ = 1;
416   for (auto s : dim_sizes) {
417     AddDim(s);
418   }
419 }
420 
TensorShapeOld()421 TensorShapeOld::TensorShapeOld() : num_elements_(1) {}
422 
Clear()423 void TensorShapeOld::Clear() {
424   dim_sizes_.clear();
425   num_elements_ = 1;
426 }
427 
AddDim(int64_t size)428 void TensorShapeOld::AddDim(int64_t size) {
429   CHECK_GE(size, 0);
430   dim_sizes_.push_back(size);
431   num_elements_ *= size;
432   CHECK_LE(0, num_elements_);
433   CHECK_LE(num_elements_, kMaxElements);
434 }
435 
AppendShape(const TensorShapeOld & shape)436 void TensorShapeOld::AppendShape(const TensorShapeOld& shape) {
437   for (auto d : shape) AddDim(d.size);
438 }
439 
InsertDim(int d,int64_t size)440 void TensorShapeOld::InsertDim(int d, int64_t size) {
441   CHECK_GE(d, 0);
442   CHECK_LE(d, dims());
443   CHECK_GE(size, 0);
444   dim_sizes_.insert(dim_sizes_.begin() + d, size);
445   num_elements_ *= size;
446   CHECK_LE(0, num_elements_);
447   CHECK_LE(num_elements_, kMaxElements);
448 }
449 
set_dim(int d,int64_t size)450 void TensorShapeOld::set_dim(int d, int64_t size) {
451   CHECK_GE(d, 0);
452   CHECK_LT(d, dims());
453   CHECK_GE(size, 0);
454 
455   // Update the number of elements. num_elements_ is int64.
456   dim_sizes_[d] = size;
457   recompute_dims();
458 }
459 
RemoveDim(int d)460 void TensorShapeOld::RemoveDim(int d) {
461   CHECK_GE(d, 0);
462   CHECK_LT(d, dims());
463 
464   // Update the number of elements and remove the dimension from the
465   // sizes.
466   dim_sizes_.erase(dim_sizes_.begin() + d);
467   recompute_dims();
468 }
469 
recompute_dims()470 void TensorShapeOld::recompute_dims() {
471   num_elements_ = 1;
472   for (auto s : dim_sizes_) {
473     num_elements_ *= s;
474     CHECK_LE(0, num_elements_);
475     CHECK_LE(num_elements_, kMaxElements);
476   }
477 }
478 
IsSameSize(const TensorShapeOld & b) const479 bool TensorShapeOld::IsSameSize(const TensorShapeOld& b) const {
480   if (b.dims() != dims()) return false;
481   for (int d = 0; d < dims(); d++) {
482     if (dim_size(d) != b.dim_size(d)) return false;
483   }
484   return true;
485 }
486 
AsProto(TensorShapeProto * proto) const487 void TensorShapeOld::AsProto(TensorShapeProto* proto) const {
488   proto->Clear();
489   for (size_t d = 0; d < dim_sizes_.size(); ++d) {
490     auto* dim = proto->add_dim();
491     dim->set_size(dim_sizes_[d]);
492   }
493 }
494 
begin() const495 TensorShapeIterOld TensorShapeOld::begin() const {
496   return TensorShapeIterOld(this, 0);
497 }
498 
end() const499 TensorShapeIterOld TensorShapeOld::end() const {
500   return TensorShapeIterOld(this, dims());
501 }
502 
DebugString() const503 string TensorShapeOld::DebugString() const {
504   return strings::StrCat(
505       "[", absl::StrJoin(gtl::ArraySlice<int64_t>(dim_sizes_), ","), "]");
506 }
507 
DebugString(const TensorShapeProto & proto)508 string TensorShapeOld::DebugString(const TensorShapeProto& proto) {
509   string s = "[";
510   bool first = true;
511   for (const auto& d : proto.dim()) {
512     strings::StrAppend(&s, first ? "" : ",", d.size());
513     first = false;
514   }
515   strings::StrAppend(&s, "]");
516   return s;
517 }
518 // End of old implementation
519 // ------------------------------------------------------------------------
520 
SkewedSize(random::SimplePhilox * gen,int64_t current_elements)521 static int64_t SkewedSize(random::SimplePhilox* gen, int64_t current_elements) {
522   int64_t result = 0;
523   do {
524     if (current_elements < 100) {
525       result = gen->Uniform(100000);
526     } else {
527       result = gen->Uniform(2);
528     }
529   } while ((result * current_elements >= 1LL << 34) ||
530            (result * current_elements < 0));
531   return result;
532 }
533 
TEST(TensorShapeTest,Randomized)534 TEST(TensorShapeTest, Randomized) {
535   // We do a randomized test to verify that the behavior of the
536   // TensorShape implementation (which changes representations depending
537   // on the values) is identical to our older, more straightforward (but
538   // more memory hungry) implementation (TensorShapeOld).
539   random::PhiloxRandom philox(7, 7);
540   random::SimplePhilox gen(&philox);
541   TensorShape s;
542   TensorShapeOld sold;
543   TensorShapeProto sp;
544   TensorShapeProto spold;
545   LOG(INFO) << "Sizes: " << sizeof(TensorShape) << " vs "
546             << sizeof(TensorShapeOld);
547   for (int i = 0; i < 100000; i++) {
548     s.AsProto(&sp);
549     sold.AsProto(&spold);
550     EXPECT_EQ(sp.DebugString(), spold.DebugString());
551     if ((i % 1000) == 0) {
552       fprintf(stderr, "ITERATION %d: %s\n", i, sp.DebugString().c_str());
553     }
554     EXPECT_EQ(s.num_elements(), sold.num_elements());
555 
556     // Test moves.
557     TensorShape copy = s;
558     TensorShape moved(std::move(copy));
559     EXPECT_EQ(s, moved);
560     copy = s;
561     moved = std::move(copy);
562     EXPECT_EQ(s, moved);
563 
564     int64_t ne = sold.num_elements();
565     int r = gen.Uniform(100);
566     if (r < 10) {
567       int64_t sz = SkewedSize(&gen, sold.num_elements());
568       s.AddDim(sz);
569       sold.AddDim(sz);
570     } else if (r < 15) {
571       s.Clear();
572       sold.Clear();
573     } else if (r < 35 && s.dims() > 0 && ne > 0 && ne < 100000000) {
574       int dim = gen.Uniform(s.dims());
575       s.RemoveDim(dim);
576       sold.RemoveDim(dim);
577     } else if (r < 50 && ne > 0 && ne < 100000000) {
578       int dim = gen.Uniform(s.dims() + 1);
579       int64_t sz = SkewedSize(&gen, sold.num_elements());
580       s.InsertDim(dim, sz);
581       sold.InsertDim(dim, sz);
582     } else {
583       std::vector<int64_t> sizes;
584       const int N = (gen.Uniform(4) == 0) ? gen.Uniform(10) : gen.Uniform(3);
585       int64_t num_elements = 1;
586       for (int i = 0; i < N; i++) {
587         int64_t sz = SkewedSize(&gen, num_elements);
588         sizes.push_back(sz);
589         num_elements *= std::max<int64_t>(1, sz);
590       }
591 
592       s = TensorShape(sizes);
593       sold = TensorShapeOld(sizes);
594     }
595   }
596 }
597 
TEST(TensorShapeTest,Large)598 TEST(TensorShapeTest, Large) {
599   // We used to cap shapes at 2**40 elements.  Ensure the
600   // bound is now higher.
601   int64_t one = 1;
602   int64_t max = std::numeric_limits<int64_t>::max();
603   EXPECT_EQ(TensorShape({max}).num_elements(), max);
604   EXPECT_EQ(TensorShape({1, max}).num_elements(), max);
605   EXPECT_EQ(TensorShape({max, 1}).num_elements(), max);
606   EXPECT_EQ(TensorShape({one << 62}).num_elements(), one << 62);
607   EXPECT_EQ(TensorShape({one << 20, one << 41}).num_elements(), one << 61);
608   EXPECT_EQ(TensorShape({1000, 1000, 1000, 1000, 1000, 1000}).num_elements(),
609             1e18);
610 }
611 
TEST(TensorShapeTest,Overflow)612 TEST(TensorShapeTest, Overflow) {
613   int64_t one = 1;
614   std::vector<std::vector<int64_t>> overflows = {
615       {1 << 30, 1 << 30, 1 << 30},
616       {1 << 5, (one << 60) + 1},
617   };
618   for (const auto& overflow : overflows) {
619     TensorShapeProto proto;
620     for (auto dim : overflow) {
621       proto.add_dim()->set_size(dim);
622     }
623     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
624               TensorShape::IsValidShape(proto).code());
625     TensorShape shape;
626     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
627               TensorShapeUtils::MakeShape(overflow, &shape).code());
628   }
629 }
630 
TEST(TensorShapeTest,UnknownRank)631 TEST(TensorShapeTest, UnknownRank) {
632   // NOTE(irving): Unfortunately, for historical reasons we have to allow an
633   // TensorShapeProto with unknown_rank() set to be parsed as a TensorShape.
634   // Would be nice to tighten this, but it's tricky given backwards
635   // compatibility requirements.
636   TensorShapeProto proto;
637   proto.set_unknown_rank(true);
638   EXPECT_TRUE(TensorShape::IsValid(proto));
639   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
640   EXPECT_EQ(TensorShape(), TensorShape(proto));
641 
642   proto.add_dim()->set_size(7);
643   EXPECT_TRUE(TensorShape::IsValid(proto));
644   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
645   EXPECT_EQ(TensorShape({7}), TensorShape(proto));
646 }
647 
TEST(TensorShapeUtilsTest,StartsWith)648 TEST(TensorShapeUtilsTest, StartsWith) {
649   EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({}), TensorShape({})));
650   EXPECT_TRUE(
651       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({})));
652   EXPECT_TRUE(
653       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2})));
654   EXPECT_TRUE(
655       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 3})));
656   EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
657                                            TensorShape({2, 3})));
658   EXPECT_FALSE(
659       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({3})));
660   EXPECT_FALSE(
661       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 4})));
662   EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3}),
663                                             TensorShape({2, 3, 4})));
664   EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
665                                             TensorShape({3, 4})));
666 }
667 
TEST(TensorShapeUtilsTest,EndsWith)668 TEST(TensorShapeUtilsTest, EndsWith) {
669   EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({}), TensorShape({})));
670   EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({})));
671   EXPECT_TRUE(
672       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({3})));
673   EXPECT_TRUE(
674       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3})));
675   EXPECT_TRUE(
676       TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({3, 4})));
677   EXPECT_FALSE(
678       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2})));
679   EXPECT_FALSE(
680       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 4})));
681   EXPECT_FALSE(
682       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3, 4})));
683   EXPECT_FALSE(
684       TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({2, 3})));
685 }
686 
687 // A few different test cases for tensor sizes for benchmarks
MakeSizes(int arg)688 static std::vector<int64_t> MakeSizes(int arg) {
689   std::vector<int64_t> sizes;
690   switch (arg) {
691     case 0:
692       sizes = {100};
693       break;
694     case 1:
695       sizes = {100, 1000};
696       break;
697     case 2:
698       sizes = {100, 1000000};
699       break;
700     case 3:
701       sizes = {100, 256, 192, 3};
702       break;
703     case 4:
704       sizes = {1, 2, 1ll << 34, 1, 1, 1};
705       break;
706   }
707   return sizes;
708 }
709 
BM_TensorShape_Init(::testing::benchmark::State & state)710 void BM_TensorShape_Init(::testing::benchmark::State& state) {
711   const int arg = state.range(0);
712 
713   auto sizes = MakeSizes(arg);
714   for (auto s : state) {
715     TensorShape shape(sizes);
716     tensorflow::testing::DoNotOptimize(shape.num_elements());
717   }
718 }
719 BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
720 
BM_TensorShape_Assign(::testing::benchmark::State & state)721 void BM_TensorShape_Assign(::testing::benchmark::State& state) {
722   const int arg = state.range(0);
723 
724   TensorShape shape(MakeSizes(arg));
725   for (auto s : state) {
726     const TensorShape s2 = shape;
727     tensorflow::testing::DoNotOptimize(s2);
728   }
729 }
730 BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
731 
BM_TensorShape_SetDim(::testing::benchmark::State & state)732 void BM_TensorShape_SetDim(::testing::benchmark::State& state) {
733   const int arg = state.range(0);
734 
735   TensorShape shape(MakeSizes(arg));
736   tensorflow::testing::DoNotOptimize(shape);
737   for (auto s : state) {
738     shape.set_dim(0, 8);
739   }
740 }
741 BENCHMARK(BM_TensorShape_SetDim)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
742 
743 }  // namespace
744 }  // namespace tensorflow
745