1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/fake_input.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/framework/op_def_builder.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace shape_inference {
30 namespace {
31
MakeOpDefWithLists()32 OpDef MakeOpDefWithLists() {
33 OpRegistrationData op_reg_data;
34 OpDefBuilder b("dummy");
35 b.Input(strings::StrCat("input: N * float"));
36 b.Output(strings::StrCat("output: N * float"));
37 CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok());
38 return op_reg_data.op_def;
39 }
40
S(std::initializer_list<int64> dims)41 PartialTensorShape S(std::initializer_list<int64> dims) {
42 return PartialTensorShape(dims);
43 }
44
Unknown()45 PartialTensorShape Unknown() { return PartialTensorShape(); }
46
47 } // namespace
48
49 class ShapeInferenceTest : public ::testing::Test {
50 protected:
51 // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)52 bool SameHandle(DimensionHandle a, DimensionHandle b) {
53 return a.SameHandle(b);
54 }
SameHandle(ShapeHandle a,ShapeHandle b)55 bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)56 bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)57 bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)58 void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
59 DimensionHandle* out) {
60 c->Relax(d0, d1, out);
61 }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)62 void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
63 ShapeHandle* out) {
64 c->Relax(s0, s1, out);
65 }
66 void TestMergeHandles(bool input_not_output);
67 void TestRelaxHandles(bool input_not_output);
68
69 static const int kVersion = 0; // used for graph-def version.
70 };
71
TEST_F(ShapeInferenceTest,InputOutputByName)72 TEST_F(ShapeInferenceTest, InputOutputByName) {
73 // Setup test to contain an input tensor list of size 3.
74 OpDef op_def = MakeOpDefWithLists();
75 NodeDef def;
76 auto s = NodeDefBuilder("dummy", &op_def)
77 .Attr("N", 3)
78 .Input(FakeInput(DT_FLOAT))
79 .Finalize(&def);
80 InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
81 {}, {}, {});
82
83 EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
84 EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
85 EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
86 // Test getters.
87 std::vector<ShapeHandle> shapes;
88 EXPECT_FALSE(c.input("nonexistent", &shapes).ok());
89 TF_EXPECT_OK(c.input("input", &shapes));
90 EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
91 EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
92 EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
93
94 // Test setters.
95 EXPECT_FALSE(c.set_output("nonexistent", shapes).ok());
96 TF_EXPECT_OK(c.set_output("output", shapes));
97 EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
98 EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
99 EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
100 }
101
MakeOpDef(int num_inputs,int num_outputs)102 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
103 OpRegistrationData op_reg_data;
104 OpDefBuilder b("dummy");
105 for (int i = 0; i < num_inputs; ++i) {
106 b.Input(strings::StrCat("i", i, ": float"));
107 }
108 for (int i = 0; i < num_outputs; ++i) {
109 b.Output(strings::StrCat("o", i, ": float"));
110 }
111 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
112 return op_reg_data.op_def;
113 }
114
TEST_F(ShapeInferenceTest,DimensionOrConstant)115 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
116 NodeDef def;
117 InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
118 EXPECT_EQ(InferenceContext::kUnknownDim,
119 c.Value(InferenceContext::kUnknownDim));
120 EXPECT_EQ(1, c.Value(1));
121
122 #ifndef NDEBUG
123 // Only run death test if DCHECKS are enabled.
124 EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
125 #endif
126 }
127
TEST_F(ShapeInferenceTest,Run)128 TEST_F(ShapeInferenceTest, Run) {
129 NodeDef def;
130 def.set_name("foo");
131 def.set_op("foo_op");
132 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
133 TF_ASSERT_OK(c.construction_status());
134
135 {
136 auto fn = [](InferenceContext* c) {
137 ShapeHandle h;
138 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
139 c->set_output(0, c->input(0));
140 c->set_output(1, c->input(0));
141 return Status::OK();
142 };
143 TF_ASSERT_OK(c.Run(fn));
144 }
145
146 {
147 auto fn = [](InferenceContext* c) {
148 ShapeHandle h;
149 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
150 c->set_output(0, c->input(0));
151 c->set_output(1, c->input(0));
152 return Status::OK();
153 };
154 Status s = c.Run(fn);
155 // Extra error message is attached when Run fails.
156 EXPECT_TRUE(str_util::StrContains(
157 s.ToString(),
158 "Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')"))
159 << s;
160 }
161 }
162
163 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)164 TEST_F(ShapeInferenceTest, AttachContext) {
165 NodeDef def;
166 def.set_name("foo");
167 def.set_op("foo_op");
168 // Error when no constant tensors were requested.
169 {
170 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
171 {});
172 TF_ASSERT_OK(c.construction_status());
173 auto fn = [](InferenceContext* c) {
174 ShapeHandle h;
175 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
176 c->set_output(0, c->input(0));
177 return Status::OK();
178 };
179 EXPECT_EQ(
180 "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
181 "'foo' (op: 'foo_op') with input shapes: [1,2,3].",
182 c.Run(fn).ToString());
183 }
184
185 // Error when a constant tensor value was requested.
186 {
187 Tensor input_t =
188 ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
189 InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
190 {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
191 TF_ASSERT_OK(c.construction_status());
192 auto fn = [](InferenceContext* c) {
193 c->input_tensor(0); // get this one, but it's null - won't be in error.
194 c->input_tensor(1); // get this one, will now be in error.
195 ShapeHandle h;
196 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
197 c->set_output(0, c->input(0));
198 return Status::OK();
199 };
200 EXPECT_EQ(
201 "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
202 "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with "
203 "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.",
204 c.Run(fn).ToString());
205 }
206
207 // Error when a constant tensor value as shape was requested, but no partial
208 // shapes provided.
209 {
210 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
211 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
212 {nullptr, &input_t}, {}, {});
213 TF_ASSERT_OK(c.construction_status());
214 auto fn = [](InferenceContext* c) {
215 ShapeHandle s;
216 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
217 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
218 ShapeHandle h;
219 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
220 c->set_output(0, c->input(0));
221 return Status::OK();
222 };
223 EXPECT_EQ(
224 "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
225 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
226 "input tensors: input[1] = <1 2 3 4 5>.",
227 c.Run(fn).ToString());
228 }
229
230 // Error when a constant tensor value as shape was requested, and a partial
231 // shape was provided.
232 {
233 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
234 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
235 {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
236 TF_ASSERT_OK(c.construction_status());
237 auto fn = [](InferenceContext* c) {
238 ShapeHandle s;
239 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
240 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
241 ShapeHandle h;
242 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
243 c->set_output(0, c->input(0));
244 return Status::OK();
245 };
246 EXPECT_EQ(
247 "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
248 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
249 "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
250 "as partial shapes: input[0] = [10,?,5].",
251 c.Run(fn).ToString());
252 }
253 }
254
TEST_F(ShapeInferenceTest,RankAndDimInspection)255 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
256 NodeDef def;
257 InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
258 {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
259 EXPECT_EQ(3, c.num_inputs());
260 EXPECT_EQ(2, c.num_outputs());
261
262 auto in0 = c.input(0);
263 EXPECT_EQ("?", c.DebugString(in0));
264 EXPECT_FALSE(c.RankKnown(in0));
265 EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
266 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
267 EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
268 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
269
270 auto in1 = c.input(1);
271 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
272 EXPECT_TRUE(c.RankKnown(in1));
273 EXPECT_EQ(3, c.Rank(in1));
274 auto d = c.Dim(in1, 0);
275 EXPECT_EQ(1, c.Value(d));
276 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
277 EXPECT_TRUE(c.ValueKnown(d));
278 EXPECT_EQ("1", c.DebugString(d));
279 d = c.Dim(in1, 1);
280 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
281 EXPECT_FALSE(c.ValueKnown(d));
282 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
283 EXPECT_EQ("?", c.DebugString(d));
284 d = c.Dim(in1, 2);
285 EXPECT_EQ(3, c.Value(d));
286 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
287 EXPECT_TRUE(c.ValueKnown(d));
288 EXPECT_EQ("3", c.DebugString(d));
289
290 auto in2 = c.input(2);
291 EXPECT_EQ("[]", c.DebugString(in2));
292 EXPECT_TRUE(c.RankKnown(in2));
293 EXPECT_EQ(0, c.Rank(in2));
294 }
295
TEST_F(ShapeInferenceTest,NumElements)296 TEST_F(ShapeInferenceTest, NumElements) {
297 NodeDef def;
298 InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
299 {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
300
301 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
302 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
303
304 // Different handles (not the same unknown value).
305 EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
306
307 EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
308 }
309
TEST_F(ShapeInferenceTest,WithRank)310 TEST_F(ShapeInferenceTest, WithRank) {
311 NodeDef def;
312 InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
313 {Unknown(), S({1, -1, 3})}, {}, {}, {});
314
315 auto in0 = c.input(0);
316 auto in1 = c.input(1);
317 ShapeHandle s1;
318 ShapeHandle s2;
319
320 // WithRank on a shape with unknown dimensionality always succeeds.
321 EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
322 EXPECT_EQ("[?]", c.DebugString(s1));
323
324 EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
325 EXPECT_EQ("[?,?]", c.DebugString(s2));
326 EXPECT_FALSE(SameHandle(s1, s2));
327 EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
328
329 EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
330 EXPECT_EQ("[?]", c.DebugString(s2));
331 EXPECT_FALSE(SameHandle(s1, s2));
332
333 EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
334 EXPECT_EQ("[]", c.DebugString(s1));
335
336 // WithRank on shape with known dimensionality.
337 s1 = in1;
338 EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3",
339 c.WithRank(in1, 2, &s1).ToString());
340 EXPECT_FALSE(IsSet(s1));
341 EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
342 EXPECT_TRUE(SameHandle(s1, in1));
343
344 // Inputs are unchanged.
345 EXPECT_EQ("?", c.DebugString(in0));
346 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
347 }
348
TEST_F(ShapeInferenceTest,WithRankAtMost)349 TEST_F(ShapeInferenceTest, WithRankAtMost) {
350 NodeDef def;
351 InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
352 {Unknown(), S({1, -1, 3})}, {}, {}, {});
353
354 auto in0 = c.input(0);
355 auto in1 = c.input(1);
356 ShapeHandle s1;
357 ShapeHandle s2;
358
359 // WithRankAtMost on a shape with unknown dimensionality always succeeds.
360 EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
361 EXPECT_EQ("?", c.DebugString(s1));
362 EXPECT_TRUE(SameHandle(in0, s1));
363
364 EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
365 EXPECT_EQ("?", c.DebugString(s2));
366 EXPECT_TRUE(SameHandle(s1, s2));
367
368 // WithRankAtMost on shape with known dimensionality.
369 s1 = in1;
370 EXPECT_TRUE(str_util::StrContains(
371 c.WithRankAtMost(in1, 2, &s1).ToString(),
372 "Invalid argument: Shape must be at most rank 2 but is rank 3"));
373
374 EXPECT_FALSE(IsSet(s1));
375 EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
376 EXPECT_TRUE(SameHandle(s1, in1));
377 EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok());
378 EXPECT_TRUE(SameHandle(s1, in1));
379 EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok());
380 EXPECT_TRUE(SameHandle(s1, in1));
381
382 // Inputs are unchanged.
383 EXPECT_EQ("?", c.DebugString(in0));
384 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
385 }
386
TEST_F(ShapeInferenceTest,WithRankAtLeast)387 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
388 NodeDef def;
389 InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
390 {Unknown(), S({1, -1, 3})}, {}, {}, {});
391
392 auto in0 = c.input(0);
393 auto in1 = c.input(1);
394 ShapeHandle s1;
395 ShapeHandle s2;
396
397 // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
398 EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
399 EXPECT_EQ("?", c.DebugString(s1));
400 EXPECT_TRUE(SameHandle(in0, s1));
401
402 EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
403 EXPECT_EQ("?", c.DebugString(s2));
404 EXPECT_TRUE(SameHandle(s1, s2));
405
406 // WithRankAtLeast on shape with known dimensionality.
407 s1 = in1;
408 EXPECT_TRUE(str_util::StrContains(
409 c.WithRankAtLeast(in1, 4, &s1).ToString(),
410 "Invalid argument: Shape must be at least rank 4 but is rank 3"));
411
412 EXPECT_FALSE(IsSet(s1));
413 EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
414 EXPECT_TRUE(SameHandle(s1, in1));
415 EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok());
416 EXPECT_TRUE(SameHandle(s1, in1));
417 EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok());
418 EXPECT_TRUE(SameHandle(s1, in1));
419
420 // Inputs are unchanged.
421 EXPECT_EQ("?", c.DebugString(in0));
422 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
423 }
424
TEST_F(ShapeInferenceTest,WithValue)425 TEST_F(ShapeInferenceTest, WithValue) {
426 NodeDef def;
427 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
428
429 auto d0 = c.Dim(c.input(0), 0);
430 auto d1 = c.Dim(c.input(0), 1);
431 DimensionHandle out1;
432 DimensionHandle out2;
433
434 // WithValue on a dimension with unknown value always succeeds.
435 EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
436 EXPECT_EQ(1, c.Value(out1));
437
438 EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
439 EXPECT_EQ(2, c.Value(out2));
440 EXPECT_FALSE(SameHandle(out1, out2));
441 EXPECT_FALSE(SameHandle(out1, d1));
442
443 EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
444 EXPECT_EQ(1, c.Value(out2));
445 EXPECT_FALSE(SameHandle(out1, out2));
446
447 // WithValue on dimension with known size.
448 out1 = d0;
449
450 EXPECT_TRUE(
451 str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(),
452 "Invalid argument: Dimension must be 0 but is 1"));
453 EXPECT_FALSE(IsSet(out1));
454 out1 = d0;
455 EXPECT_TRUE(
456 str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(),
457 "Invalid argument: Dimension must be 2 but is 1"));
458
459 EXPECT_FALSE(IsSet(out1));
460 EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
461 EXPECT_TRUE(SameHandle(d0, out1));
462
463 // Inputs are unchanged.
464 EXPECT_EQ("1", c.DebugString(d0));
465 EXPECT_EQ("?", c.DebugString(d1));
466 }
467
TEST_F(ShapeInferenceTest,MergeDim)468 TEST_F(ShapeInferenceTest, MergeDim) {
469 NodeDef def;
470 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
471 {}, {}, {});
472
473 auto d2 = c.Dim(c.input(0), 0);
474 auto d_unknown = c.Dim(c.input(0), 1);
475 auto d2_b = c.Dim(c.input(0), 2);
476 auto d1 = c.Dim(c.input(0), 3);
477 auto d_unknown_b = c.Dim(c.input(0), 4);
478 DimensionHandle out;
479
480 // Merging anything with unknown returns the same pointer.
481 EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
482 EXPECT_TRUE(SameHandle(d2, out));
483 EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
484 EXPECT_TRUE(SameHandle(d2, out));
485 EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
486 EXPECT_TRUE(SameHandle(d_unknown, out));
487
488 auto merged_dims = c.MergedDims();
489 ASSERT_EQ(3, merged_dims.size());
490 EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
491 EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
492 EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
493 EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
494 EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
495 EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
496
497 // Merging with self is a no-op and returns self.
498 EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
499 EXPECT_TRUE(SameHandle(d2, out));
500 EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
501 EXPECT_TRUE(SameHandle(d_unknown, out));
502
503 merged_dims = c.MergedDims();
504 EXPECT_EQ(3, merged_dims.size());
505
506 // Merging equal values is a no op and returns first one.
507 EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
508 EXPECT_TRUE(SameHandle(d2, out));
509 EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
510 EXPECT_TRUE(SameHandle(d2_b, out));
511
512 merged_dims = c.MergedDims();
513 EXPECT_EQ(3, merged_dims.size());
514
515 // Merging unequal values is an error.
516 EXPECT_TRUE(str_util::StrContains(
517 c.Merge(d2, d1, &out).ToString(),
518 "Invalid argument: Dimensions must be equal, but are 2 and 1"));
519
520 EXPECT_FALSE(IsSet(out));
521 EXPECT_TRUE(str_util::StrContains(
522 c.Merge(d1, d2, &out).ToString(),
523 "Invalid argument: Dimensions must be equal, but are 1 and 2"));
524
525 EXPECT_FALSE(IsSet(out));
526
527 merged_dims = c.MergedDims();
528 EXPECT_EQ(3, merged_dims.size());
529 }
530
TEST_F(ShapeInferenceTest,RelaxDim)531 TEST_F(ShapeInferenceTest, RelaxDim) {
532 NodeDef def;
533 InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
534 {S({2, InferenceContext::kUnknownDim, 2, 1,
535 InferenceContext::kUnknownDim})},
536 {}, {}, {});
537
538 auto d2 = c.Dim(c.input(0), 0);
539 auto d_unknown = c.Dim(c.input(0), 1);
540 auto d2_b = c.Dim(c.input(0), 2);
541 auto d1 = c.Dim(c.input(0), 3);
542 auto d_unknown_b = c.Dim(c.input(0), 4);
543 DimensionHandle out;
544
545 // Relaxing anything with unknown returns a new unknown or the existing
546 // unknown.
547 Relax(&c, d2, d_unknown, &out);
548 EXPECT_TRUE(SameHandle(d_unknown, out));
549 EXPECT_FALSE(SameHandle(d_unknown_b, out));
550 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
551 Relax(&c, d_unknown, d2, &out);
552 EXPECT_FALSE(SameHandle(d_unknown, out));
553 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
554 Relax(&c, d_unknown, d_unknown_b, &out);
555 EXPECT_FALSE(SameHandle(d_unknown, out));
556 EXPECT_TRUE(SameHandle(d_unknown_b, out));
557 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
558
559 // Relaxing with self returns self.
560 Relax(&c, d2, d2, &out);
561 EXPECT_TRUE(SameHandle(d2, out));
562 Relax(&c, d_unknown, d_unknown, &out);
563 EXPECT_TRUE(SameHandle(d_unknown, out));
564
565 // Relaxing equal values returns first one.
566 Relax(&c, d2, d2_b, &out);
567 EXPECT_TRUE(SameHandle(d2, out));
568 Relax(&c, d2_b, d2, &out);
569 EXPECT_TRUE(SameHandle(d2_b, out));
570
571 // Relaxing unequal values returns a new unknown.
572 Relax(&c, d2, d1, &out);
573 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
574 Relax(&c, d1, d2, &out);
575 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
576 }
577
TEST_F(ShapeInferenceTest,RelaxShape)578 TEST_F(ShapeInferenceTest, RelaxShape) {
579 NodeDef def;
580 InferenceContext c(
581 kVersion, &def, MakeOpDef(7, 2),
582 {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
583 S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
584 {}, {}, {});
585
586 auto s_unknown = c.input(0);
587 auto s_1_2 = c.input(1);
588 auto s_u_2 = c.input(2);
589 auto s_1_u = c.input(3);
590 auto s_1_3 = c.input(4);
591 auto s_unknown_b = c.input(5);
592 auto s_1 = c.input(6);
593 ShapeHandle out;
594
595 // Relaxing any shape with unknown returns a new unknown.
596 Relax(&c, s_unknown, s_1_2, &out);
597 EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
598 EXPECT_EQ("?", c.DebugString(out));
599 Relax(&c, s_u_2, s_unknown, &out);
600 EXPECT_FALSE(SameHandle(s_u_2, out));
601 EXPECT_EQ("?", c.DebugString(out));
602 Relax(&c, s_unknown, s_unknown_b, &out);
603 EXPECT_FALSE(SameHandle(s_unknown, out));
604 EXPECT_TRUE(SameHandle(s_unknown_b, out));
605 EXPECT_EQ("?", c.DebugString(out));
606
607 // Relaxing with self returns self.
608 Relax(&c, s_1_2, s_1_2, &out);
609 EXPECT_TRUE(SameHandle(out, s_1_2));
610
611 // Relaxing where one of the inputs has less information.
612 out = ShapeHandle();
613 Relax(&c, s_1_2, s_u_2, &out);
614 EXPECT_FALSE(SameHandle(s_u_2, out));
615 EXPECT_EQ("[?,2]", c.DebugString(out));
616 out = ShapeHandle();
617 Relax(&c, s_u_2, s_1_2, &out);
618 EXPECT_FALSE(SameHandle(s_u_2, out));
619 EXPECT_EQ("[?,2]", c.DebugString(out));
620
621 // Relaxing where each input has one distinct unknown dimension.
622 Relax(&c, s_u_2, s_1_u, &out);
623 EXPECT_EQ("[?,?]", c.DebugString(out));
624 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
625 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
626 auto s_u1 = c.UnknownShapeOfRank(1);
627 auto s_u2 = c.UnknownShapeOfRank(1);
628 Relax(&c, s_u1, s_u2, &out);
629 EXPECT_FALSE(SameHandle(s_u1, out));
630
631 // Relaxing with mismatched values in a dimension returns a shape with that
632 // dimension unknown.
633 out = s_unknown;
634 Relax(&c, s_u_2, s_1_3, &out);
635 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
636 EXPECT_EQ("[?,?]", c.DebugString(out));
637 out = s_unknown;
638 Relax(&c, s_1_3, s_u_2, &out);
639 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
640 EXPECT_EQ("[?,?]", c.DebugString(out));
641 out = s_unknown;
642
643 // Relaxing with mismatched ranks returns a new unknown.
644 Relax(&c, s_1, s_1_2, &out);
645 EXPECT_EQ("?", c.DebugString(out));
646 }
647
TEST_F(ShapeInferenceTest,MergeShape)648 TEST_F(ShapeInferenceTest, MergeShape) {
649 NodeDef def;
650 InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
651 {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
652 Unknown(), S({1})},
653 {}, {}, {});
654
655 auto s_unknown = c.input(0);
656 auto s_1_2 = c.input(1);
657 auto s_u_2 = c.input(2);
658 auto s_1_u = c.input(3);
659 auto s_1_3 = c.input(4);
660 auto s_unknown_b = c.input(5);
661 auto s_1 = c.input(6);
662 ShapeHandle out;
663
664 // Merging any shape with unknown returns the shape.
665 EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
666 EXPECT_TRUE(SameHandle(s_1_2, out));
667 EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
668 EXPECT_TRUE(SameHandle(s_u_2, out));
669 EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
670 EXPECT_TRUE(SameHandle(s_unknown, out));
671
672 auto merged_shapes = c.MergedShapes();
673 ASSERT_EQ(3, merged_shapes.size());
674 EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
675 EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
676 EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
677 EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
678 EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
679 EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
680
681 // Merging with self returns self.
682 EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
683 EXPECT_TRUE(SameHandle(out, s_1_2));
684
685 merged_shapes = c.MergedShapes();
686 EXPECT_EQ(3, merged_shapes.size());
687
688 // Merging where one of the inputs is the right answer - return that input.
689 out = ShapeHandle();
690 EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
691 EXPECT_TRUE(SameHandle(s_1_2, out));
692 out = ShapeHandle();
693 EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
694 EXPECT_TRUE(SameHandle(s_1_2, out));
695
696 merged_shapes = c.MergedShapes();
697 ASSERT_EQ(5, merged_shapes.size());
698 EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
699 EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
700 EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
701 EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
702
703 // Merging where neither input is the right answer.
704 EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
705 EXPECT_FALSE(SameHandle(out, s_u_2));
706 EXPECT_FALSE(SameHandle(out, s_1_u));
707 EXPECT_EQ("[1,2]", c.DebugString(out));
708 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
709 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
710
711 merged_shapes = c.MergedShapes();
712 ASSERT_EQ(7, merged_shapes.size());
713 EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
714 EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
715 EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
716 EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
717
718 auto s_u1 = c.UnknownShapeOfRank(1);
719 auto s_u2 = c.UnknownShapeOfRank(1);
720 TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
721 EXPECT_TRUE(SameHandle(s_u1, out));
722
723 merged_shapes = c.MergedShapes();
724 ASSERT_EQ(8, merged_shapes.size());
725 EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
726 EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
727
728 // Incompatible merges give errors and set out to nullptr.
729 out = s_unknown;
730 EXPECT_TRUE(str_util::StrContains(
731 c.Merge(s_u_2, s_1_3, &out).ToString(),
732 "Invalid argument: Dimension 1 in both shapes must be equal, but "
733 "are 2 and 3"));
734
735 EXPECT_FALSE(IsSet(out));
736 out = s_unknown;
737 EXPECT_TRUE(str_util::StrContains(
738 c.Merge(s_1_3, s_u_2, &out).ToString(),
739 "Invalid argument: Dimension 1 in both shapes must be equal, but "
740 "are 3 and 2"));
741
742 EXPECT_FALSE(IsSet(out));
743 out = s_unknown;
744 EXPECT_TRUE(str_util::StrContains(
745 c.Merge(s_1, s_1_2, &out).ToString(),
746 "Invalid argument: Shapes must be equal rank, but are 1 and 2"));
747
748 EXPECT_FALSE(IsSet(out));
749
750 merged_shapes = c.MergedShapes();
751 EXPECT_EQ(8, merged_shapes.size());
752 }
753
TEST_F(ShapeInferenceTest,MergePrefix)754 TEST_F(ShapeInferenceTest, MergePrefix) {
755 NodeDef def;
756 InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
757 {
758 Unknown(),
759 S({-1, 2}),
760 S({1, -1, 3}),
761 S({2, 4}),
762 },
763 {}, {}, {});
764
765 auto s_unknown = c.input(0);
766 auto s_u_2 = c.input(1);
767 auto s_1_u_3 = c.input(2);
768 auto s_2_4 = c.input(3);
769
770 ShapeHandle s_out;
771 ShapeHandle s_prefix_out;
772
773 // Merging with unknown returns the inputs.
774 EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
775 EXPECT_TRUE(SameHandle(s_out, s_unknown));
776 EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
777 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok());
778 EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
779 EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
780
781 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok());
782 EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
783 EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
784 EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
785 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
786 EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
787 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
788 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
789
790 // Incompatible merges give errors and set outs to nullptr.
791 s_out = s_unknown;
792 s_prefix_out = s_unknown;
793 EXPECT_TRUE(str_util::StrContains(
794 c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(),
795 "Invalid argument: Dimensions must be equal, but are 1 and 2"));
796
797 EXPECT_FALSE(IsSet(s_out));
798 EXPECT_FALSE(IsSet(s_prefix_out));
799
800 s_out = s_unknown;
801 s_prefix_out = s_unknown;
802 EXPECT_TRUE(str_util::StrContains(
803 c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(),
804 "Invalid argument: Shape must be at least rank 3 but is rank 2"));
805 EXPECT_FALSE(IsSet(s_out));
806 EXPECT_FALSE(IsSet(s_prefix_out));
807 }
808
TEST_F(ShapeInferenceTest,Subshape)809 TEST_F(ShapeInferenceTest, Subshape) {
810 NodeDef def;
811 InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
812 {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
813
814 ShapeHandle unknown = c.input(1);
815 ShapeHandle out;
816 EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
817 EXPECT_EQ("?", c.DebugString(out));
818 EXPECT_TRUE(SameHandle(out, unknown));
819 EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok());
820 EXPECT_EQ("?", c.DebugString(out));
821 EXPECT_FALSE(SameHandle(out, unknown));
822 EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok());
823 EXPECT_EQ("?", c.DebugString(out));
824 EXPECT_FALSE(SameHandle(out, unknown));
825
826 const int kFullRank = 5;
827 ShapeHandle out_arr[4];
828 auto in0 = c.input(0);
829 EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
830 EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
831 EXPECT_TRUE(SameHandle(out, in0));
832 EXPECT_EQ(kFullRank, c.Rank(out));
833 for (int start = 0; start <= kFullRank + 1; ++start) {
834 for (int end = start; end <= kFullRank + 1; ++end) {
835 // Get subshapes using different start and end values that give the same
836 // range.
837 const int neg_start =
838 start >= kFullRank ? kFullRank : (start - kFullRank);
839 const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
840 ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok());
841 ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok());
842 ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok());
843 ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok());
844
845 // Verify all computed subshapes.
846 for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
847 out = out_arr[arr_idx];
848 ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
849 c.Rank(out))
850 << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
851 << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
852 for (int d = 0; d < c.Rank(out); ++d) {
853 EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
854 << "arr_idx: " << arr_idx;
855 }
856 }
857 }
858 }
859
860 // Errors.
861 out = unknown;
862 EXPECT_TRUE(str_util::StrContains(
863 c.Subshape(in0, 6, -3, &out).ToString(),
864 "Invalid argument: Subshape must have computed start <= end, but is 5 "
865 "and 2 (computed from start 6 and end -3 over shape with rank 5)"));
866 EXPECT_FALSE(IsSet(out));
867 out = unknown;
868 EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(),
869 "Invalid argument: Subshape start out of "
870 "bounds: -50, for shape with rank 5"));
871
872 EXPECT_FALSE(IsSet(out));
873 out = unknown;
874 EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(),
875 "Invalid argument: Subshape end out of "
876 "bounds: -50, for shape with rank 5"));
877
878 EXPECT_FALSE(IsSet(out));
879 }
880
TEST_F(ShapeInferenceTest,Concatenate)881 TEST_F(ShapeInferenceTest, Concatenate) {
882 NodeDef def;
883 InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
884 {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
885
886 auto in0 = c.input(0);
887 auto in1 = c.input(1);
888 ShapeHandle unknown = c.input(2);
889 ShapeHandle out;
890 EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
891 EXPECT_EQ("?", c.DebugString(out));
892 EXPECT_FALSE(SameHandle(out, unknown));
893 EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok());
894 EXPECT_EQ("?", c.DebugString(out));
895 EXPECT_FALSE(SameHandle(out, unknown));
896
897 EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok());
898 EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
899 int out_i = 0;
900 for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
901 EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
902 }
903 for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
904 EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
905 }
906 }
907
TEST_F(ShapeInferenceTest,ReplaceDim)908 TEST_F(ShapeInferenceTest, ReplaceDim) {
909 NodeDef def;
910 InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
911 {}, {}, {});
912
913 auto in = c.input(0);
914 auto unknown = c.input(1);
915
916 ShapeHandle replaced;
917 EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
918 EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
919 EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
920 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
921 EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
922 EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
923 EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
924 EXPECT_EQ("?", c.DebugString(replaced));
925
926 // Negative indexing.
927 EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
928 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
929 EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
930 EXPECT_EQ("?", c.DebugString(replaced));
931
932 // out of range indexing.
933 EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
934 EXPECT_FALSE(IsSet(replaced));
935 replaced = in;
936 EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
937 EXPECT_FALSE(IsSet(replaced));
938 }
939
TEST_F(ShapeInferenceTest,MakeShape)940 TEST_F(ShapeInferenceTest, MakeShape) {
941 NodeDef def;
942 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
943 {}, {});
944
945 std::vector<DimensionHandle> dims;
946 auto in0 = c.input(0);
947 const int rank = c.Rank(in0);
948 dims.reserve(rank);
949 for (int i = 0; i < rank; ++i) {
950 dims.push_back(c.Dim(in0, rank - i - 1));
951 }
952
953 auto s = c.MakeShape(dims);
954 EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
955 EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
956
957 auto s2 = c.MakeShape(dims);
958 EXPECT_FALSE(SameHandle(s, s2));
959 EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
960
961 auto s3 = c.MakeShape({1, 2, dims[2]});
962 EXPECT_FALSE(SameHandle(s, s3));
963 EXPECT_EQ("[1,2,3]", c.DebugString(s3));
964 }
965
TEST_F(ShapeInferenceTest,UnknownShape)966 TEST_F(ShapeInferenceTest, UnknownShape) {
967 NodeDef def;
968 std::vector<ShapeHandle> empty;
969 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
970
971 auto u0 = c.UnknownShape();
972 auto u1 = c.UnknownShape();
973 EXPECT_EQ("?", c.DebugString(u0));
974 EXPECT_EQ("?", c.DebugString(u1));
975 EXPECT_FALSE(SameHandle(u0, u1));
976 }
977
TEST_F(ShapeInferenceTest,KnownShapeToProto)978 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
979 NodeDef def;
980 std::vector<ShapeHandle> empty;
981 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
982
983 auto s = c.MakeShape({1, 2, 3});
984 TensorShapeProto proto;
985 c.ShapeHandleToProto(s, &proto);
986
987 EXPECT_FALSE(proto.unknown_rank());
988 EXPECT_EQ(3, proto.dim_size());
989 EXPECT_EQ(1, proto.dim(0).size());
990 }
991
TEST_F(ShapeInferenceTest,UnknownShapeToProto)992 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
993 NodeDef def;
994 std::vector<ShapeHandle> empty;
995 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
996
997 auto u0 = c.UnknownShape();
998 TensorShapeProto proto;
999 c.ShapeHandleToProto(u0, &proto);
1000
1001 EXPECT_TRUE(proto.unknown_rank());
1002 EXPECT_EQ(0, proto.dim_size());
1003 }
1004
TEST_F(ShapeInferenceTest,Scalar)1005 TEST_F(ShapeInferenceTest, Scalar) {
1006 NodeDef def;
1007 std::vector<ShapeHandle> empty;
1008 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1009
1010 auto s0 = c.Scalar();
1011 EXPECT_EQ("[]", c.DebugString(s0));
1012 auto s1 = c.Scalar();
1013 EXPECT_EQ("[]", c.DebugString(s1));
1014 }
1015
TEST_F(ShapeInferenceTest,Vector)1016 TEST_F(ShapeInferenceTest, Vector) {
1017 NodeDef def;
1018 std::vector<ShapeHandle> empty;
1019 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1020
1021 auto s0 = c.Vector(1);
1022 EXPECT_EQ("[1]", c.DebugString(s0));
1023 auto s1 = c.Vector(InferenceContext::kUnknownDim);
1024 EXPECT_EQ("[?]", c.DebugString(s1));
1025
1026 auto d1 = c.UnknownDim();
1027 auto s2 = c.Vector(d1);
1028 EXPECT_EQ("[?]", c.DebugString(s2));
1029 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1030 }
1031
TEST_F(ShapeInferenceTest,Matrix)1032 TEST_F(ShapeInferenceTest, Matrix) {
1033 NodeDef def;
1034 std::vector<ShapeHandle> empty;
1035 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1036
1037 auto s0 = c.Matrix(1, 2);
1038 EXPECT_EQ("[1,2]", c.DebugString(s0));
1039 auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1040 EXPECT_EQ("[0,?]", c.DebugString(s1));
1041
1042 auto d1 = c.UnknownDim();
1043 auto d2 = c.UnknownDim();
1044 auto s2 = c.Matrix(d1, d2);
1045 EXPECT_EQ("[?,?]", c.DebugString(s2));
1046 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1047 EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1048
1049 auto s3 = c.Matrix(d1, 100);
1050 EXPECT_EQ("[?,100]", c.DebugString(s3));
1051 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1052 }
1053
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1054 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1055 auto create = [&](Tensor* t) {
1056 NodeDef def;
1057 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1058 {});
1059 ShapeHandle out;
1060 Status s = c.MakeShapeFromShapeTensor(0, &out);
1061 if (s.ok()) {
1062 return c.DebugString(out);
1063 } else {
1064 EXPECT_FALSE(IsSet(out));
1065 return s.error_message();
1066 }
1067 };
1068
1069 Tensor t;
1070 EXPECT_EQ("?", create(nullptr));
1071
1072 t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1073 EXPECT_EQ("[1,2,3]", create(&t));
1074
1075 t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
1076 EXPECT_EQ("[3,2,1]", create(&t));
1077
1078 t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
1079 EXPECT_EQ("[3,?,1]", create(&t));
1080
1081 t = ::tensorflow::test::AsTensor<int64>({});
1082 EXPECT_EQ("[]", create(&t));
1083
1084 // Test negative scalar
1085 t = ::tensorflow::test::AsScalar<int32>(-1);
1086 EXPECT_EQ("?", create(&t));
1087
1088 t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1089 EXPECT_TRUE(str_util::StrContains(
1090 create(&t), "Input tensor must be int32 or int64, but was float"));
1091
1092 t = ::tensorflow::test::AsScalar<int32>(1);
1093 auto s_scalar = create(&t);
1094 EXPECT_TRUE(str_util::StrContains(
1095 s_scalar,
1096 "Input tensor must be rank 1, or if its rank 0 it must have value -1"))
1097 << s_scalar;
1098
1099 t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1100 auto s_matrix = create(&t);
1101 EXPECT_TRUE(str_util::StrContains(
1102 s_matrix, "Input tensor must be rank 1, but was rank 2"))
1103 << s_matrix;
1104
1105 // Test negative values for the dims.
1106 t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
1107 EXPECT_TRUE(str_util::StrContains(
1108 create(&t), "Invalid value in tensor used for shape: -2"));
1109
1110 // Test negative values for the dims.
1111 t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1112 EXPECT_TRUE(str_util::StrContains(
1113 create(&t), "Invalid value in tensor used for shape: -2"));
1114
1115 // Test when the input shape is wrong.
1116 {
1117 NodeDef def;
1118 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1119 {}, {});
1120 ShapeHandle out;
1121 EXPECT_EQ("Shape must be rank 1 but is rank 2",
1122 c.MakeShapeFromShapeTensor(0, &out).error_message());
1123 }
1124 }
1125
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1126 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1127 NodeDef def;
1128 std::vector<ShapeHandle> empty;
1129 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1130
1131 // With an unknown rank.
1132 ShapeHandle out;
1133 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1134 EXPECT_EQ("?", c.DebugString(out));
1135
1136 // With a known rank.
1137 TF_ASSERT_OK(
1138 c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1139 EXPECT_EQ("[0]", c.DebugString(out));
1140 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1141 PartialTensorShape({0, -1, 1000}), &out));
1142 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1143 }
1144
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1145 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1146 NodeDef def;
1147 std::vector<ShapeHandle> empty;
1148 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1149
1150 ShapeHandle out;
1151 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1152 EXPECT_EQ("[]", c.DebugString(out));
1153 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1154 EXPECT_EQ("[0]", c.DebugString(out));
1155 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1156 EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1157 }
1158
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1159 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1160 NodeDef def;
1161 std::vector<ShapeHandle> empty;
1162 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1163 TensorShapeProto proto;
1164
1165 // With a set unknown rank.
1166 ShapeHandle out;
1167 proto.set_unknown_rank(true);
1168 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1169 EXPECT_EQ("?", c.DebugString(out));
1170 proto.add_dim()->set_size(0);
1171 EXPECT_TRUE(str_util::StrContains(
1172 c.MakeShapeFromShapeProto(proto, &out).error_message(),
1173 "An unknown shape must not have any dimensions set."));
1174 EXPECT_FALSE(IsSet(out));
1175
1176 // With known rank.
1177 proto.set_unknown_rank(false);
1178 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1179 EXPECT_EQ("[0]", c.DebugString(out));
1180 proto.add_dim()->set_size(-1);
1181 proto.add_dim()->set_size(1000);
1182 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1183 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1184
1185 // With invalid dimension value.
1186 proto.add_dim()->set_size(-2);
1187 EXPECT_TRUE(str_util::StrContains(
1188 c.MakeShapeFromShapeProto(proto, &out).error_message(),
1189 "Shape [0,?,1000,-2] has dimensions with values below -1 "
1190 "(where -1 means unknown)"));
1191
1192 EXPECT_FALSE(IsSet(out));
1193 }
1194
TEST_F(ShapeInferenceTest,MakeDim)1195 TEST_F(ShapeInferenceTest, MakeDim) {
1196 NodeDef def;
1197 std::vector<ShapeHandle> empty;
1198 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1199
1200 auto d0 = c.MakeDim(1);
1201 auto d1 = c.MakeDim(1);
1202 auto d2 = c.MakeDim(2);
1203 EXPECT_EQ("1", c.DebugString(d0));
1204 EXPECT_EQ("1", c.DebugString(d1));
1205 EXPECT_FALSE(SameHandle(d0, d1));
1206 EXPECT_EQ("2", c.DebugString(d2));
1207 }
1208
TEST_F(ShapeInferenceTest,UnknownDim)1209 TEST_F(ShapeInferenceTest, UnknownDim) {
1210 NodeDef def;
1211 std::vector<ShapeHandle> empty;
1212 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1213
1214 auto d0 = c.UnknownDim();
1215 auto d1 = c.UnknownDim();
1216 EXPECT_EQ("?", c.DebugString(d0));
1217 EXPECT_EQ("?", c.DebugString(d1));
1218 EXPECT_FALSE(SameHandle(d0, d1));
1219 }
1220
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1221 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1222 NodeDef def;
1223 std::vector<ShapeHandle> empty;
1224 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1225
1226 auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1227 EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1228
1229 auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1230 EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1231 }
1232
TEST_F(ShapeInferenceTest,InputTensors)1233 TEST_F(ShapeInferenceTest, InputTensors) {
1234 const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1235 const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1236 NodeDef def;
1237 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1238 {&t1, &t2}, {}, {});
1239
1240 EXPECT_TRUE(c.input_tensor(0) == &t1);
1241 EXPECT_TRUE(c.input_tensor(1) == &t2);
1242 EXPECT_TRUE(c.input_tensor(2) == nullptr);
1243 }
1244
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1245 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1246 Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1247 Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1248 NodeDef def;
1249 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
1250 {&t1, &t2}, {}, {});
1251
1252 DimensionHandle d;
1253 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1254 EXPECT_EQ("20", c.DebugString(d));
1255
1256 EXPECT_TRUE(
1257 str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
1258 "Dimension size, given by scalar input 1, must be "
1259 "non-negative but is -1"));
1260
1261 // Same tests, with int64 values.
1262 t1 = tensorflow::test::AsScalar<int64>(20);
1263 t2 = tensorflow::test::AsScalar<int64>(-1);
1264 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1265 EXPECT_EQ("20", c.DebugString(d));
1266
1267 EXPECT_TRUE(
1268 str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
1269 "Dimension size, given by scalar input 1, must be "
1270 "non-negative but is -1"));
1271 }
1272
TEST_F(ShapeInferenceTest,GetAttr)1273 TEST_F(ShapeInferenceTest, GetAttr) {
1274 OpRegistrationData op_reg_data;
1275 op_reg_data.op_def = MakeOpDef(0, 2);
1276 NodeDef def;
1277 CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1278 .Attr("foo", "bar")
1279 .Finalize(&def)
1280 .ok());
1281
1282 std::vector<ShapeHandle> empty;
1283 InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
1284 string value;
1285 EXPECT_TRUE(c.GetAttr("foo", &value).ok());
1286 EXPECT_EQ("bar", value);
1287 }
1288
TEST_F(ShapeInferenceTest,Divide)1289 TEST_F(ShapeInferenceTest, Divide) {
1290 NodeDef def;
1291 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1292 {}, {});
1293
1294 auto s = c.input(0);
1295 auto d_6 = c.Dim(s, 0);
1296 auto d_unknown = c.Dim(s, 1);
1297 auto d_1 = c.Dim(s, 2);
1298 auto d_2 = c.Dim(s, 3);
1299 auto d_0 = c.Dim(s, 4);
1300 bool evenly_divisible = true;
1301
1302 // Dividing unknown by non-1 gives new unknown.
1303 DimensionHandle out;
1304 EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok());
1305 EXPECT_EQ("?", c.DebugString(out));
1306 EXPECT_FALSE(SameHandle(out, d_unknown));
1307
1308 // Dividing anything by 1 returns the input.
1309 EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok());
1310 EXPECT_TRUE(SameHandle(out, d_unknown));
1311 EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
1312 EXPECT_TRUE(SameHandle(out, d_6));
1313 EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
1314 EXPECT_TRUE(SameHandle(out, d_unknown));
1315 EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
1316 EXPECT_TRUE(SameHandle(out, d_6));
1317
1318 EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
1319 EXPECT_EQ("3", c.DebugString(out));
1320 EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
1321 EXPECT_EQ("3", c.DebugString(out));
1322
1323 EXPECT_TRUE(str_util::StrContains(
1324 c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
1325 "Dimension size must be evenly divisible by 5 but is 6"));
1326
1327 EXPECT_TRUE(str_util::StrContains(
1328 c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1329 "Divisor must be positive but is 0"));
1330 EXPECT_TRUE(str_util::StrContains(
1331 c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
1332 "Divisor must be positive but is 0"));
1333
1334 EXPECT_TRUE(str_util::StrContains(
1335 c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1336 "Divisor must be positive but is -1"));
1337
1338 // Repeat error cases above with evenly_divisible=false.
1339 evenly_divisible = false;
1340 EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
1341 EXPECT_EQ("1", c.DebugString(out));
1342
1343 EXPECT_TRUE(str_util::StrContains(
1344 c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1345 "Divisor must be positive but is 0"));
1346
1347 EXPECT_TRUE(str_util::StrContains(
1348 c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1349 "Divisor must be positive but is -1"));
1350 }
1351
TEST_F(ShapeInferenceTest,Add)1352 TEST_F(ShapeInferenceTest, Add) {
1353 NodeDef def;
1354 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1355 {});
1356
1357 auto s = c.input(0);
1358 auto d_6 = c.Dim(s, 0);
1359 auto d_unknown = c.Dim(s, 1);
1360 auto d_0 = c.Dim(s, 2);
1361
1362 // Adding non-zero to unknown gives new unknown.
1363 DimensionHandle out;
1364 EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
1365 EXPECT_EQ("?", c.DebugString(out));
1366 EXPECT_FALSE(SameHandle(out, d_unknown));
1367
1368 // Adding 0 to anything gives input.
1369 EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok());
1370 EXPECT_TRUE(SameHandle(out, d_unknown));
1371 EXPECT_TRUE(c.Add(d_6, 0, &out).ok());
1372 EXPECT_TRUE(SameHandle(out, d_6));
1373
1374 // Adding dimension with value 0 to anything gives input.
1375 EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
1376 EXPECT_TRUE(SameHandle(out, d_unknown));
1377 EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
1378 EXPECT_TRUE(SameHandle(out, d_6));
1379
1380 // Test addition.
1381 EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
1382 EXPECT_EQ("8", c.DebugString(out));
1383 EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
1384 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1385
1386 // Test addition using dimension as second value.
1387 EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok());
1388 EXPECT_EQ("8", c.DebugString(out));
1389 EXPECT_TRUE(
1390 c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok());
1391 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1392 EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
1393 EXPECT_EQ("?", c.DebugString(out));
1394 EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
1395 EXPECT_TRUE(SameHandle(out, d_6));
1396
1397 EXPECT_TRUE(str_util::StrContains(
1398 c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
1399 "Dimension size overflow from adding 6 and 9223372036854775802"));
1400 }
1401
TEST_F(ShapeInferenceTest,Subtract)1402 TEST_F(ShapeInferenceTest, Subtract) {
1403 NodeDef def;
1404 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
1405 {}, {});
1406
1407 auto s = c.input(0);
1408 auto d_6 = c.Dim(s, 0);
1409 auto d_unknown = c.Dim(s, 1);
1410 auto d_0 = c.Dim(s, 2);
1411 auto d_5 = c.Dim(s, 3);
1412
1413 // Subtracting non-zero from unknown gives new unknown.
1414 DimensionHandle out;
1415 EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
1416 EXPECT_EQ("?", c.DebugString(out));
1417 EXPECT_FALSE(SameHandle(out, d_unknown));
1418
1419 // Subtracting 0 from anything gives input.
1420 EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
1421 EXPECT_TRUE(SameHandle(out, d_unknown));
1422 EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
1423 EXPECT_TRUE(SameHandle(out, d_6));
1424
1425 // Subtracting dimension with value 0 from anything gives input.
1426 EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
1427 EXPECT_TRUE(SameHandle(out, d_unknown));
1428 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
1429 EXPECT_TRUE(SameHandle(out, d_6));
1430
1431 // Test subtraction.
1432 EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
1433 EXPECT_EQ("4", c.DebugString(out));
1434 EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
1435 EXPECT_EQ("0", c.DebugString(out));
1436
1437 // Test subtraction using dimension as second value.
1438 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
1439 EXPECT_EQ("4", c.DebugString(out));
1440 EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
1441 EXPECT_EQ("1", c.DebugString(out));
1442 EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
1443 EXPECT_EQ("?", c.DebugString(out));
1444 EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
1445 EXPECT_TRUE(SameHandle(out, d_6));
1446
1447 EXPECT_TRUE(str_util::StrContains(
1448 c.Subtract(d_5, d_6, &out).error_message(),
1449 "Negative dimension size caused by subtracting 6 from 5"));
1450 }
1451
TEST_F(ShapeInferenceTest,Multiply)1452 TEST_F(ShapeInferenceTest, Multiply) {
1453 NodeDef def;
1454 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
1455 {}, {});
1456
1457 auto s = c.input(0);
1458 auto d_6 = c.Dim(s, 0);
1459 auto d_unknown = c.Dim(s, 1);
1460 auto d_0 = c.Dim(s, 2);
1461 auto d_1 = c.Dim(s, 3);
1462
1463 // Multiplying non-zero to unknown gives new unknown.
1464 DimensionHandle out;
1465 EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
1466 EXPECT_EQ("?", c.DebugString(out));
1467
1468 // Multiplying 0 to anything gives 0.
1469 EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok());
1470 EXPECT_EQ("0", c.DebugString(out));
1471 EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
1472 EXPECT_EQ("0", c.DebugString(out));
1473 EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
1474 EXPECT_EQ("0", c.DebugString(out));
1475
1476 // Multiplying 1 to anything gives the original.
1477 // (unknown -> unknown)
1478 EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
1479 EXPECT_TRUE(SameHandle(d_unknown, out));
1480 EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
1481 EXPECT_TRUE(SameHandle(d_unknown, out));
1482 EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
1483 EXPECT_TRUE(SameHandle(d_unknown, out));
1484 // (known -> known)
1485 EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
1486 EXPECT_TRUE(SameHandle(d_6, out));
1487 EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
1488 EXPECT_TRUE(SameHandle(d_6, out));
1489 EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
1490 EXPECT_TRUE(SameHandle(d_6, out));
1491
1492 // Test multiplication.
1493 EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
1494 EXPECT_EQ("12", c.DebugString(out));
1495 EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
1496 EXPECT_EQ("36", c.DebugString(out));
1497
1498 // Test multiplication using dimension as second value.
1499 EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
1500 EXPECT_EQ("12", c.DebugString(out));
1501 EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
1502 EXPECT_EQ("?", c.DebugString(out));
1503 }
1504
TEST_F(ShapeInferenceTest,FullyDefined)1505 TEST_F(ShapeInferenceTest, FullyDefined) {
1506 NodeDef def;
1507 std::vector<ShapeHandle> empty;
1508 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1509
1510 // No rank or missing dimension information should return false.
1511 EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1512 EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1513
1514 // Return true if all information exists.
1515 EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1516 EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1517 }
1518
TEST_F(ShapeInferenceTest,Min)1519 TEST_F(ShapeInferenceTest, Min) {
1520 NodeDef def;
1521 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
1522 {}, {});
1523
1524 auto s = c.input(0);
1525 auto d_1 = c.Dim(s, 0);
1526 auto d_2 = c.Dim(s, 1);
1527 auto d_unknown = c.Dim(s, 2);
1528 auto d_0 = c.Dim(s, 3);
1529
1530 // Minimum involving zero and unknown returns zero.
1531 DimensionHandle out;
1532 EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
1533 EXPECT_TRUE(SameHandle(d_0, out));
1534 EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
1535 EXPECT_TRUE(SameHandle(d_0, out));
1536 EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
1537 EXPECT_EQ("0", c.DebugString(out));
1538 EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
1539 EXPECT_EQ("0", c.DebugString(out));
1540
1541 // Minimum involving unknowns and non-zeros gives new unknown.
1542 EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
1543 EXPECT_EQ("?", c.DebugString(out));
1544 EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
1545 EXPECT_EQ("?", c.DebugString(out));
1546 EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
1547 EXPECT_EQ("?", c.DebugString(out));
1548
1549 // Minimum with constant second arg.
1550 EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
1551 EXPECT_TRUE(SameHandle(d_1, out));
1552 EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
1553 EXPECT_TRUE(SameHandle(d_1, out));
1554 EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
1555 EXPECT_EQ("1", c.DebugString(out));
1556
1557 // Minimum with two dimensions.
1558 EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
1559 EXPECT_TRUE(SameHandle(d_1, out));
1560 EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
1561 EXPECT_TRUE(SameHandle(d_1, out));
1562 EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
1563 EXPECT_TRUE(SameHandle(d_1, out));
1564 EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
1565 EXPECT_TRUE(SameHandle(d_2, out));
1566 }
1567
TEST_F(ShapeInferenceTest,Max)1568 TEST_F(ShapeInferenceTest, Max) {
1569 NodeDef def;
1570 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1571 {});
1572
1573 auto s = c.input(0);
1574 auto d_1 = c.Dim(s, 0);
1575 auto d_2 = c.Dim(s, 1);
1576 auto d_unknown = c.Dim(s, 2);
1577
1578 // Maximum involving unknowns gives new unknown.
1579 DimensionHandle out;
1580 EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
1581 EXPECT_EQ("?", c.DebugString(out));
1582 EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
1583 EXPECT_EQ("?", c.DebugString(out));
1584 EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
1585 EXPECT_EQ("?", c.DebugString(out));
1586
1587 // Maximum with constant second arg.
1588 EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
1589 EXPECT_TRUE(SameHandle(d_1, out));
1590 EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
1591 EXPECT_TRUE(SameHandle(d_2, out));
1592 EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
1593 EXPECT_EQ("3", c.DebugString(out));
1594
1595 // Maximum with two dimensions.
1596 EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
1597 EXPECT_TRUE(SameHandle(d_1, out));
1598 EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
1599 EXPECT_TRUE(SameHandle(d_2, out));
1600 EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
1601 EXPECT_TRUE(SameHandle(d_2, out));
1602 EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
1603 EXPECT_TRUE(SameHandle(d_2, out));
1604 }
1605
TestMergeHandles(bool input_not_output)1606 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
1607 NodeDef def;
1608 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1609 {});
1610 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1611 ShapeHandle s;
1612 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1613 return s;
1614 };
1615 auto get_shapes_and_types_from_context = [&](int idx) {
1616 if (input_not_output) {
1617 return c.input_handle_shapes_and_types(idx);
1618 } else {
1619 return c.output_handle_shapes_and_types(idx);
1620 }
1621 };
1622 auto merge_shapes_and_types_to_context =
1623 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1624 if (input_not_output) {
1625 return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1626 } else {
1627 return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1628 }
1629 };
1630
1631 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1632 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1633
1634 // First merge will take the input completely.
1635 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1636 {c.UnknownShape(), DT_INVALID},
1637 {make_shape({4, 3, 2, 1}), DT_INT32}};
1638 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1639 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1640 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1641 ASSERT_EQ(3, v.size());
1642 for (int i = 0; i < v.size(); ++i) {
1643 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1644 EXPECT_EQ(t[i].dtype, v[i].dtype);
1645 }
1646
1647 // Merge that fails because wrong number of values passed.
1648 // Fails, and no changes made.
1649 ASSERT_FALSE(merge_shapes_and_types_to_context(
1650 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1651 v = *get_shapes_and_types_from_context(0);
1652 ASSERT_EQ(3, v.size());
1653 for (int i = 0; i < v.size(); ++i) {
1654 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1655 EXPECT_EQ(t[i].dtype, v[i].dtype);
1656 }
1657
1658 // Only difference is in a mismatched shape. That is ignored,
1659 // and there are no other changes, so nothing is done.
1660 //
1661 // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1662 // return an error (separate error from 'refined' output)?
1663 auto t2 = t;
1664 t2[2].shape = make_shape({4, 3, 4, 1});
1665 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1666 v = *get_shapes_and_types_from_context(0);
1667 ASSERT_EQ(3, v.size());
1668 for (int i = 0; i < v.size(); ++i) {
1669 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1670 EXPECT_EQ(t[i].dtype, v[i].dtype);
1671 }
1672
1673 // Only difference is in a mismatched dtype, but that cannot be
1674 // updated unless original dtype is DT_INVALID.
1675 t2 = t;
1676 t2[2].dtype = DT_FLOAT;
1677 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1678 v = *get_shapes_and_types_from_context(0);
1679 ASSERT_EQ(3, v.size());
1680 for (int i = 0; i < v.size(); ++i) {
1681 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1682 EXPECT_EQ(t[i].dtype, v[i].dtype);
1683 }
1684
1685 // Difference is mergeable (new shape).
1686 t[1].shape = make_shape({1, 10});
1687 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1688 v = *get_shapes_and_types_from_context(0);
1689 ASSERT_EQ(3, v.size());
1690 for (int i = 0; i < v.size(); ++i) {
1691 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1692 EXPECT_EQ(t[i].dtype, v[i].dtype);
1693 }
1694
1695 // Difference is mergeable (new type).
1696 t[1].dtype = DT_DOUBLE;
1697 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1698 v = *get_shapes_and_types_from_context(0);
1699 ASSERT_EQ(3, v.size());
1700 for (int i = 0; i < v.size(); ++i) {
1701 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1702 EXPECT_EQ(t[i].dtype, v[i].dtype);
1703 }
1704
1705 // No difference.
1706 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1707 }
1708
TEST_F(ShapeInferenceTest,MergeInputHandleShapesAndTypes)1709 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
1710 TestMergeHandles(true /* input_not_output */);
1711 }
1712
TEST_F(ShapeInferenceTest,MergeOutputHandleShapesAndTypes)1713 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
1714 TestMergeHandles(false /* input_not_output */);
1715 }
1716
TestRelaxHandles(bool input_not_output)1717 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
1718 NodeDef def;
1719 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1720 {});
1721 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1722 ShapeHandle s;
1723 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1724 return s;
1725 };
1726 auto get_shapes_and_types_from_context = [&](int idx) {
1727 if (input_not_output) {
1728 return c.input_handle_shapes_and_types(idx);
1729 } else {
1730 return c.output_handle_shapes_and_types(idx);
1731 }
1732 };
1733 auto relax_shapes_and_types_to_context =
1734 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1735 if (input_not_output) {
1736 return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1737 } else {
1738 return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1739 }
1740 };
1741
1742 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1743 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1744
1745 // First relax will take the input completely.
1746 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1747 {c.UnknownShape(), DT_INVALID},
1748 {make_shape({4, 3, 2, 1}), DT_INT32}};
1749 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1750 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1751 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1752 ASSERT_EQ(3, v.size());
1753 for (int i = 0; i < v.size(); ++i) {
1754 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1755 EXPECT_EQ(t[i].dtype, v[i].dtype);
1756 }
1757
1758 // Relax that fails because wrong number of values passed.
1759 // Fails, and no changes made.
1760 ASSERT_FALSE(relax_shapes_and_types_to_context(
1761 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1762 v = *get_shapes_and_types_from_context(0);
1763 ASSERT_EQ(3, v.size());
1764 for (int i = 0; i < v.size(); ++i) {
1765 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1766 EXPECT_EQ(t[i].dtype, v[i].dtype);
1767 }
1768
1769 // Only difference is in a mismatched shape. This should replace
1770 // the mismatched dimension with an UnknownDim.
1771 auto t2 = t;
1772 t2[2].shape = make_shape({4, 3, 4, 1});
1773 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1774 v = *get_shapes_and_types_from_context(0);
1775 EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1776 for (int i = 0; i < v.size(); ++i) {
1777 EXPECT_EQ(t[i].dtype, v[i].dtype);
1778 }
1779
1780 // Only difference is in a mismatched dtype, but that cannot be
1781 // updated unless original dtype is DT_INVALID.
1782 t2 = t;
1783 t2[2].dtype = DT_FLOAT;
1784 ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1785 v = *get_shapes_and_types_from_context(0);
1786 ASSERT_EQ(3, v.size());
1787 for (int i = 0; i < v.size(); ++i) {
1788 EXPECT_EQ(t[i].dtype, v[i].dtype);
1789 }
1790
1791 // Difference is a new shape, which will result in a new UnknownShape.
1792 t[1].shape = make_shape({1, 10});
1793 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1794 v = *get_shapes_and_types_from_context(0);
1795 ASSERT_EQ(3, v.size());
1796 EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1797 EXPECT_EQ("?", c.DebugString(v[1].shape));
1798 for (int i = 0; i < v.size(); ++i) {
1799 EXPECT_EQ(t[i].dtype, v[i].dtype);
1800 }
1801
1802 // Difference is relaxable (new type).
1803 t[1].dtype = DT_DOUBLE;
1804 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1805 v = *get_shapes_and_types_from_context(0);
1806 EXPECT_EQ(t[1].dtype, v[1].dtype);
1807 }
1808
TEST_F(ShapeInferenceTest,RelaxInputHandleShapesAndTypes)1809 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) {
1810 TestRelaxHandles(true /* input_not_output */);
1811 }
1812
TEST_F(ShapeInferenceTest,RelaxOutputHandleShapesAndTypes)1813 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) {
1814 TestRelaxHandles(false /* input_not_output */);
1815 }
1816
1817 } // namespace shape_inference
1818 } // namespace tensorflow
1819