1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/fake_input.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/framework/op_def_builder.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28
29 namespace tensorflow {
30 namespace shape_inference {
31 namespace {
32
33 #define EXPECT_CONTAINS(X, Y) \
34 do { \
35 auto XSTR = X; \
36 auto YSTR = Y; \
37 EXPECT_TRUE(absl::StrContains(XSTR, YSTR)) \
38 << "'" << XSTR << "' does not contain '" << YSTR << "'"; \
39 } while (false);
40
MakeOpDefWithLists()41 OpDef MakeOpDefWithLists() {
42 OpRegistrationData op_reg_data;
43 OpDefBuilder b("dummy");
44 b.Input(strings::StrCat("input: N * float"));
45 b.Output(strings::StrCat("output: N * float"));
46 CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok());
47 return op_reg_data.op_def;
48 }
49
S(std::initializer_list<int64> dims)50 PartialTensorShape S(std::initializer_list<int64> dims) {
51 return PartialTensorShape(dims);
52 }
53
Unknown()54 PartialTensorShape Unknown() { return PartialTensorShape(); }
55
56 } // namespace
57
58 class ShapeInferenceTest : public ::testing::Test {
59 protected:
60 // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)61 bool SameHandle(DimensionHandle a, DimensionHandle b) {
62 return a.SameHandle(b);
63 }
SameHandle(ShapeHandle a,ShapeHandle b)64 bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)65 bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)66 bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)67 void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
68 DimensionHandle* out) {
69 c->Relax(d0, d1, out);
70 }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)71 void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
72 ShapeHandle* out) {
73 c->Relax(s0, s1, out);
74 }
75 void TestMergeHandles(bool input_not_output);
76 void TestRelaxHandles(bool input_not_output);
77
78 static constexpr int kVersion = 0; // used for graph-def version.
79 };
80
TEST_F(ShapeInferenceTest,InputOutputByName)81 TEST_F(ShapeInferenceTest, InputOutputByName) {
82 // Setup test to contain an input tensor list of size 3.
83 OpDef op_def = MakeOpDefWithLists();
84 NodeDef def;
85 auto s = NodeDefBuilder("dummy", &op_def)
86 .Attr("N", 3)
87 .Input(FakeInput(DT_FLOAT))
88 .Finalize(&def);
89 InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
90 {}, {}, {});
91
92 EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
93 EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
94 EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
95 // Test getters.
96 std::vector<ShapeHandle> shapes;
97 EXPECT_FALSE(c.input("nonexistent", &shapes).ok());
98 TF_EXPECT_OK(c.input("input", &shapes));
99 EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
100 EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
101 EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
102
103 // Test setters.
104 EXPECT_FALSE(c.set_output("nonexistent", shapes).ok());
105 TF_EXPECT_OK(c.set_output("output", shapes));
106 EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
107 EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
108 EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
109 }
110
MakeOpDef(int num_inputs,int num_outputs)111 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
112 OpRegistrationData op_reg_data;
113 OpDefBuilder b("dummy");
114 for (int i = 0; i < num_inputs; ++i) {
115 b.Input(strings::StrCat("i", i, ": float"));
116 }
117 for (int i = 0; i < num_outputs; ++i) {
118 b.Output(strings::StrCat("o", i, ": float"));
119 }
120 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
121 return op_reg_data.op_def;
122 }
123
TEST_F(ShapeInferenceTest,DimensionOrConstant)124 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
125 NodeDef def;
126 InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
127 EXPECT_EQ(InferenceContext::kUnknownDim,
128 c.Value(InferenceContext::kUnknownDim));
129 EXPECT_EQ(1, c.Value(1));
130
131 #ifndef NDEBUG
132 // Only run death test if DCHECKS are enabled.
133 EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
134 #endif
135 }
136
TEST_F(ShapeInferenceTest,Run)137 TEST_F(ShapeInferenceTest, Run) {
138 NodeDef def;
139 def.set_name("foo");
140 def.set_op("foo_op");
141 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
142 TF_ASSERT_OK(c.construction_status());
143
144 {
145 auto fn = [](InferenceContext* c) {
146 ShapeHandle h;
147 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
148 c->set_output(0, c->input(0));
149 c->set_output(1, c->input(0));
150 return Status::OK();
151 };
152 TF_ASSERT_OK(c.Run(fn));
153 }
154
155 {
156 auto fn = [](InferenceContext* c) {
157 ShapeHandle h;
158 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
159 c->set_output(0, c->input(0));
160 c->set_output(1, c->input(0));
161 return Status::OK();
162 };
163 auto s = c.Run(fn).ToString();
164 // Extra error message is attached when Run fails.
165 EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
166 EXPECT_CONTAINS(s, "node foo");
167 EXPECT_CONTAINS(s, "foo_op");
168 }
169 }
170
171 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)172 TEST_F(ShapeInferenceTest, AttachContext) {
173 NodeDef def;
174 def.set_name("foo");
175 def.set_op("foo_op");
176 // Error when no constant tensors were requested.
177 {
178 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
179 {});
180 TF_ASSERT_OK(c.construction_status());
181 auto fn = [](InferenceContext* c) {
182 ShapeHandle h;
183 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
184 c->set_output(0, c->input(0));
185 return Status::OK();
186 };
187 auto s = c.Run(fn).ToString();
188 EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 3");
189 EXPECT_CONTAINS(s, "node foo");
190 EXPECT_CONTAINS(s, "foo_op");
191 EXPECT_CONTAINS(s, "input shapes: [1,2,3]");
192 }
193
194 // Error when a constant tensor value was requested.
195 {
196 Tensor input_t =
197 ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
198 InferenceContext c(kVersion, def, MakeOpDef(2, 2),
199 {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
200 TF_ASSERT_OK(c.construction_status());
201 auto fn = [](InferenceContext* c) {
202 c->input_tensor(0); // get this one, but it's null - won't be in error.
203 c->input_tensor(1); // get this one, will now be in error.
204 ShapeHandle h;
205 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
206 c->set_output(0, c->input(0));
207 return Status::OK();
208 };
209 auto s = c.Run(fn).ToString();
210 EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 3");
211 EXPECT_CONTAINS(s, "node foo");
212 EXPECT_CONTAINS(s, "foo_op");
213 EXPECT_CONTAINS(
214 s,
215 "input shapes: [1,2,3], [4,5] and with computed input tensors: "
216 "input[1] = <1.1 2.2 3.3 4.4 5.5>.");
217 }
218
219 // Error when a constant tensor value as shape was requested, but no partial
220 // shapes provided.
221 {
222 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
223 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
224 {nullptr, &input_t}, {}, {});
225 TF_ASSERT_OK(c.construction_status());
226 auto fn = [](InferenceContext* c) {
227 ShapeHandle s;
228 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
229 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
230 ShapeHandle h;
231 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
232 c->set_output(0, c->input(0));
233 return Status::OK();
234 };
235 auto s = c.Run(fn).ToString();
236 EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
237 EXPECT_CONTAINS(s, "node foo");
238 EXPECT_CONTAINS(s, "foo_op");
239 EXPECT_CONTAINS(
240 s,
241 "with input shapes: [3], [4] and with computed input tensors: input[1] "
242 "= <1 2 3 4 5>.");
243 }
244
245 // Error when a constant tensor value as shape was requested, and a partial
246 // shape was provided.
247 {
248 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
249 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
250 {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
251 TF_ASSERT_OK(c.construction_status());
252 auto fn = [](InferenceContext* c) {
253 ShapeHandle s;
254 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
255 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
256 ShapeHandle h;
257 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
258 c->set_output(0, c->input(0));
259 return Status::OK();
260 };
261 auto s = c.Run(fn).ToString();
262 EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
263 EXPECT_CONTAINS(s, "node foo");
264 EXPECT_CONTAINS(s, "foo_op");
265 EXPECT_CONTAINS(
266 s,
267 "with input shapes: [3], [4] and with computed "
268 "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
269 "as partial shapes: input[0] = [10,?,5].");
270 }
271 }
272
TEST_F(ShapeInferenceTest,RankAndDimInspection)273 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
274 NodeDef def;
275 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
276 {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
277 EXPECT_EQ(3, c.num_inputs());
278 EXPECT_EQ(2, c.num_outputs());
279
280 auto in0 = c.input(0);
281 EXPECT_EQ("?", c.DebugString(in0));
282 EXPECT_FALSE(c.RankKnown(in0));
283 EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
284 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
285 EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
286 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
287
288 auto in1 = c.input(1);
289 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
290 EXPECT_TRUE(c.RankKnown(in1));
291 EXPECT_EQ(3, c.Rank(in1));
292 auto d = c.Dim(in1, 0);
293 EXPECT_EQ(1, c.Value(d));
294 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
295 EXPECT_TRUE(c.ValueKnown(d));
296 EXPECT_EQ("1", c.DebugString(d));
297 d = c.Dim(in1, 1);
298 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
299 EXPECT_FALSE(c.ValueKnown(d));
300 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
301 EXPECT_EQ("?", c.DebugString(d));
302 d = c.Dim(in1, 2);
303 EXPECT_EQ(3, c.Value(d));
304 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
305 EXPECT_TRUE(c.ValueKnown(d));
306 EXPECT_EQ("3", c.DebugString(d));
307
308 auto in2 = c.input(2);
309 EXPECT_EQ("[]", c.DebugString(in2));
310 EXPECT_TRUE(c.RankKnown(in2));
311 EXPECT_EQ(0, c.Rank(in2));
312 }
313
TEST_F(ShapeInferenceTest,NumElements)314 TEST_F(ShapeInferenceTest, NumElements) {
315 NodeDef def;
316 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
317 {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
318
319 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
320 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
321
322 // Different handles (not the same unknown value).
323 EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
324
325 EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
326 }
327
TEST_F(ShapeInferenceTest,WithRank)328 TEST_F(ShapeInferenceTest, WithRank) {
329 NodeDef def;
330 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
331 {}, {}, {});
332
333 auto in0 = c.input(0);
334 auto in1 = c.input(1);
335 ShapeHandle s1;
336 ShapeHandle s2;
337
338 // WithRank on a shape with unknown dimensionality always succeeds.
339 EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
340 EXPECT_EQ("[?]", c.DebugString(s1));
341
342 EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
343 EXPECT_EQ("[?,?]", c.DebugString(s2));
344 EXPECT_FALSE(SameHandle(s1, s2));
345 EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
346
347 EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
348 EXPECT_EQ("[?]", c.DebugString(s2));
349 EXPECT_FALSE(SameHandle(s1, s2));
350
351 EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
352 EXPECT_EQ("[]", c.DebugString(s1));
353
354 // WithRank on shape with known dimensionality.
355 s1 = in1;
356 Status status = c.WithRank(in1, 2, &s1);
357 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
358 EXPECT_CONTAINS(status.error_message(), "Shape must be rank 2 but is rank 3");
359 EXPECT_FALSE(IsSet(s1));
360 EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
361 EXPECT_TRUE(SameHandle(s1, in1));
362
363 // Inputs are unchanged.
364 EXPECT_EQ("?", c.DebugString(in0));
365 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
366 }
367
TEST_F(ShapeInferenceTest,WithRankAtMost)368 TEST_F(ShapeInferenceTest, WithRankAtMost) {
369 NodeDef def;
370 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
371 {}, {}, {});
372
373 auto in0 = c.input(0);
374 auto in1 = c.input(1);
375 ShapeHandle s1;
376 ShapeHandle s2;
377
378 // WithRankAtMost on a shape with unknown dimensionality always succeeds.
379 EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
380 EXPECT_EQ("?", c.DebugString(s1));
381 EXPECT_TRUE(SameHandle(in0, s1));
382
383 EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
384 EXPECT_EQ("?", c.DebugString(s2));
385 EXPECT_TRUE(SameHandle(s1, s2));
386
387 // WithRankAtMost on shape with known dimensionality.
388 s1 = in1;
389 Status status = c.WithRankAtMost(in1, 2, &s1);
390 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
391 EXPECT_CONTAINS(status.error_message(),
392 "Shape must be at most rank 2 but is rank 3");
393
394 EXPECT_FALSE(IsSet(s1));
395 EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
396 EXPECT_TRUE(SameHandle(s1, in1));
397 EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok());
398 EXPECT_TRUE(SameHandle(s1, in1));
399 EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok());
400 EXPECT_TRUE(SameHandle(s1, in1));
401
402 // Inputs are unchanged.
403 EXPECT_EQ("?", c.DebugString(in0));
404 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
405 }
406
TEST_F(ShapeInferenceTest,WithRankAtLeast)407 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
408 NodeDef def;
409 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
410 {}, {}, {});
411
412 auto in0 = c.input(0);
413 auto in1 = c.input(1);
414 ShapeHandle s1;
415 ShapeHandle s2;
416
417 // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
418 EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
419 EXPECT_EQ("?", c.DebugString(s1));
420 EXPECT_TRUE(SameHandle(in0, s1));
421
422 EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
423 EXPECT_EQ("?", c.DebugString(s2));
424 EXPECT_TRUE(SameHandle(s1, s2));
425
426 // WithRankAtLeast on shape with known dimensionality.
427 s1 = in1;
428 Status status = c.WithRankAtLeast(in1, 4, &s1);
429 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
430 EXPECT_CONTAINS(status.error_message(),
431 "Shape must be at least rank 4 but is rank 3");
432
433 EXPECT_FALSE(IsSet(s1));
434 EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
435 EXPECT_TRUE(SameHandle(s1, in1));
436 EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok());
437 EXPECT_TRUE(SameHandle(s1, in1));
438 EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok());
439 EXPECT_TRUE(SameHandle(s1, in1));
440
441 // Inputs are unchanged.
442 EXPECT_EQ("?", c.DebugString(in0));
443 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
444 }
445
TEST_F(ShapeInferenceTest,WithValue)446 TEST_F(ShapeInferenceTest, WithValue) {
447 NodeDef def;
448 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
449
450 auto d0 = c.Dim(c.input(0), 0);
451 auto d1 = c.Dim(c.input(0), 1);
452 DimensionHandle out1;
453 DimensionHandle out2;
454
455 // WithValue on a dimension with unknown value always succeeds.
456 EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
457 EXPECT_EQ(1, c.Value(out1));
458
459 EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
460 EXPECT_EQ(2, c.Value(out2));
461 EXPECT_FALSE(SameHandle(out1, out2));
462 EXPECT_FALSE(SameHandle(out1, d1));
463
464 EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
465 EXPECT_EQ(1, c.Value(out2));
466 EXPECT_FALSE(SameHandle(out1, out2));
467
468 // WithValue on dimension with known size.
469 out1 = d0;
470
471 Status status = c.WithValue(d0, 0, &out1);
472 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
473 EXPECT_CONTAINS(status.error_message(), "Dimension must be 0 but is 1");
474 EXPECT_FALSE(IsSet(out1));
475 out1 = d0;
476 status = c.WithValue(d0, 2, &out1);
477 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
478 EXPECT_CONTAINS(status.error_message(), "Dimension must be 2 but is 1");
479
480 EXPECT_FALSE(IsSet(out1));
481 EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
482 EXPECT_TRUE(SameHandle(d0, out1));
483
484 // Inputs are unchanged.
485 EXPECT_EQ("1", c.DebugString(d0));
486 EXPECT_EQ("?", c.DebugString(d1));
487 }
488
TEST_F(ShapeInferenceTest,MergeDim)489 TEST_F(ShapeInferenceTest, MergeDim) {
490 NodeDef def;
491 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
492 {}, {});
493
494 auto d2 = c.Dim(c.input(0), 0);
495 auto d_unknown = c.Dim(c.input(0), 1);
496 auto d2_b = c.Dim(c.input(0), 2);
497 auto d1 = c.Dim(c.input(0), 3);
498 auto d_unknown_b = c.Dim(c.input(0), 4);
499 DimensionHandle out;
500
501 // Merging anything with unknown returns the same pointer.
502 EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
503 EXPECT_TRUE(SameHandle(d2, out));
504 EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
505 EXPECT_TRUE(SameHandle(d2, out));
506 EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
507 EXPECT_TRUE(SameHandle(d_unknown, out));
508
509 auto merged_dims = c.MergedDims();
510 ASSERT_EQ(3, merged_dims.size());
511 EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
512 EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
513 EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
514 EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
515 EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
516 EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
517
518 // Merging with self is a no-op and returns self.
519 EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
520 EXPECT_TRUE(SameHandle(d2, out));
521 EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
522 EXPECT_TRUE(SameHandle(d_unknown, out));
523
524 merged_dims = c.MergedDims();
525 EXPECT_EQ(3, merged_dims.size());
526
527 // Merging equal values is a no op and returns first one.
528 EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
529 EXPECT_TRUE(SameHandle(d2, out));
530 EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
531 EXPECT_TRUE(SameHandle(d2_b, out));
532
533 merged_dims = c.MergedDims();
534 EXPECT_EQ(3, merged_dims.size());
535
536 // Merging unequal values is an error.
537 Status status = c.Merge(d2, d1, &out);
538 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
539 EXPECT_CONTAINS(status.error_message(),
540 "Dimensions must be equal, but are 2 and 1");
541
542 EXPECT_FALSE(IsSet(out));
543 status = c.Merge(d1, d2, &out);
544 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
545 EXPECT_CONTAINS(status.error_message(),
546 "Dimensions must be equal, but are 1 and 2");
547
548 EXPECT_FALSE(IsSet(out));
549
550 merged_dims = c.MergedDims();
551 EXPECT_EQ(3, merged_dims.size());
552 }
553
TEST_F(ShapeInferenceTest,RelaxDim)554 TEST_F(ShapeInferenceTest, RelaxDim) {
555 NodeDef def;
556 InferenceContext c(kVersion, def, MakeOpDef(1, 2),
557 {S({2, InferenceContext::kUnknownDim, 2, 1,
558 InferenceContext::kUnknownDim})},
559 {}, {}, {});
560
561 auto d2 = c.Dim(c.input(0), 0);
562 auto d_unknown = c.Dim(c.input(0), 1);
563 auto d2_b = c.Dim(c.input(0), 2);
564 auto d1 = c.Dim(c.input(0), 3);
565 auto d_unknown_b = c.Dim(c.input(0), 4);
566 DimensionHandle out;
567
568 // Relaxing anything with unknown returns a new unknown or the existing
569 // unknown.
570 Relax(&c, d2, d_unknown, &out);
571 EXPECT_TRUE(SameHandle(d_unknown, out));
572 EXPECT_FALSE(SameHandle(d_unknown_b, out));
573 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
574 Relax(&c, d_unknown, d2, &out);
575 EXPECT_FALSE(SameHandle(d_unknown, out));
576 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
577 Relax(&c, d_unknown, d_unknown_b, &out);
578 EXPECT_FALSE(SameHandle(d_unknown, out));
579 EXPECT_TRUE(SameHandle(d_unknown_b, out));
580 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
581
582 // Relaxing with self returns self.
583 Relax(&c, d2, d2, &out);
584 EXPECT_TRUE(SameHandle(d2, out));
585 Relax(&c, d_unknown, d_unknown, &out);
586 EXPECT_TRUE(SameHandle(d_unknown, out));
587
588 // Relaxing equal values returns first one.
589 Relax(&c, d2, d2_b, &out);
590 EXPECT_TRUE(SameHandle(d2, out));
591 Relax(&c, d2_b, d2, &out);
592 EXPECT_TRUE(SameHandle(d2_b, out));
593
594 // Relaxing unequal values returns a new unknown.
595 Relax(&c, d2, d1, &out);
596 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
597 Relax(&c, d1, d2, &out);
598 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
599 }
600
TEST_F(ShapeInferenceTest,RelaxShape)601 TEST_F(ShapeInferenceTest, RelaxShape) {
602 NodeDef def;
603 InferenceContext c(
604 kVersion, def, MakeOpDef(7, 2),
605 {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
606 S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
607 {}, {}, {});
608
609 auto s_unknown = c.input(0);
610 auto s_1_2 = c.input(1);
611 auto s_u_2 = c.input(2);
612 auto s_1_u = c.input(3);
613 auto s_1_3 = c.input(4);
614 auto s_unknown_b = c.input(5);
615 auto s_1 = c.input(6);
616 ShapeHandle out;
617
618 // Relaxing any shape with unknown returns a new unknown.
619 Relax(&c, s_unknown, s_1_2, &out);
620 EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
621 EXPECT_EQ("?", c.DebugString(out));
622 Relax(&c, s_u_2, s_unknown, &out);
623 EXPECT_FALSE(SameHandle(s_u_2, out));
624 EXPECT_EQ("?", c.DebugString(out));
625 Relax(&c, s_unknown, s_unknown_b, &out);
626 EXPECT_FALSE(SameHandle(s_unknown, out));
627 EXPECT_TRUE(SameHandle(s_unknown_b, out));
628 EXPECT_EQ("?", c.DebugString(out));
629
630 // Relaxing with self returns self.
631 Relax(&c, s_1_2, s_1_2, &out);
632 EXPECT_TRUE(SameHandle(out, s_1_2));
633
634 // Relaxing where one of the inputs has less information.
635 out = ShapeHandle();
636 Relax(&c, s_1_2, s_u_2, &out);
637 EXPECT_FALSE(SameHandle(s_u_2, out));
638 EXPECT_EQ("[?,2]", c.DebugString(out));
639 out = ShapeHandle();
640 Relax(&c, s_u_2, s_1_2, &out);
641 EXPECT_FALSE(SameHandle(s_u_2, out));
642 EXPECT_EQ("[?,2]", c.DebugString(out));
643
644 // Relaxing where each input has one distinct unknown dimension.
645 Relax(&c, s_u_2, s_1_u, &out);
646 EXPECT_EQ("[?,?]", c.DebugString(out));
647 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
648 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
649 auto s_u1 = c.UnknownShapeOfRank(1);
650 auto s_u2 = c.UnknownShapeOfRank(1);
651 Relax(&c, s_u1, s_u2, &out);
652 EXPECT_FALSE(SameHandle(s_u1, out));
653
654 // Relaxing with mismatched values in a dimension returns a shape with that
655 // dimension unknown.
656 out = s_unknown;
657 Relax(&c, s_u_2, s_1_3, &out);
658 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
659 EXPECT_EQ("[?,?]", c.DebugString(out));
660 out = s_unknown;
661 Relax(&c, s_1_3, s_u_2, &out);
662 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
663 EXPECT_EQ("[?,?]", c.DebugString(out));
664 out = s_unknown;
665
666 // Relaxing with mismatched ranks returns a new unknown.
667 Relax(&c, s_1, s_1_2, &out);
668 EXPECT_EQ("?", c.DebugString(out));
669 }
670
TEST_F(ShapeInferenceTest,MergeShape)671 TEST_F(ShapeInferenceTest, MergeShape) {
672 NodeDef def;
673 InferenceContext c(kVersion, def, MakeOpDef(7, 2),
674 {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
675 Unknown(), S({1})},
676 {}, {}, {});
677
678 auto s_unknown = c.input(0);
679 auto s_1_2 = c.input(1);
680 auto s_u_2 = c.input(2);
681 auto s_1_u = c.input(3);
682 auto s_1_3 = c.input(4);
683 auto s_unknown_b = c.input(5);
684 auto s_1 = c.input(6);
685 ShapeHandle out;
686
687 // Merging any shape with unknown returns the shape.
688 EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
689 EXPECT_TRUE(SameHandle(s_1_2, out));
690 EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
691 EXPECT_TRUE(SameHandle(s_u_2, out));
692 EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
693 EXPECT_TRUE(SameHandle(s_unknown, out));
694
695 auto merged_shapes = c.MergedShapes();
696 ASSERT_EQ(3, merged_shapes.size());
697 EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
698 EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
699 EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
700 EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
701 EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
702 EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
703
704 // Merging with self returns self.
705 EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
706 EXPECT_TRUE(SameHandle(out, s_1_2));
707
708 merged_shapes = c.MergedShapes();
709 EXPECT_EQ(3, merged_shapes.size());
710
711 // Merging where one of the inputs is the right answer - return that input.
712 out = ShapeHandle();
713 EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
714 EXPECT_TRUE(SameHandle(s_1_2, out));
715 out = ShapeHandle();
716 EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
717 EXPECT_TRUE(SameHandle(s_1_2, out));
718
719 merged_shapes = c.MergedShapes();
720 ASSERT_EQ(5, merged_shapes.size());
721 EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
722 EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
723 EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
724 EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
725
726 // Merging where neither input is the right answer.
727 EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
728 EXPECT_FALSE(SameHandle(out, s_u_2));
729 EXPECT_FALSE(SameHandle(out, s_1_u));
730 EXPECT_EQ("[1,2]", c.DebugString(out));
731 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
732 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
733
734 merged_shapes = c.MergedShapes();
735 ASSERT_EQ(7, merged_shapes.size());
736 EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
737 EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
738 EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
739 EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
740
741 auto s_u1 = c.UnknownShapeOfRank(1);
742 auto s_u2 = c.UnknownShapeOfRank(1);
743 TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
744 EXPECT_TRUE(SameHandle(s_u1, out));
745
746 merged_shapes = c.MergedShapes();
747 ASSERT_EQ(8, merged_shapes.size());
748 EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
749 EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
750
751 // Incompatible merges give errors and set out to nullptr.
752 out = s_unknown;
753 Status status = c.Merge(s_u_2, s_1_3, &out);
754 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
755 EXPECT_CONTAINS(status.error_message(),
756 "Dimension 1 in both shapes must be equal, but are 2 and 3");
757
758 EXPECT_FALSE(IsSet(out));
759 out = s_unknown;
760 status = c.Merge(s_1_3, s_u_2, &out);
761 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
762 EXPECT_CONTAINS(status.error_message(),
763 "Dimension 1 in both shapes must be equal, but are 3 and 2");
764
765 EXPECT_FALSE(IsSet(out));
766 out = s_unknown;
767 status = c.Merge(s_1, s_1_2, &out);
768 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
769 EXPECT_CONTAINS(status.error_message(),
770 "Shapes must be equal rank, but are 1 and 2");
771
772 EXPECT_FALSE(IsSet(out));
773
774 merged_shapes = c.MergedShapes();
775 EXPECT_EQ(8, merged_shapes.size());
776 }
777
TEST_F(ShapeInferenceTest,MergePrefix)778 TEST_F(ShapeInferenceTest, MergePrefix) {
779 NodeDef def;
780 InferenceContext c(kVersion, def, MakeOpDef(4, 2),
781 {
782 Unknown(),
783 S({-1, 2}),
784 S({1, -1, 3}),
785 S({2, 4}),
786 },
787 {}, {}, {});
788
789 auto s_unknown = c.input(0);
790 auto s_u_2 = c.input(1);
791 auto s_1_u_3 = c.input(2);
792 auto s_2_4 = c.input(3);
793
794 ShapeHandle s_out;
795 ShapeHandle s_prefix_out;
796
797 // Merging with unknown returns the inputs.
798 EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
799 EXPECT_TRUE(SameHandle(s_out, s_unknown));
800 EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
801 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok());
802 EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
803 EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
804
805 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok());
806 EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
807 EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
808 EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
809 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
810 EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
811 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
812 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
813
814 // Incompatible merges give errors and set outs to nullptr.
815 s_out = s_unknown;
816 s_prefix_out = s_unknown;
817 Status status = c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out);
818 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
819 EXPECT_CONTAINS(status.error_message(),
820 "Dimensions must be equal, but are 1 and 2");
821
822 EXPECT_FALSE(IsSet(s_out));
823 EXPECT_FALSE(IsSet(s_prefix_out));
824
825 s_out = s_unknown;
826 s_prefix_out = s_unknown;
827 status = c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out);
828 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
829 EXPECT_CONTAINS(status.error_message(),
830 "Shape must be at least rank 3 but is rank 2");
831 EXPECT_FALSE(IsSet(s_out));
832 EXPECT_FALSE(IsSet(s_prefix_out));
833 }
834
TEST_F(ShapeInferenceTest,Subshape)835 TEST_F(ShapeInferenceTest, Subshape) {
836 NodeDef def;
837 InferenceContext c(kVersion, def, MakeOpDef(2, 2),
838 {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
839
840 ShapeHandle unknown = c.input(1);
841 ShapeHandle out;
842 EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
843 EXPECT_EQ("?", c.DebugString(out));
844 EXPECT_TRUE(SameHandle(out, unknown));
845 EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok());
846 EXPECT_EQ("?", c.DebugString(out));
847 EXPECT_FALSE(SameHandle(out, unknown));
848 EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok());
849 EXPECT_EQ("?", c.DebugString(out));
850 EXPECT_FALSE(SameHandle(out, unknown));
851
852 const int kFullRank = 5;
853 ShapeHandle out_arr[4];
854 auto in0 = c.input(0);
855 EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
856 EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
857 EXPECT_TRUE(SameHandle(out, in0));
858 EXPECT_EQ(kFullRank, c.Rank(out));
859 for (int start = 0; start <= kFullRank + 1; ++start) {
860 for (int end = start; end <= kFullRank + 1; ++end) {
861 // Get subshapes using different start and end values that give the same
862 // range.
863 const int neg_start =
864 start >= kFullRank ? kFullRank : (start - kFullRank);
865 const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
866 ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok());
867 ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok());
868 ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok());
869 ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok());
870
871 // Verify all computed subshapes.
872 for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
873 out = out_arr[arr_idx];
874 ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
875 c.Rank(out))
876 << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
877 << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
878 for (int d = 0; d < c.Rank(out); ++d) {
879 EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
880 << "arr_idx: " << arr_idx;
881 }
882 }
883 }
884 }
885
886 // Errors.
887 out = unknown;
888 Status status = c.Subshape(in0, 6, -3, &out);
889 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
890 EXPECT_CONTAINS(
891 status.error_message(),
892 "Subshape must have computed start <= end, but is 5 "
893 "and 2 (computed from start 6 and end -3 over shape with rank 5)");
894 EXPECT_FALSE(IsSet(out));
895 out = unknown;
896 status = c.Subshape(in0, -50, 100, &out);
897 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
898 EXPECT_CONTAINS(status.error_message(),
899 "Subshape start out of bounds: -50, for shape with rank 5");
900
901 EXPECT_FALSE(IsSet(out));
902 out = unknown;
903 status = c.Subshape(in0, 0, -50, &out);
904 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
905 EXPECT_CONTAINS(status.error_message(),
906 "Subshape end out of bounds: -50, for shape with rank 5");
907
908 EXPECT_FALSE(IsSet(out));
909 }
910
TEST_F(ShapeInferenceTest,Concatenate)911 TEST_F(ShapeInferenceTest, Concatenate) {
912 NodeDef def;
913 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
914 {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
915
916 auto in0 = c.input(0);
917 auto in1 = c.input(1);
918 ShapeHandle unknown = c.input(2);
919 ShapeHandle out;
920 EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
921 EXPECT_EQ("?", c.DebugString(out));
922 EXPECT_FALSE(SameHandle(out, unknown));
923 EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok());
924 EXPECT_EQ("?", c.DebugString(out));
925 EXPECT_FALSE(SameHandle(out, unknown));
926
927 EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok());
928 EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
929 int out_i = 0;
930 for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
931 EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
932 }
933 for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
934 EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
935 }
936 }
937
TEST_F(ShapeInferenceTest,ReplaceDim)938 TEST_F(ShapeInferenceTest, ReplaceDim) {
939 NodeDef def;
940 InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
941 {}, {}, {});
942
943 auto in = c.input(0);
944 auto unknown = c.input(1);
945
946 ShapeHandle replaced;
947 EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
948 EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
949 EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
950 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
951 EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
952 EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
953 EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
954 EXPECT_EQ("?", c.DebugString(replaced));
955
956 // Negative indexing.
957 EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
958 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
959 EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
960 EXPECT_EQ("?", c.DebugString(replaced));
961
962 // out of range indexing.
963 EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
964 EXPECT_FALSE(IsSet(replaced));
965 replaced = in;
966 EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
967 EXPECT_FALSE(IsSet(replaced));
968 }
969
TEST_F(ShapeInferenceTest,MakeShape)970 TEST_F(ShapeInferenceTest, MakeShape) {
971 NodeDef def;
972 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
973 {}, {});
974
975 std::vector<DimensionHandle> dims;
976 auto in0 = c.input(0);
977 const int rank = c.Rank(in0);
978 dims.reserve(rank);
979 for (int i = 0; i < rank; ++i) {
980 dims.push_back(c.Dim(in0, rank - i - 1));
981 }
982
983 auto s = c.MakeShape(dims);
984 EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
985 EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
986
987 auto s2 = c.MakeShape(dims);
988 EXPECT_FALSE(SameHandle(s, s2));
989 EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
990
991 auto s3 = c.MakeShape({1, 2, dims[2]});
992 EXPECT_FALSE(SameHandle(s, s3));
993 EXPECT_EQ("[1,2,3]", c.DebugString(s3));
994 }
995
TEST_F(ShapeInferenceTest,UnknownShape)996 TEST_F(ShapeInferenceTest, UnknownShape) {
997 NodeDef def;
998 std::vector<ShapeHandle> empty;
999 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1000
1001 auto u0 = c.UnknownShape();
1002 auto u1 = c.UnknownShape();
1003 EXPECT_EQ("?", c.DebugString(u0));
1004 EXPECT_EQ("?", c.DebugString(u1));
1005 EXPECT_FALSE(SameHandle(u0, u1));
1006 }
1007
TEST_F(ShapeInferenceTest,KnownShapeToProto)1008 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
1009 NodeDef def;
1010 std::vector<ShapeHandle> empty;
1011 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1012
1013 auto s = c.MakeShape({1, 2, 3});
1014 TensorShapeProto proto;
1015 c.ShapeHandleToProto(s, &proto);
1016
1017 EXPECT_FALSE(proto.unknown_rank());
1018 EXPECT_EQ(3, proto.dim_size());
1019 EXPECT_EQ(1, proto.dim(0).size());
1020 }
1021
TEST_F(ShapeInferenceTest,UnknownShapeToProto)1022 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
1023 NodeDef def;
1024 std::vector<ShapeHandle> empty;
1025 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1026
1027 auto u0 = c.UnknownShape();
1028 TensorShapeProto proto;
1029 c.ShapeHandleToProto(u0, &proto);
1030
1031 EXPECT_TRUE(proto.unknown_rank());
1032 EXPECT_EQ(0, proto.dim_size());
1033 }
1034
TEST_F(ShapeInferenceTest,Scalar)1035 TEST_F(ShapeInferenceTest, Scalar) {
1036 NodeDef def;
1037 std::vector<ShapeHandle> empty;
1038 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1039
1040 auto s0 = c.Scalar();
1041 EXPECT_EQ("[]", c.DebugString(s0));
1042 auto s1 = c.Scalar();
1043 EXPECT_EQ("[]", c.DebugString(s1));
1044 }
1045
TEST_F(ShapeInferenceTest,Vector)1046 TEST_F(ShapeInferenceTest, Vector) {
1047 NodeDef def;
1048 std::vector<ShapeHandle> empty;
1049 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1050
1051 auto s0 = c.Vector(1);
1052 EXPECT_EQ("[1]", c.DebugString(s0));
1053 auto s1 = c.Vector(InferenceContext::kUnknownDim);
1054 EXPECT_EQ("[?]", c.DebugString(s1));
1055
1056 auto d1 = c.UnknownDim();
1057 auto s2 = c.Vector(d1);
1058 EXPECT_EQ("[?]", c.DebugString(s2));
1059 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1060 }
1061
TEST_F(ShapeInferenceTest,Matrix)1062 TEST_F(ShapeInferenceTest, Matrix) {
1063 NodeDef def;
1064 std::vector<ShapeHandle> empty;
1065 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1066
1067 auto s0 = c.Matrix(1, 2);
1068 EXPECT_EQ("[1,2]", c.DebugString(s0));
1069 auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1070 EXPECT_EQ("[0,?]", c.DebugString(s1));
1071
1072 auto d1 = c.UnknownDim();
1073 auto d2 = c.UnknownDim();
1074 auto s2 = c.Matrix(d1, d2);
1075 EXPECT_EQ("[?,?]", c.DebugString(s2));
1076 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1077 EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1078
1079 auto s3 = c.Matrix(d1, 100);
1080 EXPECT_EQ("[?,100]", c.DebugString(s3));
1081 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1082 }
1083
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1084 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1085 auto create = [&](Tensor* t) {
1086 NodeDef def;
1087 InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1088 {});
1089 ShapeHandle out;
1090 Status s = c.MakeShapeFromShapeTensor(0, &out);
1091 if (s.ok()) {
1092 return c.DebugString(out);
1093 } else {
1094 EXPECT_FALSE(IsSet(out));
1095 return s.error_message();
1096 }
1097 };
1098
1099 Tensor t;
1100 EXPECT_EQ("?", create(nullptr));
1101
1102 t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1103 EXPECT_EQ("[1,2,3]", create(&t));
1104
1105 t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
1106 EXPECT_EQ("[3,2,1]", create(&t));
1107
1108 t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
1109 EXPECT_EQ("[3,?,1]", create(&t));
1110
1111 t = ::tensorflow::test::AsTensor<int64>({});
1112 EXPECT_EQ("[]", create(&t));
1113
1114 // Test negative scalar
1115 t = ::tensorflow::test::AsScalar<int32>(-1);
1116 EXPECT_EQ("?", create(&t));
1117
1118 t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1119 EXPECT_CONTAINS(create(&t),
1120 "Input tensor must be int32 or int64, but was float");
1121
1122 t = ::tensorflow::test::AsScalar<int32>(1);
1123 auto s_scalar = create(&t);
1124 EXPECT_CONTAINS(
1125 s_scalar,
1126 "Input tensor must be rank 1, or if its rank 0 it must have value -1");
1127
1128 t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1129 auto s_matrix = create(&t);
1130 EXPECT_CONTAINS(s_matrix, "Input tensor must be rank 1, but was rank 2");
1131
1132 // Test negative values for the dims.
1133 t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
1134 EXPECT_CONTAINS(create(&t), "Invalid value in tensor used for shape: -2");
1135
1136 // Test negative values for the dims.
1137 t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1138 EXPECT_CONTAINS(create(&t), "Invalid value in tensor used for shape: -2");
1139
1140 // Test when the input shape is wrong.
1141 {
1142 NodeDef def;
1143 InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1144 {}, {});
1145 ShapeHandle out;
1146 EXPECT_EQ("Shape must be rank 1 but is rank 2",
1147 c.MakeShapeFromShapeTensor(0, &out).error_message());
1148 }
1149 }
1150
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1151 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1152 NodeDef def;
1153 std::vector<ShapeHandle> empty;
1154 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1155
1156 // With an unknown rank.
1157 ShapeHandle out;
1158 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1159 EXPECT_EQ("?", c.DebugString(out));
1160
1161 // With a known rank.
1162 TF_ASSERT_OK(
1163 c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1164 EXPECT_EQ("[0]", c.DebugString(out));
1165 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1166 PartialTensorShape({0, -1, 1000}), &out));
1167 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1168 }
1169
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1170 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1171 NodeDef def;
1172 std::vector<ShapeHandle> empty;
1173 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1174
1175 ShapeHandle out;
1176 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1177 EXPECT_EQ("[]", c.DebugString(out));
1178 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1179 EXPECT_EQ("[0]", c.DebugString(out));
1180 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1181 EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1182 }
1183
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1184 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1185 NodeDef def;
1186 std::vector<ShapeHandle> empty;
1187 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1188 TensorShapeProto proto;
1189
1190 // With a set unknown rank.
1191 ShapeHandle out;
1192 proto.set_unknown_rank(true);
1193 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1194 EXPECT_EQ("?", c.DebugString(out));
1195 proto.add_dim()->set_size(0);
1196 EXPECT_CONTAINS(c.MakeShapeFromShapeProto(proto, &out).error_message(),
1197 "An unknown shape must not have any dimensions set");
1198 EXPECT_FALSE(IsSet(out));
1199
1200 // With known rank.
1201 proto.set_unknown_rank(false);
1202 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1203 EXPECT_EQ("[0]", c.DebugString(out));
1204 proto.add_dim()->set_size(-1);
1205 proto.add_dim()->set_size(1000);
1206 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1207 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1208
1209 // With invalid dimension value.
1210 proto.add_dim()->set_size(-2);
1211 EXPECT_CONTAINS(c.MakeShapeFromShapeProto(proto, &out).error_message(),
1212 "Shape [0,?,1000,-2] has dimensions with values below -1 "
1213 "(where -1 means unknown)");
1214
1215 EXPECT_FALSE(IsSet(out));
1216 }
1217
TEST_F(ShapeInferenceTest,MakeDim)1218 TEST_F(ShapeInferenceTest, MakeDim) {
1219 NodeDef def;
1220 std::vector<ShapeHandle> empty;
1221 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1222
1223 auto d0 = c.MakeDim(1);
1224 auto d1 = c.MakeDim(1);
1225 auto d2 = c.MakeDim(2);
1226 EXPECT_EQ("1", c.DebugString(d0));
1227 EXPECT_EQ("1", c.DebugString(d1));
1228 EXPECT_FALSE(SameHandle(d0, d1));
1229 EXPECT_EQ("2", c.DebugString(d2));
1230 }
1231
TEST_F(ShapeInferenceTest,UnknownDim)1232 TEST_F(ShapeInferenceTest, UnknownDim) {
1233 NodeDef def;
1234 std::vector<ShapeHandle> empty;
1235 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1236
1237 auto d0 = c.UnknownDim();
1238 auto d1 = c.UnknownDim();
1239 EXPECT_EQ("?", c.DebugString(d0));
1240 EXPECT_EQ("?", c.DebugString(d1));
1241 EXPECT_FALSE(SameHandle(d0, d1));
1242 }
1243
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1244 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1245 NodeDef def;
1246 std::vector<ShapeHandle> empty;
1247 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1248
1249 auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1250 EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1251
1252 auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1253 EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1254 }
1255
TEST_F(ShapeInferenceTest,InputTensors)1256 TEST_F(ShapeInferenceTest, InputTensors) {
1257 const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1258 const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1259 NodeDef def;
1260 InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1261 {&t1, &t2}, {}, {});
1262
1263 EXPECT_TRUE(c.input_tensor(0) == &t1);
1264 EXPECT_TRUE(c.input_tensor(1) == &t2);
1265 EXPECT_TRUE(c.input_tensor(2) == nullptr);
1266 }
1267
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1268 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1269 Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1270 Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1271 NodeDef def;
1272 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1273 {}, {});
1274
1275 DimensionHandle d;
1276 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1277 EXPECT_EQ("20", c.DebugString(d));
1278
1279 EXPECT_CONTAINS(c.MakeDimForScalarInput(1, &d).error_message(),
1280 "Dimension size, given by scalar input 1, must be "
1281 "non-negative but is -1");
1282
1283 // Same tests, with int64 values.
1284 t1 = tensorflow::test::AsScalar<int64>(20);
1285 t2 = tensorflow::test::AsScalar<int64>(-1);
1286 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1287 EXPECT_EQ("20", c.DebugString(d));
1288
1289 EXPECT_CONTAINS(c.MakeDimForScalarInput(1, &d).error_message(),
1290 "Dimension size, given by scalar input 1, must be "
1291 "non-negative but is -1");
1292 }
1293
TEST_F(ShapeInferenceTest,GetAttr)1294 TEST_F(ShapeInferenceTest, GetAttr) {
1295 OpRegistrationData op_reg_data;
1296 op_reg_data.op_def = MakeOpDef(0, 2);
1297 NodeDef def;
1298 CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1299 .Attr("foo", "bar")
1300 .Finalize(&def)
1301 .ok());
1302
1303 std::vector<ShapeHandle> empty;
1304 InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
1305 string value;
1306 EXPECT_TRUE(c.GetAttr("foo", &value).ok());
1307 EXPECT_EQ("bar", value);
1308 }
1309
TEST_F(ShapeInferenceTest,Divide)1310 TEST_F(ShapeInferenceTest, Divide) {
1311 NodeDef def;
1312 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1313 {}, {});
1314
1315 auto s = c.input(0);
1316 auto d_6 = c.Dim(s, 0);
1317 auto d_unknown = c.Dim(s, 1);
1318 auto d_1 = c.Dim(s, 2);
1319 auto d_2 = c.Dim(s, 3);
1320 auto d_0 = c.Dim(s, 4);
1321 bool evenly_divisible = true;
1322
1323 // Dividing unknown by non-1 gives new unknown.
1324 DimensionHandle out;
1325 EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok());
1326 EXPECT_EQ("?", c.DebugString(out));
1327 EXPECT_FALSE(SameHandle(out, d_unknown));
1328
1329 // Dividing anything by 1 returns the input.
1330 EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok());
1331 EXPECT_TRUE(SameHandle(out, d_unknown));
1332 EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
1333 EXPECT_TRUE(SameHandle(out, d_6));
1334 EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
1335 EXPECT_TRUE(SameHandle(out, d_unknown));
1336 EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
1337 EXPECT_TRUE(SameHandle(out, d_6));
1338
1339 EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
1340 EXPECT_EQ("3", c.DebugString(out));
1341 EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
1342 EXPECT_EQ("3", c.DebugString(out));
1343
1344 EXPECT_CONTAINS(c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
1345 "Dimension size must be evenly divisible by 5 but is 6");
1346
1347 EXPECT_CONTAINS(c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1348 "Divisor must be positive but is 0");
1349 EXPECT_CONTAINS(c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
1350 "Divisor must be positive but is 0");
1351
1352 EXPECT_CONTAINS(c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1353 "Divisor must be positive but is -1");
1354
1355 // Repeat error cases above with evenly_divisible=false.
1356 evenly_divisible = false;
1357 EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
1358 EXPECT_EQ("1", c.DebugString(out));
1359
1360 EXPECT_CONTAINS(c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1361 "Divisor must be positive but is 0");
1362
1363 EXPECT_CONTAINS(c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1364 "Divisor must be positive but is -1");
1365 }
1366
TEST_F(ShapeInferenceTest,Add)1367 TEST_F(ShapeInferenceTest, Add) {
1368 NodeDef def;
1369 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1370 {});
1371
1372 auto s = c.input(0);
1373 auto d_6 = c.Dim(s, 0);
1374 auto d_unknown = c.Dim(s, 1);
1375 auto d_0 = c.Dim(s, 2);
1376
1377 // Adding non-zero to unknown gives new unknown.
1378 DimensionHandle out;
1379 EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
1380 EXPECT_EQ("?", c.DebugString(out));
1381 EXPECT_FALSE(SameHandle(out, d_unknown));
1382
1383 // Adding 0 to anything gives input.
1384 EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok());
1385 EXPECT_TRUE(SameHandle(out, d_unknown));
1386 EXPECT_TRUE(c.Add(d_6, 0, &out).ok());
1387 EXPECT_TRUE(SameHandle(out, d_6));
1388
1389 // Adding dimension with value 0 to anything gives input.
1390 EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
1391 EXPECT_TRUE(SameHandle(out, d_unknown));
1392 EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
1393 EXPECT_TRUE(SameHandle(out, d_6));
1394
1395 // Test addition.
1396 EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
1397 EXPECT_EQ("8", c.DebugString(out));
1398 EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
1399 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1400
1401 // Test addition using dimension as second value.
1402 EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok());
1403 EXPECT_EQ("8", c.DebugString(out));
1404 EXPECT_TRUE(
1405 c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok());
1406 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1407 EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
1408 EXPECT_EQ("?", c.DebugString(out));
1409 EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
1410 EXPECT_TRUE(SameHandle(out, d_6));
1411
1412 EXPECT_CONTAINS(
1413 c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
1414 "Dimension size overflow from adding 6 and 9223372036854775802");
1415 }
1416
TEST_F(ShapeInferenceTest,Subtract)1417 TEST_F(ShapeInferenceTest, Subtract) {
1418 NodeDef def;
1419 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
1420 {});
1421
1422 auto s = c.input(0);
1423 auto d_6 = c.Dim(s, 0);
1424 auto d_unknown = c.Dim(s, 1);
1425 auto d_0 = c.Dim(s, 2);
1426 auto d_5 = c.Dim(s, 3);
1427
1428 // Subtracting non-zero from unknown gives new unknown.
1429 DimensionHandle out;
1430 EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
1431 EXPECT_EQ("?", c.DebugString(out));
1432 EXPECT_FALSE(SameHandle(out, d_unknown));
1433
1434 // Subtracting 0 from anything gives input.
1435 EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
1436 EXPECT_TRUE(SameHandle(out, d_unknown));
1437 EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
1438 EXPECT_TRUE(SameHandle(out, d_6));
1439
1440 // Subtracting dimension with value 0 from anything gives input.
1441 EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
1442 EXPECT_TRUE(SameHandle(out, d_unknown));
1443 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
1444 EXPECT_TRUE(SameHandle(out, d_6));
1445
1446 // Test subtraction.
1447 EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
1448 EXPECT_EQ("4", c.DebugString(out));
1449 EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
1450 EXPECT_EQ("0", c.DebugString(out));
1451
1452 // Test subtraction using dimension as second value.
1453 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
1454 EXPECT_EQ("4", c.DebugString(out));
1455 EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
1456 EXPECT_EQ("1", c.DebugString(out));
1457 EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
1458 EXPECT_EQ("?", c.DebugString(out));
1459 EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
1460 EXPECT_TRUE(SameHandle(out, d_6));
1461
1462 EXPECT_CONTAINS(c.Subtract(d_5, d_6, &out).error_message(),
1463 "Negative dimension size caused by subtracting 6 from 5");
1464 }
1465
TEST_F(ShapeInferenceTest,Multiply)1466 TEST_F(ShapeInferenceTest, Multiply) {
1467 NodeDef def;
1468 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
1469 {});
1470
1471 auto s = c.input(0);
1472 auto d_6 = c.Dim(s, 0);
1473 auto d_unknown = c.Dim(s, 1);
1474 auto d_0 = c.Dim(s, 2);
1475 auto d_1 = c.Dim(s, 3);
1476
1477 // Multiplying non-zero to unknown gives new unknown.
1478 DimensionHandle out;
1479 EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
1480 EXPECT_EQ("?", c.DebugString(out));
1481
1482 // Multiplying 0 to anything gives 0.
1483 EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok());
1484 EXPECT_EQ("0", c.DebugString(out));
1485 EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
1486 EXPECT_EQ("0", c.DebugString(out));
1487 EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
1488 EXPECT_EQ("0", c.DebugString(out));
1489
1490 // Multiplying 1 to anything gives the original.
1491 // (unknown -> unknown)
1492 EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
1493 EXPECT_TRUE(SameHandle(d_unknown, out));
1494 EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
1495 EXPECT_TRUE(SameHandle(d_unknown, out));
1496 EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
1497 EXPECT_TRUE(SameHandle(d_unknown, out));
1498 // (known -> known)
1499 EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
1500 EXPECT_TRUE(SameHandle(d_6, out));
1501 EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
1502 EXPECT_TRUE(SameHandle(d_6, out));
1503 EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
1504 EXPECT_TRUE(SameHandle(d_6, out));
1505
1506 // Test multiplication.
1507 EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
1508 EXPECT_EQ("12", c.DebugString(out));
1509 EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
1510 EXPECT_EQ("36", c.DebugString(out));
1511
1512 // Test multiplication using dimension as second value.
1513 EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
1514 EXPECT_EQ("12", c.DebugString(out));
1515 EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
1516 EXPECT_EQ("?", c.DebugString(out));
1517 }
1518
TEST_F(ShapeInferenceTest,FullyDefined)1519 TEST_F(ShapeInferenceTest, FullyDefined) {
1520 NodeDef def;
1521 std::vector<ShapeHandle> empty;
1522 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1523
1524 // No rank or missing dimension information should return false.
1525 EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1526 EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1527
1528 // Return true if all information exists.
1529 EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1530 EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1531 }
1532
TEST_F(ShapeInferenceTest,Min)1533 TEST_F(ShapeInferenceTest, Min) {
1534 NodeDef def;
1535 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
1536 {});
1537
1538 auto s = c.input(0);
1539 auto d_1 = c.Dim(s, 0);
1540 auto d_2 = c.Dim(s, 1);
1541 auto d_unknown = c.Dim(s, 2);
1542 auto d_0 = c.Dim(s, 3);
1543
1544 // Minimum involving zero and unknown returns zero.
1545 DimensionHandle out;
1546 EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
1547 EXPECT_TRUE(SameHandle(d_0, out));
1548 EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
1549 EXPECT_TRUE(SameHandle(d_0, out));
1550 EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
1551 EXPECT_EQ("0", c.DebugString(out));
1552 EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
1553 EXPECT_EQ("0", c.DebugString(out));
1554
1555 // Minimum involving unknowns and non-zeros gives new unknown.
1556 EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
1557 EXPECT_EQ("?", c.DebugString(out));
1558 EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
1559 EXPECT_EQ("?", c.DebugString(out));
1560 EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
1561 EXPECT_EQ("?", c.DebugString(out));
1562
1563 // Minimum with constant second arg.
1564 EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
1565 EXPECT_TRUE(SameHandle(d_1, out));
1566 EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
1567 EXPECT_TRUE(SameHandle(d_1, out));
1568 EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
1569 EXPECT_EQ("1", c.DebugString(out));
1570
1571 // Minimum with two dimensions.
1572 EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
1573 EXPECT_TRUE(SameHandle(d_1, out));
1574 EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
1575 EXPECT_TRUE(SameHandle(d_1, out));
1576 EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
1577 EXPECT_TRUE(SameHandle(d_1, out));
1578 EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
1579 EXPECT_TRUE(SameHandle(d_2, out));
1580 }
1581
TEST_F(ShapeInferenceTest,Max)1582 TEST_F(ShapeInferenceTest, Max) {
1583 NodeDef def;
1584 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1585 {});
1586
1587 auto s = c.input(0);
1588 auto d_1 = c.Dim(s, 0);
1589 auto d_2 = c.Dim(s, 1);
1590 auto d_unknown = c.Dim(s, 2);
1591
1592 // Maximum involving unknowns gives new unknown.
1593 DimensionHandle out;
1594 EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
1595 EXPECT_EQ("?", c.DebugString(out));
1596 EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
1597 EXPECT_EQ("?", c.DebugString(out));
1598 EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
1599 EXPECT_EQ("?", c.DebugString(out));
1600
1601 // Maximum with constant second arg.
1602 EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
1603 EXPECT_TRUE(SameHandle(d_1, out));
1604 EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
1605 EXPECT_TRUE(SameHandle(d_2, out));
1606 EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
1607 EXPECT_EQ("3", c.DebugString(out));
1608
1609 // Maximum with two dimensions.
1610 EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
1611 EXPECT_TRUE(SameHandle(d_1, out));
1612 EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
1613 EXPECT_TRUE(SameHandle(d_2, out));
1614 EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
1615 EXPECT_TRUE(SameHandle(d_2, out));
1616 EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
1617 EXPECT_TRUE(SameHandle(d_2, out));
1618 }
1619
TestMergeHandles(bool input_not_output)1620 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
1621 NodeDef def;
1622 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1623 {});
1624 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1625 ShapeHandle s;
1626 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1627 return s;
1628 };
1629 auto get_shapes_and_types_from_context = [&](int idx) {
1630 if (input_not_output) {
1631 return c.input_handle_shapes_and_types(idx);
1632 } else {
1633 return c.output_handle_shapes_and_types(idx);
1634 }
1635 };
1636 auto merge_shapes_and_types_to_context =
1637 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1638 if (input_not_output) {
1639 return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1640 } else {
1641 return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1642 }
1643 };
1644
1645 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1646 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1647
1648 // First merge will take the input completely.
1649 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1650 {c.UnknownShape(), DT_INVALID},
1651 {make_shape({4, 3, 2, 1}), DT_INT32}};
1652 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1653 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1654 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1655 ASSERT_EQ(3, v.size());
1656 for (int i = 0; i < v.size(); ++i) {
1657 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1658 EXPECT_EQ(t[i].dtype, v[i].dtype);
1659 }
1660
1661 // Merge that fails because wrong number of values passed.
1662 // Fails, and no changes made.
1663 ASSERT_FALSE(merge_shapes_and_types_to_context(
1664 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1665 v = *get_shapes_and_types_from_context(0);
1666 ASSERT_EQ(3, v.size());
1667 for (int i = 0; i < v.size(); ++i) {
1668 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1669 EXPECT_EQ(t[i].dtype, v[i].dtype);
1670 }
1671
1672 // Only difference is in a mismatched shape. That is ignored,
1673 // and there are no other changes, so nothing is done.
1674 //
1675 // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1676 // return an error (separate error from 'refined' output)?
1677 auto t2 = t;
1678 t2[2].shape = make_shape({4, 3, 4, 1});
1679 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1680 v = *get_shapes_and_types_from_context(0);
1681 ASSERT_EQ(3, v.size());
1682 for (int i = 0; i < v.size(); ++i) {
1683 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1684 EXPECT_EQ(t[i].dtype, v[i].dtype);
1685 }
1686
1687 // Only difference is in a mismatched dtype, but that cannot be
1688 // updated unless original dtype is DT_INVALID.
1689 t2 = t;
1690 t2[2].dtype = DT_FLOAT;
1691 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1692 v = *get_shapes_and_types_from_context(0);
1693 ASSERT_EQ(3, v.size());
1694 for (int i = 0; i < v.size(); ++i) {
1695 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1696 EXPECT_EQ(t[i].dtype, v[i].dtype);
1697 }
1698
1699 // Difference is mergeable (new shape).
1700 t[1].shape = make_shape({1, 10});
1701 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1702 v = *get_shapes_and_types_from_context(0);
1703 ASSERT_EQ(3, v.size());
1704 for (int i = 0; i < v.size(); ++i) {
1705 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1706 EXPECT_EQ(t[i].dtype, v[i].dtype);
1707 }
1708
1709 // Difference is mergeable (new type).
1710 t[1].dtype = DT_DOUBLE;
1711 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1712 v = *get_shapes_and_types_from_context(0);
1713 ASSERT_EQ(3, v.size());
1714 for (int i = 0; i < v.size(); ++i) {
1715 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1716 EXPECT_EQ(t[i].dtype, v[i].dtype);
1717 }
1718
1719 // No difference.
1720 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1721 }
1722
TEST_F(ShapeInferenceTest,MergeInputHandleShapesAndTypes)1723 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
1724 TestMergeHandles(true /* input_not_output */);
1725 }
1726
TEST_F(ShapeInferenceTest,MergeOutputHandleShapesAndTypes)1727 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
1728 TestMergeHandles(false /* input_not_output */);
1729 }
1730
TestRelaxHandles(bool input_not_output)1731 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
1732 NodeDef def;
1733 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1734 {});
1735 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1736 ShapeHandle s;
1737 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1738 return s;
1739 };
1740 auto get_shapes_and_types_from_context = [&](int idx) {
1741 if (input_not_output) {
1742 return c.input_handle_shapes_and_types(idx);
1743 } else {
1744 return c.output_handle_shapes_and_types(idx);
1745 }
1746 };
1747 auto relax_shapes_and_types_to_context =
1748 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1749 if (input_not_output) {
1750 return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1751 } else {
1752 return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1753 }
1754 };
1755
1756 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1757 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1758
1759 // First relax will take the input completely.
1760 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1761 {c.UnknownShape(), DT_INVALID},
1762 {make_shape({4, 3, 2, 1}), DT_INT32}};
1763 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1764 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1765 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1766 ASSERT_EQ(3, v.size());
1767 for (int i = 0; i < v.size(); ++i) {
1768 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1769 EXPECT_EQ(t[i].dtype, v[i].dtype);
1770 }
1771
1772 // Relax that fails because wrong number of values passed.
1773 // Fails, and no changes made.
1774 ASSERT_FALSE(relax_shapes_and_types_to_context(
1775 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1776 v = *get_shapes_and_types_from_context(0);
1777 ASSERT_EQ(3, v.size());
1778 for (int i = 0; i < v.size(); ++i) {
1779 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1780 EXPECT_EQ(t[i].dtype, v[i].dtype);
1781 }
1782
1783 // Only difference is in a mismatched shape. This should replace
1784 // the mismatched dimension with an UnknownDim.
1785 auto t2 = t;
1786 t2[2].shape = make_shape({4, 3, 4, 1});
1787 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1788 v = *get_shapes_and_types_from_context(0);
1789 EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1790 for (int i = 0; i < v.size(); ++i) {
1791 EXPECT_EQ(t[i].dtype, v[i].dtype);
1792 }
1793
1794 // Only difference is in a mismatched dtype, but that cannot be
1795 // updated unless original dtype is DT_INVALID.
1796 t2 = t;
1797 t2[2].dtype = DT_FLOAT;
1798 ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1799 v = *get_shapes_and_types_from_context(0);
1800 ASSERT_EQ(3, v.size());
1801 for (int i = 0; i < v.size(); ++i) {
1802 EXPECT_EQ(t[i].dtype, v[i].dtype);
1803 }
1804
1805 // Difference is a new shape, which will result in a new UnknownShape.
1806 t[1].shape = make_shape({1, 10});
1807 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1808 v = *get_shapes_and_types_from_context(0);
1809 ASSERT_EQ(3, v.size());
1810 EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1811 EXPECT_EQ("?", c.DebugString(v[1].shape));
1812 for (int i = 0; i < v.size(); ++i) {
1813 EXPECT_EQ(t[i].dtype, v[i].dtype);
1814 }
1815
1816 // Difference is relaxable (new type).
1817 t[1].dtype = DT_DOUBLE;
1818 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1819 v = *get_shapes_and_types_from_context(0);
1820 EXPECT_EQ(t[1].dtype, v[1].dtype);
1821 }
1822
TEST_F(ShapeInferenceTest,RelaxInputHandleShapesAndTypes)1823 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) {
1824 TestRelaxHandles(true /* input_not_output */);
1825 }
1826
TEST_F(ShapeInferenceTest,RelaxOutputHandleShapesAndTypes)1827 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) {
1828 TestRelaxHandles(false /* input_not_output */);
1829 }
1830
1831 } // namespace shape_inference
1832 } // namespace tensorflow
1833