1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_shape.h"
17
18 #include "tensorflow/core/framework/tensor_shape.pb.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/random/simple_philox.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26
27 namespace tensorflow {
28 class TensorShapeTestHelper {
29 public:
set_data_type(TensorShape * s,DataType t)30 static void set_data_type(TensorShape* s, DataType t) { s->set_data_type(t); }
data_type(const TensorShape * s)31 static uint8 data_type(const TensorShape* s) { return s->data_type(); }
32 };
33
34 namespace {
35
TEST(TensorShapeTest,Default)36 TEST(TensorShapeTest, Default) {
37 // The default TensorShape constructor constructs a shape of 0-dim
38 // and 1-element.
39 TensorShape s;
40 EXPECT_EQ(s.dims(), 0);
41 EXPECT_EQ(s.num_elements(), 1);
42 }
43
TEST(TensorShapeTest,set_dim)44 TEST(TensorShapeTest, set_dim) {
45 TensorShape s({10, 5});
46
47 s.set_dim(0, 20);
48 ASSERT_EQ(2, s.dims());
49 EXPECT_EQ(20, s.dim_size(0));
50 EXPECT_EQ(100, s.num_elements());
51
52 s.set_dim(1, 2);
53 ASSERT_EQ(2, s.dims());
54 EXPECT_EQ(2, s.dim_size(1));
55 EXPECT_EQ(40, s.num_elements());
56 }
57
TEST(TensorShapeTest,RemoveDim)58 TEST(TensorShapeTest, RemoveDim) {
59 TensorShape s({10, 5});
60 s.RemoveDim(0);
61 EXPECT_EQ(5, s.num_elements());
62 ASSERT_EQ(1, s.dims());
63 }
64
TEST(TensorShapeTest,RemoveAndAddDim)65 TEST(TensorShapeTest, RemoveAndAddDim) {
66 TensorShape s({10, 5, 20});
67 s.RemoveDim(1);
68 s.AddDim(100);
69
70 EXPECT_EQ(20000, s.num_elements());
71 ASSERT_EQ(3, s.dims());
72 }
73
TEST(TensorShapeTest,RemoveLastDims)74 TEST(TensorShapeTest, RemoveLastDims) {
75 TensorShape s({2, 3, 5, 7});
76 s.RemoveLastDims(1);
77
78 ASSERT_EQ(3, s.dims());
79 EXPECT_EQ(30, s.num_elements());
80
81 s.RemoveLastDims(2);
82 ASSERT_EQ(1, s.dims());
83 EXPECT_EQ(2, s.dim_size(0));
84 }
85
TEST(TensorShapeTest,RemoveDimRange)86 TEST(TensorShapeTest, RemoveDimRange) {
87 TensorShape s0({2, 3, 5, 7});
88 // Empty interval => noop.
89 for (int i = -4; i <= 4; ++i) {
90 s0.RemoveDimRange(i, i);
91 ASSERT_EQ(4, s0.dims());
92 ASSERT_EQ(210, s0.num_elements());
93 }
94
95 // Positive begin and end.
96 s0.RemoveDimRange(3, 1); // Empty interval.
97 ASSERT_EQ(4, s0.dims());
98 ASSERT_EQ(210, s0.num_elements());
99 s0.RemoveDimRange(0, 3);
100 ASSERT_EQ(1, s0.dims());
101 EXPECT_EQ(7, s0.dim_size(0));
102 TensorShape s1({2, 3, 5, 7});
103 s1.RemoveDimRange(2, 3);
104 ASSERT_EQ(3, s1.dims());
105 ASSERT_EQ(42, s1.num_elements());
106
107 // Negative begin or end.
108 TensorShape s2({2, 3, 5, 7});
109 s2.RemoveDimRange(-2, -3); // Empty interval.
110 ASSERT_EQ(4, s2.dims());
111 ASSERT_EQ(210, s2.num_elements());
112 s2.RemoveDimRange(0, -2);
113 ASSERT_EQ(1, s2.dims());
114 ASSERT_EQ(7, s2.dim_size(0));
115 TensorShape s3({2, 3, 5, 7});
116 s3.RemoveDimRange(-3, -2);
117 ASSERT_EQ(3, s3.dims());
118 ASSERT_EQ(42, s3.num_elements());
119 }
120
TEST(TensorShapeTest,InvalidShapeProto)121 TEST(TensorShapeTest, InvalidShapeProto) {
122 TensorShapeProto proto;
123 EXPECT_TRUE(TensorShape::IsValid(proto));
124
125 proto.add_dim()->set_size(357);
126 proto.add_dim()->set_size(982);
127 EXPECT_TRUE(TensorShape::IsValid(proto));
128
129 proto.Clear();
130 proto.add_dim()->set_size(-357);
131 proto.add_dim()->set_size(-982);
132 EXPECT_FALSE(TensorShape::IsValid(proto));
133
134 proto.Clear();
135 proto.add_dim()->set_size(1LL << 35);
136 proto.add_dim()->set_size((1LL << 35) + 1);
137 EXPECT_FALSE(TensorShape::IsValid(proto));
138 }
139
TEST(TensorShapeTest,TooManyDimsProto)140 TEST(TensorShapeTest, TooManyDimsProto) {
141 TensorShapeProto proto;
142 // Deliberate redundancy to ensure that both paths work.
143 EXPECT_TRUE(TensorShape::IsValid(proto));
144 TF_EXPECT_OK(TensorShape::IsValidShape(proto));
145 for (int i = 0; i < TensorShape::MaxDimensions(); i++) {
146 proto.add_dim()->set_size(1);
147 }
148 EXPECT_TRUE(TensorShape::IsValid(proto));
149 TF_EXPECT_OK(TensorShape::IsValidShape(proto));
150 proto.add_dim()->set_size(1);
151 EXPECT_FALSE(TensorShape::IsValid(proto));
152 EXPECT_FALSE(TensorShape::IsValidShape(proto).ok());
153 }
154
TEST(TensorShapeTest,SetDimForEmptyTensor)155 TEST(TensorShapeTest, SetDimForEmptyTensor) {
156 TensorShape s({10, 5, 20});
157 EXPECT_EQ(1000, s.num_elements());
158 s.set_dim(1, 0);
159 EXPECT_EQ(0, s.num_elements());
160 s.set_dim(1, 7);
161 EXPECT_EQ(1400, s.num_elements());
162 }
163
TEST(TensorShapeTest,AppendShape64BitIndices)164 TEST(TensorShapeTest, AppendShape64BitIndices) {
165 TensorShape s({10, 2147483648});
166
167 EXPECT_EQ(10, s.dim_size(0));
168 EXPECT_EQ(2147483648, s.dim_size(1));
169
170 TensorShape s2;
171 s2.AppendShape(s);
172 EXPECT_EQ(10, s2.dim_size(0));
173 EXPECT_EQ(2147483648, s2.dim_size(1));
174 }
175
TEST(TensorShapeTest,DataType)176 TEST(TensorShapeTest, DataType) {
177 TensorShape s({});
178 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INVALID);
179 TensorShapeTestHelper::set_data_type(&s, DT_INT32);
180 s.AddDim(1);
181 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
182 s.AddDim(100000);
183 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
184 TensorShapeTestHelper::set_data_type(&s, DT_UINT16_REF);
185 s.AddDim(2);
186 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
187 s.AddDim(4);
188 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
189 s.AddDim(3);
190 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
191
192 TensorShape s2 = s;
193 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
194 s2.RemoveDim(2);
195 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
196 TensorShapeTestHelper::set_data_type(&s2, DT_FLOAT);
197 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_FLOAT);
198 s2.Clear();
199 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_INVALID);
200 }
201
TEST(TensorShapeTest,ostream)202 TEST(TensorShapeTest, ostream) {
203 TensorShape s({10, 5, 4});
204 std::stringstream ss;
205 ss << s;
206 EXPECT_EQ(ss.str(), "[10,5,4]");
207 }
208
TEST(TensorShapeTest,AddDimWithStatus)209 TEST(TensorShapeTest, AddDimWithStatus) {
210 TensorShape s({10, 5, 20});
211 Status status = s.AddDimWithStatus(400);
212 EXPECT_TRUE(status.ok());
213 EXPECT_EQ(400000, s.num_elements());
214 ASSERT_EQ(4, s.dims());
215
216 status = s.AddDimWithStatus(-1);
217 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
218 }
219
TEST(TensorShapeTest,Factory)220 TEST(TensorShapeTest, Factory) {
221 TensorShape s;
222 Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s);
223 EXPECT_TRUE(status.ok());
224 EXPECT_EQ(1000, s.num_elements());
225 ASSERT_EQ(3, s.dims());
226
227 status = TensorShape::BuildTensorShapeBase({-10, 5, 20}, &s);
228 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
229 }
230
231 // -----------------------------------------------------------------------
232 // An old implementation of TensorShape using a different representation,
233 // preserved here in the unittest to allow us to have a randomized unittest
234 // that makes sure the behavior of TensorShape and TensorShapeOld are
235 // the same.
236 class TensorShapeIterOld; // Declared below
237
238 /// Manages the dimensions of a Tensor and their sizes.
239 class TensorShapeOld {
240 public:
241 /// \brief Construct a `TensorShape` from the provided sizes.
242 /// REQUIRES: `dim_sizes[i] >= 0`
243 explicit TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes);
TensorShapeOld(std::initializer_list<int64_t> dim_sizes)244 TensorShapeOld(std::initializer_list<int64_t> dim_sizes)
245 : TensorShapeOld(gtl::ArraySlice<int64_t>(dim_sizes)) {}
246
247 /// REQUIRES: `IsValid(proto)`
248 explicit TensorShapeOld(const TensorShapeProto& proto);
249
250 /// Create a tensor shape with no dimensions and one element, which you can
251 /// then call `AddDim()` on.
252 TensorShapeOld();
253
254 /// Returns `true` iff `proto` is a valid tensor shape.
255 static bool IsValid(const TensorShapeProto& proto);
256
257 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
258 /// status otherwise.
259 static Status IsValidShape(const TensorShapeProto& proto);
260
261 /// Clear a tensor shape
262 void Clear();
263
264 /// \brief Add a dimension to the end ("inner-most").
265 /// REQUIRES: `size >= 0`
266 void AddDim(int64_t size);
267
268 /// Appends all the dimensions from `shape`.
269 void AppendShape(const TensorShapeOld& shape);
270
271 /// \brief Insert a dimension somewhere in the `TensorShape`.
272 /// REQUIRES: `0 <= d <= dims()`
273 /// REQUIRES: `size >= 0`
274 void InsertDim(int d, int64_t size);
275
276 /// \brief Modifies the size of the dimension `d` to be `size`
277 /// REQUIRES: `0 <= d < dims()`
278 /// REQUIRES: `size >= 0`
279 void set_dim(int d, int64_t size);
280
281 /// \brief Removes dimension `d` from the `TensorShape`.
282 /// REQUIRES: `0 <= d < dims()`
283 void RemoveDim(int d);
284
285 /// Return the number of dimensions in the tensor.
dims() const286 int dims() const { return dim_sizes_.size(); }
287
288 /// \brief Returns the number of elements in dimension `d`.
289 /// REQUIRES: `0 <= d < dims()`
290 // TODO(touts): Rename to `dimension()` to match
291 // `Eigen::Tensor::dimension()`?
dim_size(int d) const292 int64_t dim_size(int d) const {
293 DCHECK_GE(d, 0);
294 DCHECK_LT(d, dims());
295 return dim_sizes_[d];
296 }
297
298 /// Returns sizes of all dimensions.
dim_sizes() const299 gtl::ArraySlice<int64_t> dim_sizes() const { return dim_sizes_; }
300
301 /// \brief Returns the number of elements in the tensor.
302 ///
303 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
304 /// which uses `ptrdiff_t`.
num_elements() const305 int64_t num_elements() const { return num_elements_; }
306
307 /// Returns true if `*this` and `b` have the same sizes. Ignores
308 /// dimension names.
309 bool IsSameSize(const TensorShapeOld& b) const;
operator ==(const TensorShapeOld & b) const310 bool operator==(const TensorShapeOld& b) const { return IsSameSize(b); }
311
312 /// Fill `*proto` from `*this`.
313 void AsProto(TensorShapeProto* proto) const;
314
315 /// Fill `*dsizes` from `*this`.
316 template <int NDIMS>
317 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
318
319 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
320 /// which case we pad the rest of the sizes with 1.
321 template <int NDIMS>
322 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
323
324 /// For iterating through the dimensions.
325 TensorShapeIterOld begin() const;
326 TensorShapeIterOld end() const;
327
328 /// For error messages.
329 string DebugString() const;
330
331 /// Same as `TensorShape(proto).DebugString()` but doesn't crash for
332 /// invalid protos.
333 static string DebugString(const TensorShapeProto& proto);
334
335 private:
336 // Recalculates the dimensions of this tensor after they are modified.
337 void recompute_dims();
338
339 // TODO(josh11b): Maybe use something from the Eigen Tensor library
340 // for the sizes.
341 gtl::InlinedVector<int64_t, 4> dim_sizes_;
342
343 // total number of elements (avoids recomputing it each time).
344 int64_t num_elements_;
345 };
346
347 struct TensorShapeDimOld {
TensorShapeDimOldtensorflow::__anon3caac34d0111::TensorShapeDimOld348 explicit TensorShapeDimOld(int64_t s) : size(s) {}
349 int64_t size;
350 };
351
352 class TensorShapeIterOld {
353 public:
TensorShapeIterOld(const TensorShapeOld * shape,int d)354 TensorShapeIterOld(const TensorShapeOld* shape, int d)
355 : shape_(shape), d_(d) {}
operator ==(const TensorShapeIterOld & rhs)356 bool operator==(const TensorShapeIterOld& rhs) {
357 DCHECK(shape_ == rhs.shape_);
358 return d_ == rhs.d_;
359 }
operator !=(const TensorShapeIterOld & rhs)360 bool operator!=(const TensorShapeIterOld& rhs) {
361 DCHECK(shape_ == rhs.shape_);
362 return d_ != rhs.d_;
363 }
operator ++()364 void operator++() { ++d_; }
operator *()365 TensorShapeDimOld operator*() {
366 return TensorShapeDimOld(shape_->dim_size(d_));
367 }
368
369 private:
370 const TensorShapeOld* shape_;
371 int d_;
372 };
373
374 // An upper limit of the total number of elements in a tensor.
375 static const int64_t kMaxElements = (1LL << 40);
376
IsValid(const TensorShapeProto & proto)377 bool TensorShapeOld::IsValid(const TensorShapeProto& proto) {
378 int64_t num_elements = 1;
379 for (const auto& d : proto.dim()) {
380 if (d.size() < 0) return false;
381 num_elements *= d.size();
382 if (num_elements > kMaxElements) return false;
383 }
384 return true;
385 }
386
IsValidShape(const TensorShapeProto & proto)387 Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) {
388 int64_t num_elements = 1;
389 for (const auto& d : proto.dim()) {
390 if (d.size() < 0) {
391 return errors::InvalidArgument("Shape ", DebugString(proto),
392 " has negative dimensions; ",
393 "perhaps an un-fed placeholder?");
394 }
395 num_elements *= d.size();
396 if (num_elements > kMaxElements) {
397 return errors::InvalidArgument("Shape ", DebugString(proto),
398 " is too large (more than ", kMaxElements,
399 " entries)");
400 }
401 }
402 return OkStatus();
403 }
404
TensorShapeOld(const TensorShapeProto & proto)405 TensorShapeOld::TensorShapeOld(const TensorShapeProto& proto) {
406 dim_sizes_.reserve(proto.dim_size());
407 num_elements_ = 1;
408 for (const auto& d : proto.dim()) {
409 AddDim(d.size());
410 }
411 }
412
TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes)413 TensorShapeOld::TensorShapeOld(gtl::ArraySlice<int64_t> dim_sizes) {
414 dim_sizes_.reserve(dim_sizes.size());
415 num_elements_ = 1;
416 for (auto s : dim_sizes) {
417 AddDim(s);
418 }
419 }
420
TensorShapeOld()421 TensorShapeOld::TensorShapeOld() : num_elements_(1) {}
422
Clear()423 void TensorShapeOld::Clear() {
424 dim_sizes_.clear();
425 num_elements_ = 1;
426 }
427
AddDim(int64_t size)428 void TensorShapeOld::AddDim(int64_t size) {
429 CHECK_GE(size, 0);
430 dim_sizes_.push_back(size);
431 num_elements_ *= size;
432 CHECK_LE(0, num_elements_);
433 CHECK_LE(num_elements_, kMaxElements);
434 }
435
AppendShape(const TensorShapeOld & shape)436 void TensorShapeOld::AppendShape(const TensorShapeOld& shape) {
437 for (auto d : shape) AddDim(d.size);
438 }
439
InsertDim(int d,int64_t size)440 void TensorShapeOld::InsertDim(int d, int64_t size) {
441 CHECK_GE(d, 0);
442 CHECK_LE(d, dims());
443 CHECK_GE(size, 0);
444 dim_sizes_.insert(dim_sizes_.begin() + d, size);
445 num_elements_ *= size;
446 CHECK_LE(0, num_elements_);
447 CHECK_LE(num_elements_, kMaxElements);
448 }
449
set_dim(int d,int64_t size)450 void TensorShapeOld::set_dim(int d, int64_t size) {
451 CHECK_GE(d, 0);
452 CHECK_LT(d, dims());
453 CHECK_GE(size, 0);
454
455 // Update the number of elements. num_elements_ is int64.
456 dim_sizes_[d] = size;
457 recompute_dims();
458 }
459
RemoveDim(int d)460 void TensorShapeOld::RemoveDim(int d) {
461 CHECK_GE(d, 0);
462 CHECK_LT(d, dims());
463
464 // Update the number of elements and remove the dimension from the
465 // sizes.
466 dim_sizes_.erase(dim_sizes_.begin() + d);
467 recompute_dims();
468 }
469
recompute_dims()470 void TensorShapeOld::recompute_dims() {
471 num_elements_ = 1;
472 for (auto s : dim_sizes_) {
473 num_elements_ *= s;
474 CHECK_LE(0, num_elements_);
475 CHECK_LE(num_elements_, kMaxElements);
476 }
477 }
478
IsSameSize(const TensorShapeOld & b) const479 bool TensorShapeOld::IsSameSize(const TensorShapeOld& b) const {
480 if (b.dims() != dims()) return false;
481 for (int d = 0; d < dims(); d++) {
482 if (dim_size(d) != b.dim_size(d)) return false;
483 }
484 return true;
485 }
486
AsProto(TensorShapeProto * proto) const487 void TensorShapeOld::AsProto(TensorShapeProto* proto) const {
488 proto->Clear();
489 for (size_t d = 0; d < dim_sizes_.size(); ++d) {
490 auto* dim = proto->add_dim();
491 dim->set_size(dim_sizes_[d]);
492 }
493 }
494
begin() const495 TensorShapeIterOld TensorShapeOld::begin() const {
496 return TensorShapeIterOld(this, 0);
497 }
498
end() const499 TensorShapeIterOld TensorShapeOld::end() const {
500 return TensorShapeIterOld(this, dims());
501 }
502
DebugString() const503 string TensorShapeOld::DebugString() const {
504 return strings::StrCat(
505 "[", absl::StrJoin(gtl::ArraySlice<int64_t>(dim_sizes_), ","), "]");
506 }
507
DebugString(const TensorShapeProto & proto)508 string TensorShapeOld::DebugString(const TensorShapeProto& proto) {
509 string s = "[";
510 bool first = true;
511 for (const auto& d : proto.dim()) {
512 strings::StrAppend(&s, first ? "" : ",", d.size());
513 first = false;
514 }
515 strings::StrAppend(&s, "]");
516 return s;
517 }
518 // End of old implementation
519 // ------------------------------------------------------------------------
520
SkewedSize(random::SimplePhilox * gen,int64_t current_elements)521 static int64_t SkewedSize(random::SimplePhilox* gen, int64_t current_elements) {
522 int64_t result = 0;
523 do {
524 if (current_elements < 100) {
525 result = gen->Uniform(100000);
526 } else {
527 result = gen->Uniform(2);
528 }
529 } while ((result * current_elements >= 1LL << 34) ||
530 (result * current_elements < 0));
531 return result;
532 }
533
TEST(TensorShapeTest,Randomized)534 TEST(TensorShapeTest, Randomized) {
535 // We do a randomized test to verify that the behavior of the
536 // TensorShape implementation (which changes representations depending
537 // on the values) is identical to our older, more straightforward (but
538 // more memory hungry) implementation (TensorShapeOld).
539 random::PhiloxRandom philox(7, 7);
540 random::SimplePhilox gen(&philox);
541 TensorShape s;
542 TensorShapeOld sold;
543 TensorShapeProto sp;
544 TensorShapeProto spold;
545 LOG(INFO) << "Sizes: " << sizeof(TensorShape) << " vs "
546 << sizeof(TensorShapeOld);
547 for (int i = 0; i < 100000; i++) {
548 s.AsProto(&sp);
549 sold.AsProto(&spold);
550 EXPECT_EQ(sp.DebugString(), spold.DebugString());
551 if ((i % 1000) == 0) {
552 fprintf(stderr, "ITERATION %d: %s\n", i, sp.DebugString().c_str());
553 }
554 EXPECT_EQ(s.num_elements(), sold.num_elements());
555
556 // Test moves.
557 TensorShape copy = s;
558 TensorShape moved(std::move(copy));
559 EXPECT_EQ(s, moved);
560 copy = s;
561 moved = std::move(copy);
562 EXPECT_EQ(s, moved);
563
564 int64_t ne = sold.num_elements();
565 int r = gen.Uniform(100);
566 if (r < 10) {
567 int64_t sz = SkewedSize(&gen, sold.num_elements());
568 s.AddDim(sz);
569 sold.AddDim(sz);
570 } else if (r < 15) {
571 s.Clear();
572 sold.Clear();
573 } else if (r < 35 && s.dims() > 0 && ne > 0 && ne < 100000000) {
574 int dim = gen.Uniform(s.dims());
575 s.RemoveDim(dim);
576 sold.RemoveDim(dim);
577 } else if (r < 50 && ne > 0 && ne < 100000000) {
578 int dim = gen.Uniform(s.dims() + 1);
579 int64_t sz = SkewedSize(&gen, sold.num_elements());
580 s.InsertDim(dim, sz);
581 sold.InsertDim(dim, sz);
582 } else {
583 std::vector<int64_t> sizes;
584 const int N = (gen.Uniform(4) == 0) ? gen.Uniform(10) : gen.Uniform(3);
585 int64_t num_elements = 1;
586 for (int i = 0; i < N; i++) {
587 int64_t sz = SkewedSize(&gen, num_elements);
588 sizes.push_back(sz);
589 num_elements *= std::max<int64_t>(1, sz);
590 }
591
592 s = TensorShape(sizes);
593 sold = TensorShapeOld(sizes);
594 }
595 }
596 }
597
TEST(TensorShapeTest,Large)598 TEST(TensorShapeTest, Large) {
599 // We used to cap shapes at 2**40 elements. Ensure the
600 // bound is now higher.
601 int64_t one = 1;
602 int64_t max = std::numeric_limits<int64_t>::max();
603 EXPECT_EQ(TensorShape({max}).num_elements(), max);
604 EXPECT_EQ(TensorShape({1, max}).num_elements(), max);
605 EXPECT_EQ(TensorShape({max, 1}).num_elements(), max);
606 EXPECT_EQ(TensorShape({one << 62}).num_elements(), one << 62);
607 EXPECT_EQ(TensorShape({one << 20, one << 41}).num_elements(), one << 61);
608 EXPECT_EQ(TensorShape({1000, 1000, 1000, 1000, 1000, 1000}).num_elements(),
609 1e18);
610 }
611
TEST(TensorShapeTest,Overflow)612 TEST(TensorShapeTest, Overflow) {
613 int64_t one = 1;
614 std::vector<std::vector<int64_t>> overflows = {
615 {1 << 30, 1 << 30, 1 << 30},
616 {1 << 5, (one << 60) + 1},
617 };
618 for (const auto& overflow : overflows) {
619 TensorShapeProto proto;
620 for (auto dim : overflow) {
621 proto.add_dim()->set_size(dim);
622 }
623 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
624 TensorShape::IsValidShape(proto).code());
625 TensorShape shape;
626 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
627 TensorShapeUtils::MakeShape(overflow, &shape).code());
628 }
629 }
630
TEST(TensorShapeTest,UnknownRank)631 TEST(TensorShapeTest, UnknownRank) {
632 // NOTE(irving): Unfortunately, for historical reasons we have to allow an
633 // TensorShapeProto with unknown_rank() set to be parsed as a TensorShape.
634 // Would be nice to tighten this, but it's tricky given backwards
635 // compatibility requirements.
636 TensorShapeProto proto;
637 proto.set_unknown_rank(true);
638 EXPECT_TRUE(TensorShape::IsValid(proto));
639 TF_EXPECT_OK(TensorShape::IsValidShape(proto));
640 EXPECT_EQ(TensorShape(), TensorShape(proto));
641
642 proto.add_dim()->set_size(7);
643 EXPECT_TRUE(TensorShape::IsValid(proto));
644 TF_EXPECT_OK(TensorShape::IsValidShape(proto));
645 EXPECT_EQ(TensorShape({7}), TensorShape(proto));
646 }
647
TEST(TensorShapeUtilsTest,StartsWith)648 TEST(TensorShapeUtilsTest, StartsWith) {
649 EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({}), TensorShape({})));
650 EXPECT_TRUE(
651 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({})));
652 EXPECT_TRUE(
653 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2})));
654 EXPECT_TRUE(
655 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 3})));
656 EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
657 TensorShape({2, 3})));
658 EXPECT_FALSE(
659 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({3})));
660 EXPECT_FALSE(
661 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 4})));
662 EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3}),
663 TensorShape({2, 3, 4})));
664 EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
665 TensorShape({3, 4})));
666 }
667
TEST(TensorShapeUtilsTest,EndsWith)668 TEST(TensorShapeUtilsTest, EndsWith) {
669 EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({}), TensorShape({})));
670 EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({})));
671 EXPECT_TRUE(
672 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({3})));
673 EXPECT_TRUE(
674 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3})));
675 EXPECT_TRUE(
676 TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({3, 4})));
677 EXPECT_FALSE(
678 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2})));
679 EXPECT_FALSE(
680 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 4})));
681 EXPECT_FALSE(
682 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3, 4})));
683 EXPECT_FALSE(
684 TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({2, 3})));
685 }
686
687 // A few different test cases for tensor sizes for benchmarks
MakeSizes(int arg)688 static std::vector<int64_t> MakeSizes(int arg) {
689 std::vector<int64_t> sizes;
690 switch (arg) {
691 case 0:
692 sizes = {100};
693 break;
694 case 1:
695 sizes = {100, 1000};
696 break;
697 case 2:
698 sizes = {100, 1000000};
699 break;
700 case 3:
701 sizes = {100, 256, 192, 3};
702 break;
703 case 4:
704 sizes = {1, 2, 1ll << 34, 1, 1, 1};
705 break;
706 }
707 return sizes;
708 }
709
BM_TensorShape_Init(::testing::benchmark::State & state)710 void BM_TensorShape_Init(::testing::benchmark::State& state) {
711 const int arg = state.range(0);
712
713 auto sizes = MakeSizes(arg);
714 for (auto s : state) {
715 TensorShape shape(sizes);
716 tensorflow::testing::DoNotOptimize(shape.num_elements());
717 }
718 }
719 BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
720
BM_TensorShape_Assign(::testing::benchmark::State & state)721 void BM_TensorShape_Assign(::testing::benchmark::State& state) {
722 const int arg = state.range(0);
723
724 TensorShape shape(MakeSizes(arg));
725 for (auto s : state) {
726 const TensorShape s2 = shape;
727 tensorflow::testing::DoNotOptimize(s2);
728 }
729 }
730 BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
731
BM_TensorShape_SetDim(::testing::benchmark::State & state)732 void BM_TensorShape_SetDim(::testing::benchmark::State& state) {
733 const int arg = state.range(0);
734
735 TensorShape shape(MakeSizes(arg));
736 tensorflow::testing::DoNotOptimize(shape);
737 for (auto s : state) {
738 shape.set_dim(0, 8);
739 }
740 }
741 BENCHMARK(BM_TensorShape_SetDim)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
742
743 } // namespace
744 } // namespace tensorflow
745