• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/fake_input.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/framework/op_def_builder.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 
29 namespace tensorflow {
30 namespace shape_inference {
31 namespace {
32 
33 #define EXPECT_CONTAINS(X, Y)                                    \
34   do {                                                           \
35     auto XSTR = X;                                               \
36     auto YSTR = Y;                                               \
37     EXPECT_TRUE(absl::StrContains(XSTR, YSTR))                   \
38         << "'" << XSTR << "' does not contain '" << YSTR << "'"; \
39   } while (false);
40 
MakeOpDefWithLists()41 OpDef MakeOpDefWithLists() {
42   OpRegistrationData op_reg_data;
43   OpDefBuilder b("dummy");
44   b.Input(strings::StrCat("input: N * float"));
45   b.Output(strings::StrCat("output: N * float"));
46   CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok());
47   return op_reg_data.op_def;
48 }
49 
S(std::initializer_list<int64> dims)50 PartialTensorShape S(std::initializer_list<int64> dims) {
51   return PartialTensorShape(dims);
52 }
53 
Unknown()54 PartialTensorShape Unknown() { return PartialTensorShape(); }
55 
56 }  // namespace
57 
58 class ShapeInferenceTest : public ::testing::Test {
59  protected:
60   // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)61   bool SameHandle(DimensionHandle a, DimensionHandle b) {
62     return a.SameHandle(b);
63   }
SameHandle(ShapeHandle a,ShapeHandle b)64   bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)65   bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)66   bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)67   void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
68              DimensionHandle* out) {
69     c->Relax(d0, d1, out);
70   }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)71   void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
72              ShapeHandle* out) {
73     c->Relax(s0, s1, out);
74   }
75   void TestMergeHandles(bool input_not_output);
76   void TestRelaxHandles(bool input_not_output);
77 
78   static constexpr int kVersion = 0;  // used for graph-def version.
79 };
80 
TEST_F(ShapeInferenceTest,InputOutputByName)81 TEST_F(ShapeInferenceTest, InputOutputByName) {
82   // Setup test to contain an input tensor list of size 3.
83   OpDef op_def = MakeOpDefWithLists();
84   NodeDef def;
85   auto s = NodeDefBuilder("dummy", &op_def)
86                .Attr("N", 3)
87                .Input(FakeInput(DT_FLOAT))
88                .Finalize(&def);
89   InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
90                      {}, {}, {});
91 
92   EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
93   EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
94   EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
95   // Test getters.
96   std::vector<ShapeHandle> shapes;
97   EXPECT_FALSE(c.input("nonexistent", &shapes).ok());
98   TF_EXPECT_OK(c.input("input", &shapes));
99   EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
100   EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
101   EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
102 
103   // Test setters.
104   EXPECT_FALSE(c.set_output("nonexistent", shapes).ok());
105   TF_EXPECT_OK(c.set_output("output", shapes));
106   EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
107   EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
108   EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
109 }
110 
MakeOpDef(int num_inputs,int num_outputs)111 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
112   OpRegistrationData op_reg_data;
113   OpDefBuilder b("dummy");
114   for (int i = 0; i < num_inputs; ++i) {
115     b.Input(strings::StrCat("i", i, ": float"));
116   }
117   for (int i = 0; i < num_outputs; ++i) {
118     b.Output(strings::StrCat("o", i, ": float"));
119   }
120   CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
121   return op_reg_data.op_def;
122 }
123 
TEST_F(ShapeInferenceTest,DimensionOrConstant)124 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
125   NodeDef def;
126   InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
127   EXPECT_EQ(InferenceContext::kUnknownDim,
128             c.Value(InferenceContext::kUnknownDim));
129   EXPECT_EQ(1, c.Value(1));
130 
131 #ifndef NDEBUG
132   // Only run death test if DCHECKS are enabled.
133   EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
134 #endif
135 }
136 
TEST_F(ShapeInferenceTest,Run)137 TEST_F(ShapeInferenceTest, Run) {
138   NodeDef def;
139   def.set_name("foo");
140   def.set_op("foo_op");
141   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
142   TF_ASSERT_OK(c.construction_status());
143 
144   {
145     auto fn = [](InferenceContext* c) {
146       ShapeHandle h;
147       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
148       c->set_output(0, c->input(0));
149       c->set_output(1, c->input(0));
150       return Status::OK();
151     };
152     TF_ASSERT_OK(c.Run(fn));
153   }
154 
155   {
156     auto fn = [](InferenceContext* c) {
157       ShapeHandle h;
158       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
159       c->set_output(0, c->input(0));
160       c->set_output(1, c->input(0));
161       return Status::OK();
162     };
163     auto s = c.Run(fn).ToString();
164     // Extra error message is attached when Run fails.
165     EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
166     EXPECT_CONTAINS(s, "node foo");
167     EXPECT_CONTAINS(s, "foo_op");
168   }
169 }
170 
171 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)172 TEST_F(ShapeInferenceTest, AttachContext) {
173   NodeDef def;
174   def.set_name("foo");
175   def.set_op("foo_op");
176   // Error when no constant tensors were requested.
177   {
178     InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
179                        {});
180     TF_ASSERT_OK(c.construction_status());
181     auto fn = [](InferenceContext* c) {
182       ShapeHandle h;
183       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
184       c->set_output(0, c->input(0));
185       return Status::OK();
186     };
187     auto s = c.Run(fn).ToString();
188     EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 3");
189     EXPECT_CONTAINS(s, "node foo");
190     EXPECT_CONTAINS(s, "foo_op");
191     EXPECT_CONTAINS(s, "input shapes: [1,2,3]");
192   }
193 
194   // Error when a constant tensor value was requested.
195   {
196     Tensor input_t =
197         ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
198     InferenceContext c(kVersion, def, MakeOpDef(2, 2),
199                        {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
200     TF_ASSERT_OK(c.construction_status());
201     auto fn = [](InferenceContext* c) {
202       c->input_tensor(0);  // get this one, but it's null - won't be in error.
203       c->input_tensor(1);  // get this one, will now be in error.
204       ShapeHandle h;
205       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
206       c->set_output(0, c->input(0));
207       return Status::OK();
208     };
209     auto s = c.Run(fn).ToString();
210     EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 3");
211     EXPECT_CONTAINS(s, "node foo");
212     EXPECT_CONTAINS(s, "foo_op");
213     EXPECT_CONTAINS(
214         s,
215         "input shapes: [1,2,3], [4,5] and with computed input tensors: "
216         "input[1] = <1.1 2.2 3.3 4.4 5.5>.");
217   }
218 
219   // Error when a constant tensor value as shape was requested, but no partial
220   // shapes provided.
221   {
222     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
223     InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
224                        {nullptr, &input_t}, {}, {});
225     TF_ASSERT_OK(c.construction_status());
226     auto fn = [](InferenceContext* c) {
227       ShapeHandle s;
228       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
229       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
230       ShapeHandle h;
231       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
232       c->set_output(0, c->input(0));
233       return Status::OK();
234     };
235     auto s = c.Run(fn).ToString();
236     EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
237     EXPECT_CONTAINS(s, "node foo");
238     EXPECT_CONTAINS(s, "foo_op");
239     EXPECT_CONTAINS(
240         s,
241         "with input shapes: [3], [4] and with computed input tensors: input[1] "
242         "= <1 2 3 4 5>.");
243   }
244 
245   // Error when a constant tensor value as shape was requested, and a partial
246   // shape was provided.
247   {
248     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
249     InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
250                        {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
251     TF_ASSERT_OK(c.construction_status());
252     auto fn = [](InferenceContext* c) {
253       ShapeHandle s;
254       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
255       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
256       ShapeHandle h;
257       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
258       c->set_output(0, c->input(0));
259       return Status::OK();
260     };
261     auto s = c.Run(fn).ToString();
262     EXPECT_CONTAINS(s, "Shape must be at most rank 0 but is rank 1");
263     EXPECT_CONTAINS(s, "node foo");
264     EXPECT_CONTAINS(s, "foo_op");
265     EXPECT_CONTAINS(
266         s,
267         "with input shapes: [3], [4] and with computed "
268         "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
269         "as partial shapes: input[0] = [10,?,5].");
270   }
271 }
272 
TEST_F(ShapeInferenceTest,RankAndDimInspection)273 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
274   NodeDef def;
275   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
276                      {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
277   EXPECT_EQ(3, c.num_inputs());
278   EXPECT_EQ(2, c.num_outputs());
279 
280   auto in0 = c.input(0);
281   EXPECT_EQ("?", c.DebugString(in0));
282   EXPECT_FALSE(c.RankKnown(in0));
283   EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
284   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
285   EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
286   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
287 
288   auto in1 = c.input(1);
289   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
290   EXPECT_TRUE(c.RankKnown(in1));
291   EXPECT_EQ(3, c.Rank(in1));
292   auto d = c.Dim(in1, 0);
293   EXPECT_EQ(1, c.Value(d));
294   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
295   EXPECT_TRUE(c.ValueKnown(d));
296   EXPECT_EQ("1", c.DebugString(d));
297   d = c.Dim(in1, 1);
298   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
299   EXPECT_FALSE(c.ValueKnown(d));
300   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
301   EXPECT_EQ("?", c.DebugString(d));
302   d = c.Dim(in1, 2);
303   EXPECT_EQ(3, c.Value(d));
304   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
305   EXPECT_TRUE(c.ValueKnown(d));
306   EXPECT_EQ("3", c.DebugString(d));
307 
308   auto in2 = c.input(2);
309   EXPECT_EQ("[]", c.DebugString(in2));
310   EXPECT_TRUE(c.RankKnown(in2));
311   EXPECT_EQ(0, c.Rank(in2));
312 }
313 
TEST_F(ShapeInferenceTest,NumElements)314 TEST_F(ShapeInferenceTest, NumElements) {
315   NodeDef def;
316   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
317                      {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
318 
319   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
320   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
321 
322   // Different handles (not the same unknown value).
323   EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
324 
325   EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
326 }
327 
TEST_F(ShapeInferenceTest,WithRank)328 TEST_F(ShapeInferenceTest, WithRank) {
329   NodeDef def;
330   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
331                      {}, {}, {});
332 
333   auto in0 = c.input(0);
334   auto in1 = c.input(1);
335   ShapeHandle s1;
336   ShapeHandle s2;
337 
338   // WithRank on a shape with unknown dimensionality always succeeds.
339   EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
340   EXPECT_EQ("[?]", c.DebugString(s1));
341 
342   EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
343   EXPECT_EQ("[?,?]", c.DebugString(s2));
344   EXPECT_FALSE(SameHandle(s1, s2));
345   EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
346 
347   EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
348   EXPECT_EQ("[?]", c.DebugString(s2));
349   EXPECT_FALSE(SameHandle(s1, s2));
350 
351   EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
352   EXPECT_EQ("[]", c.DebugString(s1));
353 
354   // WithRank on shape with known dimensionality.
355   s1 = in1;
356   Status status = c.WithRank(in1, 2, &s1);
357   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
358   EXPECT_CONTAINS(status.error_message(), "Shape must be rank 2 but is rank 3");
359   EXPECT_FALSE(IsSet(s1));
360   EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
361   EXPECT_TRUE(SameHandle(s1, in1));
362 
363   // Inputs are unchanged.
364   EXPECT_EQ("?", c.DebugString(in0));
365   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
366 }
367 
TEST_F(ShapeInferenceTest,WithRankAtMost)368 TEST_F(ShapeInferenceTest, WithRankAtMost) {
369   NodeDef def;
370   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
371                      {}, {}, {});
372 
373   auto in0 = c.input(0);
374   auto in1 = c.input(1);
375   ShapeHandle s1;
376   ShapeHandle s2;
377 
378   // WithRankAtMost on a shape with unknown dimensionality always succeeds.
379   EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
380   EXPECT_EQ("?", c.DebugString(s1));
381   EXPECT_TRUE(SameHandle(in0, s1));
382 
383   EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
384   EXPECT_EQ("?", c.DebugString(s2));
385   EXPECT_TRUE(SameHandle(s1, s2));
386 
387   // WithRankAtMost on shape with known dimensionality.
388   s1 = in1;
389   Status status = c.WithRankAtMost(in1, 2, &s1);
390   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
391   EXPECT_CONTAINS(status.error_message(),
392                   "Shape must be at most rank 2 but is rank 3");
393 
394   EXPECT_FALSE(IsSet(s1));
395   EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
396   EXPECT_TRUE(SameHandle(s1, in1));
397   EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok());
398   EXPECT_TRUE(SameHandle(s1, in1));
399   EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok());
400   EXPECT_TRUE(SameHandle(s1, in1));
401 
402   // Inputs are unchanged.
403   EXPECT_EQ("?", c.DebugString(in0));
404   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
405 }
406 
TEST_F(ShapeInferenceTest,WithRankAtLeast)407 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
408   NodeDef def;
409   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
410                      {}, {}, {});
411 
412   auto in0 = c.input(0);
413   auto in1 = c.input(1);
414   ShapeHandle s1;
415   ShapeHandle s2;
416 
417   // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
418   EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
419   EXPECT_EQ("?", c.DebugString(s1));
420   EXPECT_TRUE(SameHandle(in0, s1));
421 
422   EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
423   EXPECT_EQ("?", c.DebugString(s2));
424   EXPECT_TRUE(SameHandle(s1, s2));
425 
426   // WithRankAtLeast on shape with known dimensionality.
427   s1 = in1;
428   Status status = c.WithRankAtLeast(in1, 4, &s1);
429   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
430   EXPECT_CONTAINS(status.error_message(),
431                   "Shape must be at least rank 4 but is rank 3");
432 
433   EXPECT_FALSE(IsSet(s1));
434   EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
435   EXPECT_TRUE(SameHandle(s1, in1));
436   EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok());
437   EXPECT_TRUE(SameHandle(s1, in1));
438   EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok());
439   EXPECT_TRUE(SameHandle(s1, in1));
440 
441   // Inputs are unchanged.
442   EXPECT_EQ("?", c.DebugString(in0));
443   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
444 }
445 
TEST_F(ShapeInferenceTest,WithValue)446 TEST_F(ShapeInferenceTest, WithValue) {
447   NodeDef def;
448   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
449 
450   auto d0 = c.Dim(c.input(0), 0);
451   auto d1 = c.Dim(c.input(0), 1);
452   DimensionHandle out1;
453   DimensionHandle out2;
454 
455   // WithValue on a dimension with unknown value always succeeds.
456   EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
457   EXPECT_EQ(1, c.Value(out1));
458 
459   EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
460   EXPECT_EQ(2, c.Value(out2));
461   EXPECT_FALSE(SameHandle(out1, out2));
462   EXPECT_FALSE(SameHandle(out1, d1));
463 
464   EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
465   EXPECT_EQ(1, c.Value(out2));
466   EXPECT_FALSE(SameHandle(out1, out2));
467 
468   // WithValue on dimension with known size.
469   out1 = d0;
470 
471   Status status = c.WithValue(d0, 0, &out1);
472   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
473   EXPECT_CONTAINS(status.error_message(), "Dimension must be 0 but is 1");
474   EXPECT_FALSE(IsSet(out1));
475   out1 = d0;
476   status = c.WithValue(d0, 2, &out1);
477   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
478   EXPECT_CONTAINS(status.error_message(), "Dimension must be 2 but is 1");
479 
480   EXPECT_FALSE(IsSet(out1));
481   EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
482   EXPECT_TRUE(SameHandle(d0, out1));
483 
484   // Inputs are unchanged.
485   EXPECT_EQ("1", c.DebugString(d0));
486   EXPECT_EQ("?", c.DebugString(d1));
487 }
488 
TEST_F(ShapeInferenceTest,MergeDim)489 TEST_F(ShapeInferenceTest, MergeDim) {
490   NodeDef def;
491   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
492                      {}, {});
493 
494   auto d2 = c.Dim(c.input(0), 0);
495   auto d_unknown = c.Dim(c.input(0), 1);
496   auto d2_b = c.Dim(c.input(0), 2);
497   auto d1 = c.Dim(c.input(0), 3);
498   auto d_unknown_b = c.Dim(c.input(0), 4);
499   DimensionHandle out;
500 
501   // Merging anything with unknown returns the same pointer.
502   EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
503   EXPECT_TRUE(SameHandle(d2, out));
504   EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
505   EXPECT_TRUE(SameHandle(d2, out));
506   EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
507   EXPECT_TRUE(SameHandle(d_unknown, out));
508 
509   auto merged_dims = c.MergedDims();
510   ASSERT_EQ(3, merged_dims.size());
511   EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
512   EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
513   EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
514   EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
515   EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
516   EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
517 
518   // Merging with self is a no-op and returns self.
519   EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
520   EXPECT_TRUE(SameHandle(d2, out));
521   EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
522   EXPECT_TRUE(SameHandle(d_unknown, out));
523 
524   merged_dims = c.MergedDims();
525   EXPECT_EQ(3, merged_dims.size());
526 
527   // Merging equal values is a no op and returns first one.
528   EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
529   EXPECT_TRUE(SameHandle(d2, out));
530   EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
531   EXPECT_TRUE(SameHandle(d2_b, out));
532 
533   merged_dims = c.MergedDims();
534   EXPECT_EQ(3, merged_dims.size());
535 
536   // Merging unequal values is an error.
537   Status status = c.Merge(d2, d1, &out);
538   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
539   EXPECT_CONTAINS(status.error_message(),
540                   "Dimensions must be equal, but are 2 and 1");
541 
542   EXPECT_FALSE(IsSet(out));
543   status = c.Merge(d1, d2, &out);
544   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
545   EXPECT_CONTAINS(status.error_message(),
546                   "Dimensions must be equal, but are 1 and 2");
547 
548   EXPECT_FALSE(IsSet(out));
549 
550   merged_dims = c.MergedDims();
551   EXPECT_EQ(3, merged_dims.size());
552 }
553 
TEST_F(ShapeInferenceTest,RelaxDim)554 TEST_F(ShapeInferenceTest, RelaxDim) {
555   NodeDef def;
556   InferenceContext c(kVersion, def, MakeOpDef(1, 2),
557                      {S({2, InferenceContext::kUnknownDim, 2, 1,
558                          InferenceContext::kUnknownDim})},
559                      {}, {}, {});
560 
561   auto d2 = c.Dim(c.input(0), 0);
562   auto d_unknown = c.Dim(c.input(0), 1);
563   auto d2_b = c.Dim(c.input(0), 2);
564   auto d1 = c.Dim(c.input(0), 3);
565   auto d_unknown_b = c.Dim(c.input(0), 4);
566   DimensionHandle out;
567 
568   // Relaxing anything with unknown returns a new unknown or the existing
569   // unknown.
570   Relax(&c, d2, d_unknown, &out);
571   EXPECT_TRUE(SameHandle(d_unknown, out));
572   EXPECT_FALSE(SameHandle(d_unknown_b, out));
573   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
574   Relax(&c, d_unknown, d2, &out);
575   EXPECT_FALSE(SameHandle(d_unknown, out));
576   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
577   Relax(&c, d_unknown, d_unknown_b, &out);
578   EXPECT_FALSE(SameHandle(d_unknown, out));
579   EXPECT_TRUE(SameHandle(d_unknown_b, out));
580   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
581 
582   // Relaxing with self returns self.
583   Relax(&c, d2, d2, &out);
584   EXPECT_TRUE(SameHandle(d2, out));
585   Relax(&c, d_unknown, d_unknown, &out);
586   EXPECT_TRUE(SameHandle(d_unknown, out));
587 
588   // Relaxing equal values returns first one.
589   Relax(&c, d2, d2_b, &out);
590   EXPECT_TRUE(SameHandle(d2, out));
591   Relax(&c, d2_b, d2, &out);
592   EXPECT_TRUE(SameHandle(d2_b, out));
593 
594   // Relaxing unequal values returns a new unknown.
595   Relax(&c, d2, d1, &out);
596   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
597   Relax(&c, d1, d2, &out);
598   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
599 }
600 
TEST_F(ShapeInferenceTest,RelaxShape)601 TEST_F(ShapeInferenceTest, RelaxShape) {
602   NodeDef def;
603   InferenceContext c(
604       kVersion, def, MakeOpDef(7, 2),
605       {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
606        S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
607       {}, {}, {});
608 
609   auto s_unknown = c.input(0);
610   auto s_1_2 = c.input(1);
611   auto s_u_2 = c.input(2);
612   auto s_1_u = c.input(3);
613   auto s_1_3 = c.input(4);
614   auto s_unknown_b = c.input(5);
615   auto s_1 = c.input(6);
616   ShapeHandle out;
617 
618   // Relaxing any shape with unknown returns a new unknown.
619   Relax(&c, s_unknown, s_1_2, &out);
620   EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
621   EXPECT_EQ("?", c.DebugString(out));
622   Relax(&c, s_u_2, s_unknown, &out);
623   EXPECT_FALSE(SameHandle(s_u_2, out));
624   EXPECT_EQ("?", c.DebugString(out));
625   Relax(&c, s_unknown, s_unknown_b, &out);
626   EXPECT_FALSE(SameHandle(s_unknown, out));
627   EXPECT_TRUE(SameHandle(s_unknown_b, out));
628   EXPECT_EQ("?", c.DebugString(out));
629 
630   // Relaxing with self returns self.
631   Relax(&c, s_1_2, s_1_2, &out);
632   EXPECT_TRUE(SameHandle(out, s_1_2));
633 
634   // Relaxing where one of the inputs has less information.
635   out = ShapeHandle();
636   Relax(&c, s_1_2, s_u_2, &out);
637   EXPECT_FALSE(SameHandle(s_u_2, out));
638   EXPECT_EQ("[?,2]", c.DebugString(out));
639   out = ShapeHandle();
640   Relax(&c, s_u_2, s_1_2, &out);
641   EXPECT_FALSE(SameHandle(s_u_2, out));
642   EXPECT_EQ("[?,2]", c.DebugString(out));
643 
644   // Relaxing where each input has one distinct unknown dimension.
645   Relax(&c, s_u_2, s_1_u, &out);
646   EXPECT_EQ("[?,?]", c.DebugString(out));
647   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
648   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
649   auto s_u1 = c.UnknownShapeOfRank(1);
650   auto s_u2 = c.UnknownShapeOfRank(1);
651   Relax(&c, s_u1, s_u2, &out);
652   EXPECT_FALSE(SameHandle(s_u1, out));
653 
654   // Relaxing with mismatched values in a dimension returns a shape with that
655   // dimension unknown.
656   out = s_unknown;
657   Relax(&c, s_u_2, s_1_3, &out);
658   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
659   EXPECT_EQ("[?,?]", c.DebugString(out));
660   out = s_unknown;
661   Relax(&c, s_1_3, s_u_2, &out);
662   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
663   EXPECT_EQ("[?,?]", c.DebugString(out));
664   out = s_unknown;
665 
666   // Relaxing with mismatched ranks returns a new unknown.
667   Relax(&c, s_1, s_1_2, &out);
668   EXPECT_EQ("?", c.DebugString(out));
669 }
670 
TEST_F(ShapeInferenceTest,MergeShape)671 TEST_F(ShapeInferenceTest, MergeShape) {
672   NodeDef def;
673   InferenceContext c(kVersion, def, MakeOpDef(7, 2),
674                      {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
675                       Unknown(), S({1})},
676                      {}, {}, {});
677 
678   auto s_unknown = c.input(0);
679   auto s_1_2 = c.input(1);
680   auto s_u_2 = c.input(2);
681   auto s_1_u = c.input(3);
682   auto s_1_3 = c.input(4);
683   auto s_unknown_b = c.input(5);
684   auto s_1 = c.input(6);
685   ShapeHandle out;
686 
687   // Merging any shape with unknown returns the shape.
688   EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
689   EXPECT_TRUE(SameHandle(s_1_2, out));
690   EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
691   EXPECT_TRUE(SameHandle(s_u_2, out));
692   EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
693   EXPECT_TRUE(SameHandle(s_unknown, out));
694 
695   auto merged_shapes = c.MergedShapes();
696   ASSERT_EQ(3, merged_shapes.size());
697   EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
698   EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
699   EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
700   EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
701   EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
702   EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
703 
704   // Merging with self returns self.
705   EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
706   EXPECT_TRUE(SameHandle(out, s_1_2));
707 
708   merged_shapes = c.MergedShapes();
709   EXPECT_EQ(3, merged_shapes.size());
710 
711   // Merging where one of the inputs is the right answer - return that input.
712   out = ShapeHandle();
713   EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
714   EXPECT_TRUE(SameHandle(s_1_2, out));
715   out = ShapeHandle();
716   EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
717   EXPECT_TRUE(SameHandle(s_1_2, out));
718 
719   merged_shapes = c.MergedShapes();
720   ASSERT_EQ(5, merged_shapes.size());
721   EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
722   EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
723   EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
724   EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
725 
726   // Merging where neither input is the right answer.
727   EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
728   EXPECT_FALSE(SameHandle(out, s_u_2));
729   EXPECT_FALSE(SameHandle(out, s_1_u));
730   EXPECT_EQ("[1,2]", c.DebugString(out));
731   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
732   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
733 
734   merged_shapes = c.MergedShapes();
735   ASSERT_EQ(7, merged_shapes.size());
736   EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
737   EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
738   EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
739   EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
740 
741   auto s_u1 = c.UnknownShapeOfRank(1);
742   auto s_u2 = c.UnknownShapeOfRank(1);
743   TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
744   EXPECT_TRUE(SameHandle(s_u1, out));
745 
746   merged_shapes = c.MergedShapes();
747   ASSERT_EQ(8, merged_shapes.size());
748   EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
749   EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
750 
751   // Incompatible merges give errors and set out to nullptr.
752   out = s_unknown;
753   Status status = c.Merge(s_u_2, s_1_3, &out);
754   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
755   EXPECT_CONTAINS(status.error_message(),
756                   "Dimension 1 in both shapes must be equal, but are 2 and 3");
757 
758   EXPECT_FALSE(IsSet(out));
759   out = s_unknown;
760   status = c.Merge(s_1_3, s_u_2, &out);
761   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
762   EXPECT_CONTAINS(status.error_message(),
763                   "Dimension 1 in both shapes must be equal, but are 3 and 2");
764 
765   EXPECT_FALSE(IsSet(out));
766   out = s_unknown;
767   status = c.Merge(s_1, s_1_2, &out);
768   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
769   EXPECT_CONTAINS(status.error_message(),
770                   "Shapes must be equal rank, but are 1 and 2");
771 
772   EXPECT_FALSE(IsSet(out));
773 
774   merged_shapes = c.MergedShapes();
775   EXPECT_EQ(8, merged_shapes.size());
776 }
777 
TEST_F(ShapeInferenceTest,MergePrefix)778 TEST_F(ShapeInferenceTest, MergePrefix) {
779   NodeDef def;
780   InferenceContext c(kVersion, def, MakeOpDef(4, 2),
781                      {
782                          Unknown(),
783                          S({-1, 2}),
784                          S({1, -1, 3}),
785                          S({2, 4}),
786                      },
787                      {}, {}, {});
788 
789   auto s_unknown = c.input(0);
790   auto s_u_2 = c.input(1);
791   auto s_1_u_3 = c.input(2);
792   auto s_2_4 = c.input(3);
793 
794   ShapeHandle s_out;
795   ShapeHandle s_prefix_out;
796 
797   // Merging with unknown returns the inputs.
798   EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
799   EXPECT_TRUE(SameHandle(s_out, s_unknown));
800   EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
801   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok());
802   EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
803   EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
804 
805   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok());
806   EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
807   EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
808   EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
809   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
810   EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
811   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
812   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
813 
814   // Incompatible merges give errors and set outs to nullptr.
815   s_out = s_unknown;
816   s_prefix_out = s_unknown;
817   Status status = c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out);
818   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
819   EXPECT_CONTAINS(status.error_message(),
820                   "Dimensions must be equal, but are 1 and 2");
821 
822   EXPECT_FALSE(IsSet(s_out));
823   EXPECT_FALSE(IsSet(s_prefix_out));
824 
825   s_out = s_unknown;
826   s_prefix_out = s_unknown;
827   status = c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out);
828   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
829   EXPECT_CONTAINS(status.error_message(),
830                   "Shape must be at least rank 3 but is rank 2");
831   EXPECT_FALSE(IsSet(s_out));
832   EXPECT_FALSE(IsSet(s_prefix_out));
833 }
834 
TEST_F(ShapeInferenceTest,Subshape)835 TEST_F(ShapeInferenceTest, Subshape) {
836   NodeDef def;
837   InferenceContext c(kVersion, def, MakeOpDef(2, 2),
838                      {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
839 
840   ShapeHandle unknown = c.input(1);
841   ShapeHandle out;
842   EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
843   EXPECT_EQ("?", c.DebugString(out));
844   EXPECT_TRUE(SameHandle(out, unknown));
845   EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok());
846   EXPECT_EQ("?", c.DebugString(out));
847   EXPECT_FALSE(SameHandle(out, unknown));
848   EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok());
849   EXPECT_EQ("?", c.DebugString(out));
850   EXPECT_FALSE(SameHandle(out, unknown));
851 
852   const int kFullRank = 5;
853   ShapeHandle out_arr[4];
854   auto in0 = c.input(0);
855   EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
856   EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
857   EXPECT_TRUE(SameHandle(out, in0));
858   EXPECT_EQ(kFullRank, c.Rank(out));
859   for (int start = 0; start <= kFullRank + 1; ++start) {
860     for (int end = start; end <= kFullRank + 1; ++end) {
861       // Get subshapes using different start and end values that give the same
862       // range.
863       const int neg_start =
864           start >= kFullRank ? kFullRank : (start - kFullRank);
865       const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
866       ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok());
867       ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok());
868       ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok());
869       ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok());
870 
871       // Verify all computed subshapes.
872       for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
873         out = out_arr[arr_idx];
874         ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
875                   c.Rank(out))
876             << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
877             << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
878         for (int d = 0; d < c.Rank(out); ++d) {
879           EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
880               << "arr_idx: " << arr_idx;
881         }
882       }
883     }
884   }
885 
886   // Errors.
887   out = unknown;
888   Status status = c.Subshape(in0, 6, -3, &out);
889   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
890   EXPECT_CONTAINS(
891       status.error_message(),
892       "Subshape must have computed start <= end, but is 5 "
893       "and 2 (computed from start 6 and end -3 over shape with rank 5)");
894   EXPECT_FALSE(IsSet(out));
895   out = unknown;
896   status = c.Subshape(in0, -50, 100, &out);
897   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
898   EXPECT_CONTAINS(status.error_message(),
899                   "Subshape start out of bounds: -50, for shape with rank 5");
900 
901   EXPECT_FALSE(IsSet(out));
902   out = unknown;
903   status = c.Subshape(in0, 0, -50, &out);
904   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
905   EXPECT_CONTAINS(status.error_message(),
906                   "Subshape end out of bounds: -50, for shape with rank 5");
907 
908   EXPECT_FALSE(IsSet(out));
909 }
910 
TEST_F(ShapeInferenceTest,Concatenate)911 TEST_F(ShapeInferenceTest, Concatenate) {
912   NodeDef def;
913   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
914                      {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
915 
916   auto in0 = c.input(0);
917   auto in1 = c.input(1);
918   ShapeHandle unknown = c.input(2);
919   ShapeHandle out;
920   EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
921   EXPECT_EQ("?", c.DebugString(out));
922   EXPECT_FALSE(SameHandle(out, unknown));
923   EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok());
924   EXPECT_EQ("?", c.DebugString(out));
925   EXPECT_FALSE(SameHandle(out, unknown));
926 
927   EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok());
928   EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
929   int out_i = 0;
930   for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
931     EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
932   }
933   for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
934     EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
935   }
936 }
937 
TEST_F(ShapeInferenceTest,ReplaceDim)938 TEST_F(ShapeInferenceTest, ReplaceDim) {
939   NodeDef def;
940   InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
941                      {}, {}, {});
942 
943   auto in = c.input(0);
944   auto unknown = c.input(1);
945 
946   ShapeHandle replaced;
947   EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
948   EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
949   EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
950   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
951   EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
952   EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
953   EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
954   EXPECT_EQ("?", c.DebugString(replaced));
955 
956   // Negative indexing.
957   EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
958   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
959   EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
960   EXPECT_EQ("?", c.DebugString(replaced));
961 
962   // out of range indexing.
963   EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
964   EXPECT_FALSE(IsSet(replaced));
965   replaced = in;
966   EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
967   EXPECT_FALSE(IsSet(replaced));
968 }
969 
TEST_F(ShapeInferenceTest,MakeShape)970 TEST_F(ShapeInferenceTest, MakeShape) {
971   NodeDef def;
972   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
973                      {}, {});
974 
975   std::vector<DimensionHandle> dims;
976   auto in0 = c.input(0);
977   const int rank = c.Rank(in0);
978   dims.reserve(rank);
979   for (int i = 0; i < rank; ++i) {
980     dims.push_back(c.Dim(in0, rank - i - 1));
981   }
982 
983   auto s = c.MakeShape(dims);
984   EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
985   EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
986 
987   auto s2 = c.MakeShape(dims);
988   EXPECT_FALSE(SameHandle(s, s2));
989   EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
990 
991   auto s3 = c.MakeShape({1, 2, dims[2]});
992   EXPECT_FALSE(SameHandle(s, s3));
993   EXPECT_EQ("[1,2,3]", c.DebugString(s3));
994 }
995 
TEST_F(ShapeInferenceTest,UnknownShape)996 TEST_F(ShapeInferenceTest, UnknownShape) {
997   NodeDef def;
998   std::vector<ShapeHandle> empty;
999   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1000 
1001   auto u0 = c.UnknownShape();
1002   auto u1 = c.UnknownShape();
1003   EXPECT_EQ("?", c.DebugString(u0));
1004   EXPECT_EQ("?", c.DebugString(u1));
1005   EXPECT_FALSE(SameHandle(u0, u1));
1006 }
1007 
TEST_F(ShapeInferenceTest,KnownShapeToProto)1008 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
1009   NodeDef def;
1010   std::vector<ShapeHandle> empty;
1011   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1012 
1013   auto s = c.MakeShape({1, 2, 3});
1014   TensorShapeProto proto;
1015   c.ShapeHandleToProto(s, &proto);
1016 
1017   EXPECT_FALSE(proto.unknown_rank());
1018   EXPECT_EQ(3, proto.dim_size());
1019   EXPECT_EQ(1, proto.dim(0).size());
1020 }
1021 
TEST_F(ShapeInferenceTest,UnknownShapeToProto)1022 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
1023   NodeDef def;
1024   std::vector<ShapeHandle> empty;
1025   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1026 
1027   auto u0 = c.UnknownShape();
1028   TensorShapeProto proto;
1029   c.ShapeHandleToProto(u0, &proto);
1030 
1031   EXPECT_TRUE(proto.unknown_rank());
1032   EXPECT_EQ(0, proto.dim_size());
1033 }
1034 
TEST_F(ShapeInferenceTest,Scalar)1035 TEST_F(ShapeInferenceTest, Scalar) {
1036   NodeDef def;
1037   std::vector<ShapeHandle> empty;
1038   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1039 
1040   auto s0 = c.Scalar();
1041   EXPECT_EQ("[]", c.DebugString(s0));
1042   auto s1 = c.Scalar();
1043   EXPECT_EQ("[]", c.DebugString(s1));
1044 }
1045 
TEST_F(ShapeInferenceTest,Vector)1046 TEST_F(ShapeInferenceTest, Vector) {
1047   NodeDef def;
1048   std::vector<ShapeHandle> empty;
1049   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1050 
1051   auto s0 = c.Vector(1);
1052   EXPECT_EQ("[1]", c.DebugString(s0));
1053   auto s1 = c.Vector(InferenceContext::kUnknownDim);
1054   EXPECT_EQ("[?]", c.DebugString(s1));
1055 
1056   auto d1 = c.UnknownDim();
1057   auto s2 = c.Vector(d1);
1058   EXPECT_EQ("[?]", c.DebugString(s2));
1059   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1060 }
1061 
TEST_F(ShapeInferenceTest,Matrix)1062 TEST_F(ShapeInferenceTest, Matrix) {
1063   NodeDef def;
1064   std::vector<ShapeHandle> empty;
1065   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1066 
1067   auto s0 = c.Matrix(1, 2);
1068   EXPECT_EQ("[1,2]", c.DebugString(s0));
1069   auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1070   EXPECT_EQ("[0,?]", c.DebugString(s1));
1071 
1072   auto d1 = c.UnknownDim();
1073   auto d2 = c.UnknownDim();
1074   auto s2 = c.Matrix(d1, d2);
1075   EXPECT_EQ("[?,?]", c.DebugString(s2));
1076   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1077   EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1078 
1079   auto s3 = c.Matrix(d1, 100);
1080   EXPECT_EQ("[?,100]", c.DebugString(s3));
1081   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1082 }
1083 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1084 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1085   auto create = [&](Tensor* t) {
1086     NodeDef def;
1087     InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1088                        {});
1089     ShapeHandle out;
1090     Status s = c.MakeShapeFromShapeTensor(0, &out);
1091     if (s.ok()) {
1092       return c.DebugString(out);
1093     } else {
1094       EXPECT_FALSE(IsSet(out));
1095       return s.error_message();
1096     }
1097   };
1098 
1099   Tensor t;
1100   EXPECT_EQ("?", create(nullptr));
1101 
1102   t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1103   EXPECT_EQ("[1,2,3]", create(&t));
1104 
1105   t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
1106   EXPECT_EQ("[3,2,1]", create(&t));
1107 
1108   t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
1109   EXPECT_EQ("[3,?,1]", create(&t));
1110 
1111   t = ::tensorflow::test::AsTensor<int64>({});
1112   EXPECT_EQ("[]", create(&t));
1113 
1114   // Test negative scalar
1115   t = ::tensorflow::test::AsScalar<int32>(-1);
1116   EXPECT_EQ("?", create(&t));
1117 
1118   t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1119   EXPECT_CONTAINS(create(&t),
1120                   "Input tensor must be int32 or int64, but was float");
1121 
1122   t = ::tensorflow::test::AsScalar<int32>(1);
1123   auto s_scalar = create(&t);
1124   EXPECT_CONTAINS(
1125       s_scalar,
1126       "Input tensor must be rank 1, or if its rank 0 it must have value -1");
1127 
1128   t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1129   auto s_matrix = create(&t);
1130   EXPECT_CONTAINS(s_matrix, "Input tensor must be rank 1, but was rank 2");
1131 
1132   // Test negative values for the dims.
1133   t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
1134   EXPECT_CONTAINS(create(&t), "Invalid value in tensor used for shape: -2");
1135 
1136   // Test negative values for the dims.
1137   t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1138   EXPECT_CONTAINS(create(&t), "Invalid value in tensor used for shape: -2");
1139 
1140   // Test when the input shape is wrong.
1141   {
1142     NodeDef def;
1143     InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1144                        {}, {});
1145     ShapeHandle out;
1146     EXPECT_EQ("Shape must be rank 1 but is rank 2",
1147               c.MakeShapeFromShapeTensor(0, &out).error_message());
1148   }
1149 }
1150 
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1151 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1152   NodeDef def;
1153   std::vector<ShapeHandle> empty;
1154   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1155 
1156   // With an unknown rank.
1157   ShapeHandle out;
1158   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1159   EXPECT_EQ("?", c.DebugString(out));
1160 
1161   // With a known rank.
1162   TF_ASSERT_OK(
1163       c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1164   EXPECT_EQ("[0]", c.DebugString(out));
1165   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1166       PartialTensorShape({0, -1, 1000}), &out));
1167   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1168 }
1169 
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1170 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1171   NodeDef def;
1172   std::vector<ShapeHandle> empty;
1173   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1174 
1175   ShapeHandle out;
1176   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1177   EXPECT_EQ("[]", c.DebugString(out));
1178   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1179   EXPECT_EQ("[0]", c.DebugString(out));
1180   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1181   EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1182 }
1183 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1184 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1185   NodeDef def;
1186   std::vector<ShapeHandle> empty;
1187   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1188   TensorShapeProto proto;
1189 
1190   // With a set unknown rank.
1191   ShapeHandle out;
1192   proto.set_unknown_rank(true);
1193   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1194   EXPECT_EQ("?", c.DebugString(out));
1195   proto.add_dim()->set_size(0);
1196   EXPECT_CONTAINS(c.MakeShapeFromShapeProto(proto, &out).error_message(),
1197                   "An unknown shape must not have any dimensions set");
1198   EXPECT_FALSE(IsSet(out));
1199 
1200   // With known rank.
1201   proto.set_unknown_rank(false);
1202   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1203   EXPECT_EQ("[0]", c.DebugString(out));
1204   proto.add_dim()->set_size(-1);
1205   proto.add_dim()->set_size(1000);
1206   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1207   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1208 
1209   // With invalid dimension value.
1210   proto.add_dim()->set_size(-2);
1211   EXPECT_CONTAINS(c.MakeShapeFromShapeProto(proto, &out).error_message(),
1212                   "Shape [0,?,1000,-2] has dimensions with values below -1 "
1213                   "(where -1 means unknown)");
1214 
1215   EXPECT_FALSE(IsSet(out));
1216 }
1217 
TEST_F(ShapeInferenceTest,MakeDim)1218 TEST_F(ShapeInferenceTest, MakeDim) {
1219   NodeDef def;
1220   std::vector<ShapeHandle> empty;
1221   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1222 
1223   auto d0 = c.MakeDim(1);
1224   auto d1 = c.MakeDim(1);
1225   auto d2 = c.MakeDim(2);
1226   EXPECT_EQ("1", c.DebugString(d0));
1227   EXPECT_EQ("1", c.DebugString(d1));
1228   EXPECT_FALSE(SameHandle(d0, d1));
1229   EXPECT_EQ("2", c.DebugString(d2));
1230 }
1231 
TEST_F(ShapeInferenceTest,UnknownDim)1232 TEST_F(ShapeInferenceTest, UnknownDim) {
1233   NodeDef def;
1234   std::vector<ShapeHandle> empty;
1235   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1236 
1237   auto d0 = c.UnknownDim();
1238   auto d1 = c.UnknownDim();
1239   EXPECT_EQ("?", c.DebugString(d0));
1240   EXPECT_EQ("?", c.DebugString(d1));
1241   EXPECT_FALSE(SameHandle(d0, d1));
1242 }
1243 
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1244 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1245   NodeDef def;
1246   std::vector<ShapeHandle> empty;
1247   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1248 
1249   auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1250   EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1251 
1252   auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1253   EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1254 }
1255 
TEST_F(ShapeInferenceTest,InputTensors)1256 TEST_F(ShapeInferenceTest, InputTensors) {
1257   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1258   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1259   NodeDef def;
1260   InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1261                      {&t1, &t2}, {}, {});
1262 
1263   EXPECT_TRUE(c.input_tensor(0) == &t1);
1264   EXPECT_TRUE(c.input_tensor(1) == &t2);
1265   EXPECT_TRUE(c.input_tensor(2) == nullptr);
1266 }
1267 
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1268 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1269   Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1270   Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1271   NodeDef def;
1272   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1273                      {}, {});
1274 
1275   DimensionHandle d;
1276   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1277   EXPECT_EQ("20", c.DebugString(d));
1278 
1279   EXPECT_CONTAINS(c.MakeDimForScalarInput(1, &d).error_message(),
1280                   "Dimension size, given by scalar input 1, must be "
1281                   "non-negative but is -1");
1282 
1283   // Same tests, with int64 values.
1284   t1 = tensorflow::test::AsScalar<int64>(20);
1285   t2 = tensorflow::test::AsScalar<int64>(-1);
1286   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1287   EXPECT_EQ("20", c.DebugString(d));
1288 
1289   EXPECT_CONTAINS(c.MakeDimForScalarInput(1, &d).error_message(),
1290                   "Dimension size, given by scalar input 1, must be "
1291                   "non-negative but is -1");
1292 }
1293 
TEST_F(ShapeInferenceTest,GetAttr)1294 TEST_F(ShapeInferenceTest, GetAttr) {
1295   OpRegistrationData op_reg_data;
1296   op_reg_data.op_def = MakeOpDef(0, 2);
1297   NodeDef def;
1298   CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1299             .Attr("foo", "bar")
1300             .Finalize(&def)
1301             .ok());
1302 
1303   std::vector<ShapeHandle> empty;
1304   InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
1305   string value;
1306   EXPECT_TRUE(c.GetAttr("foo", &value).ok());
1307   EXPECT_EQ("bar", value);
1308 }
1309 
TEST_F(ShapeInferenceTest,Divide)1310 TEST_F(ShapeInferenceTest, Divide) {
1311   NodeDef def;
1312   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1313                      {}, {});
1314 
1315   auto s = c.input(0);
1316   auto d_6 = c.Dim(s, 0);
1317   auto d_unknown = c.Dim(s, 1);
1318   auto d_1 = c.Dim(s, 2);
1319   auto d_2 = c.Dim(s, 3);
1320   auto d_0 = c.Dim(s, 4);
1321   bool evenly_divisible = true;
1322 
1323   // Dividing unknown by non-1 gives new unknown.
1324   DimensionHandle out;
1325   EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok());
1326   EXPECT_EQ("?", c.DebugString(out));
1327   EXPECT_FALSE(SameHandle(out, d_unknown));
1328 
1329   // Dividing anything by 1 returns the input.
1330   EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok());
1331   EXPECT_TRUE(SameHandle(out, d_unknown));
1332   EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
1333   EXPECT_TRUE(SameHandle(out, d_6));
1334   EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
1335   EXPECT_TRUE(SameHandle(out, d_unknown));
1336   EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
1337   EXPECT_TRUE(SameHandle(out, d_6));
1338 
1339   EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
1340   EXPECT_EQ("3", c.DebugString(out));
1341   EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
1342   EXPECT_EQ("3", c.DebugString(out));
1343 
1344   EXPECT_CONTAINS(c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
1345                   "Dimension size must be evenly divisible by 5 but is 6");
1346 
1347   EXPECT_CONTAINS(c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1348                   "Divisor must be positive but is 0");
1349   EXPECT_CONTAINS(c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
1350                   "Divisor must be positive but is 0");
1351 
1352   EXPECT_CONTAINS(c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1353                   "Divisor must be positive but is -1");
1354 
1355   // Repeat error cases above with evenly_divisible=false.
1356   evenly_divisible = false;
1357   EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
1358   EXPECT_EQ("1", c.DebugString(out));
1359 
1360   EXPECT_CONTAINS(c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1361                   "Divisor must be positive but is 0");
1362 
1363   EXPECT_CONTAINS(c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1364                   "Divisor must be positive but is -1");
1365 }
1366 
TEST_F(ShapeInferenceTest,Add)1367 TEST_F(ShapeInferenceTest, Add) {
1368   NodeDef def;
1369   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1370                      {});
1371 
1372   auto s = c.input(0);
1373   auto d_6 = c.Dim(s, 0);
1374   auto d_unknown = c.Dim(s, 1);
1375   auto d_0 = c.Dim(s, 2);
1376 
1377   // Adding non-zero to unknown gives new unknown.
1378   DimensionHandle out;
1379   EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
1380   EXPECT_EQ("?", c.DebugString(out));
1381   EXPECT_FALSE(SameHandle(out, d_unknown));
1382 
1383   // Adding 0 to anything gives input.
1384   EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok());
1385   EXPECT_TRUE(SameHandle(out, d_unknown));
1386   EXPECT_TRUE(c.Add(d_6, 0, &out).ok());
1387   EXPECT_TRUE(SameHandle(out, d_6));
1388 
1389   // Adding dimension with value 0 to anything gives input.
1390   EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
1391   EXPECT_TRUE(SameHandle(out, d_unknown));
1392   EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
1393   EXPECT_TRUE(SameHandle(out, d_6));
1394 
1395   // Test addition.
1396   EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
1397   EXPECT_EQ("8", c.DebugString(out));
1398   EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
1399   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1400 
1401   // Test addition using dimension as second value.
1402   EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok());
1403   EXPECT_EQ("8", c.DebugString(out));
1404   EXPECT_TRUE(
1405       c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok());
1406   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1407   EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
1408   EXPECT_EQ("?", c.DebugString(out));
1409   EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
1410   EXPECT_TRUE(SameHandle(out, d_6));
1411 
1412   EXPECT_CONTAINS(
1413       c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
1414       "Dimension size overflow from adding 6 and 9223372036854775802");
1415 }
1416 
TEST_F(ShapeInferenceTest,Subtract)1417 TEST_F(ShapeInferenceTest, Subtract) {
1418   NodeDef def;
1419   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
1420                      {});
1421 
1422   auto s = c.input(0);
1423   auto d_6 = c.Dim(s, 0);
1424   auto d_unknown = c.Dim(s, 1);
1425   auto d_0 = c.Dim(s, 2);
1426   auto d_5 = c.Dim(s, 3);
1427 
1428   // Subtracting non-zero from unknown gives new unknown.
1429   DimensionHandle out;
1430   EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
1431   EXPECT_EQ("?", c.DebugString(out));
1432   EXPECT_FALSE(SameHandle(out, d_unknown));
1433 
1434   // Subtracting 0 from anything gives input.
1435   EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
1436   EXPECT_TRUE(SameHandle(out, d_unknown));
1437   EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
1438   EXPECT_TRUE(SameHandle(out, d_6));
1439 
1440   // Subtracting dimension with value 0 from anything gives input.
1441   EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
1442   EXPECT_TRUE(SameHandle(out, d_unknown));
1443   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
1444   EXPECT_TRUE(SameHandle(out, d_6));
1445 
1446   // Test subtraction.
1447   EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
1448   EXPECT_EQ("4", c.DebugString(out));
1449   EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
1450   EXPECT_EQ("0", c.DebugString(out));
1451 
1452   // Test subtraction using dimension as second value.
1453   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
1454   EXPECT_EQ("4", c.DebugString(out));
1455   EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
1456   EXPECT_EQ("1", c.DebugString(out));
1457   EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
1458   EXPECT_EQ("?", c.DebugString(out));
1459   EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
1460   EXPECT_TRUE(SameHandle(out, d_6));
1461 
1462   EXPECT_CONTAINS(c.Subtract(d_5, d_6, &out).error_message(),
1463                   "Negative dimension size caused by subtracting 6 from 5");
1464 }
1465 
TEST_F(ShapeInferenceTest,Multiply)1466 TEST_F(ShapeInferenceTest, Multiply) {
1467   NodeDef def;
1468   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
1469                      {});
1470 
1471   auto s = c.input(0);
1472   auto d_6 = c.Dim(s, 0);
1473   auto d_unknown = c.Dim(s, 1);
1474   auto d_0 = c.Dim(s, 2);
1475   auto d_1 = c.Dim(s, 3);
1476 
1477   // Multiplying non-zero to unknown gives new unknown.
1478   DimensionHandle out;
1479   EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
1480   EXPECT_EQ("?", c.DebugString(out));
1481 
1482   // Multiplying 0 to anything gives 0.
1483   EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok());
1484   EXPECT_EQ("0", c.DebugString(out));
1485   EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
1486   EXPECT_EQ("0", c.DebugString(out));
1487   EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
1488   EXPECT_EQ("0", c.DebugString(out));
1489 
1490   // Multiplying 1 to anything gives the original.
1491   // (unknown -> unknown)
1492   EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
1493   EXPECT_TRUE(SameHandle(d_unknown, out));
1494   EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
1495   EXPECT_TRUE(SameHandle(d_unknown, out));
1496   EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
1497   EXPECT_TRUE(SameHandle(d_unknown, out));
1498   // (known -> known)
1499   EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
1500   EXPECT_TRUE(SameHandle(d_6, out));
1501   EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
1502   EXPECT_TRUE(SameHandle(d_6, out));
1503   EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
1504   EXPECT_TRUE(SameHandle(d_6, out));
1505 
1506   // Test multiplication.
1507   EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
1508   EXPECT_EQ("12", c.DebugString(out));
1509   EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
1510   EXPECT_EQ("36", c.DebugString(out));
1511 
1512   // Test multiplication using dimension as second value.
1513   EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
1514   EXPECT_EQ("12", c.DebugString(out));
1515   EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
1516   EXPECT_EQ("?", c.DebugString(out));
1517 }
1518 
TEST_F(ShapeInferenceTest,FullyDefined)1519 TEST_F(ShapeInferenceTest, FullyDefined) {
1520   NodeDef def;
1521   std::vector<ShapeHandle> empty;
1522   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1523 
1524   // No rank or missing dimension information should return false.
1525   EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1526   EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1527 
1528   // Return true if all information exists.
1529   EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1530   EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1531 }
1532 
TEST_F(ShapeInferenceTest,Min)1533 TEST_F(ShapeInferenceTest, Min) {
1534   NodeDef def;
1535   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
1536                      {});
1537 
1538   auto s = c.input(0);
1539   auto d_1 = c.Dim(s, 0);
1540   auto d_2 = c.Dim(s, 1);
1541   auto d_unknown = c.Dim(s, 2);
1542   auto d_0 = c.Dim(s, 3);
1543 
1544   // Minimum involving zero and unknown returns zero.
1545   DimensionHandle out;
1546   EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
1547   EXPECT_TRUE(SameHandle(d_0, out));
1548   EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
1549   EXPECT_TRUE(SameHandle(d_0, out));
1550   EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
1551   EXPECT_EQ("0", c.DebugString(out));
1552   EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
1553   EXPECT_EQ("0", c.DebugString(out));
1554 
1555   // Minimum involving unknowns and non-zeros gives new unknown.
1556   EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
1557   EXPECT_EQ("?", c.DebugString(out));
1558   EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
1559   EXPECT_EQ("?", c.DebugString(out));
1560   EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
1561   EXPECT_EQ("?", c.DebugString(out));
1562 
1563   // Minimum with constant second arg.
1564   EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
1565   EXPECT_TRUE(SameHandle(d_1, out));
1566   EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
1567   EXPECT_TRUE(SameHandle(d_1, out));
1568   EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
1569   EXPECT_EQ("1", c.DebugString(out));
1570 
1571   // Minimum with two dimensions.
1572   EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
1573   EXPECT_TRUE(SameHandle(d_1, out));
1574   EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
1575   EXPECT_TRUE(SameHandle(d_1, out));
1576   EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
1577   EXPECT_TRUE(SameHandle(d_1, out));
1578   EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
1579   EXPECT_TRUE(SameHandle(d_2, out));
1580 }
1581 
TEST_F(ShapeInferenceTest,Max)1582 TEST_F(ShapeInferenceTest, Max) {
1583   NodeDef def;
1584   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1585                      {});
1586 
1587   auto s = c.input(0);
1588   auto d_1 = c.Dim(s, 0);
1589   auto d_2 = c.Dim(s, 1);
1590   auto d_unknown = c.Dim(s, 2);
1591 
1592   // Maximum involving unknowns gives new unknown.
1593   DimensionHandle out;
1594   EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
1595   EXPECT_EQ("?", c.DebugString(out));
1596   EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
1597   EXPECT_EQ("?", c.DebugString(out));
1598   EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
1599   EXPECT_EQ("?", c.DebugString(out));
1600 
1601   // Maximum with constant second arg.
1602   EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
1603   EXPECT_TRUE(SameHandle(d_1, out));
1604   EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
1605   EXPECT_TRUE(SameHandle(d_2, out));
1606   EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
1607   EXPECT_EQ("3", c.DebugString(out));
1608 
1609   // Maximum with two dimensions.
1610   EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
1611   EXPECT_TRUE(SameHandle(d_1, out));
1612   EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
1613   EXPECT_TRUE(SameHandle(d_2, out));
1614   EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
1615   EXPECT_TRUE(SameHandle(d_2, out));
1616   EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
1617   EXPECT_TRUE(SameHandle(d_2, out));
1618 }
1619 
TestMergeHandles(bool input_not_output)1620 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
1621   NodeDef def;
1622   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1623                      {});
1624   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1625     ShapeHandle s;
1626     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1627     return s;
1628   };
1629   auto get_shapes_and_types_from_context = [&](int idx) {
1630     if (input_not_output) {
1631       return c.input_handle_shapes_and_types(idx);
1632     } else {
1633       return c.output_handle_shapes_and_types(idx);
1634     }
1635   };
1636   auto merge_shapes_and_types_to_context =
1637       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1638         if (input_not_output) {
1639           return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1640         } else {
1641           return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1642         }
1643       };
1644 
1645   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1646   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1647 
1648   // First merge will take the input completely.
1649   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1650                               {c.UnknownShape(), DT_INVALID},
1651                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1652   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1653   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1654   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1655   ASSERT_EQ(3, v.size());
1656   for (int i = 0; i < v.size(); ++i) {
1657     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1658     EXPECT_EQ(t[i].dtype, v[i].dtype);
1659   }
1660 
1661   // Merge that fails because wrong number of values passed.
1662   // Fails, and no changes made.
1663   ASSERT_FALSE(merge_shapes_and_types_to_context(
1664       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1665   v = *get_shapes_and_types_from_context(0);
1666   ASSERT_EQ(3, v.size());
1667   for (int i = 0; i < v.size(); ++i) {
1668     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1669     EXPECT_EQ(t[i].dtype, v[i].dtype);
1670   }
1671 
1672   // Only difference is in a mismatched shape. That is ignored,
1673   // and there are no other changes, so nothing is done.
1674   //
1675   // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1676   // return an error (separate error from 'refined' output)?
1677   auto t2 = t;
1678   t2[2].shape = make_shape({4, 3, 4, 1});
1679   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1680   v = *get_shapes_and_types_from_context(0);
1681   ASSERT_EQ(3, v.size());
1682   for (int i = 0; i < v.size(); ++i) {
1683     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1684     EXPECT_EQ(t[i].dtype, v[i].dtype);
1685   }
1686 
1687   // Only difference is in a mismatched dtype, but that cannot be
1688   // updated unless original dtype is DT_INVALID.
1689   t2 = t;
1690   t2[2].dtype = DT_FLOAT;
1691   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1692   v = *get_shapes_and_types_from_context(0);
1693   ASSERT_EQ(3, v.size());
1694   for (int i = 0; i < v.size(); ++i) {
1695     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1696     EXPECT_EQ(t[i].dtype, v[i].dtype);
1697   }
1698 
1699   // Difference is mergeable (new shape).
1700   t[1].shape = make_shape({1, 10});
1701   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1702   v = *get_shapes_and_types_from_context(0);
1703   ASSERT_EQ(3, v.size());
1704   for (int i = 0; i < v.size(); ++i) {
1705     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1706     EXPECT_EQ(t[i].dtype, v[i].dtype);
1707   }
1708 
1709   // Difference is mergeable (new type).
1710   t[1].dtype = DT_DOUBLE;
1711   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1712   v = *get_shapes_and_types_from_context(0);
1713   ASSERT_EQ(3, v.size());
1714   for (int i = 0; i < v.size(); ++i) {
1715     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1716     EXPECT_EQ(t[i].dtype, v[i].dtype);
1717   }
1718 
1719   // No difference.
1720   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1721 }
1722 
TEST_F(ShapeInferenceTest,MergeInputHandleShapesAndTypes)1723 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
1724   TestMergeHandles(true /* input_not_output */);
1725 }
1726 
TEST_F(ShapeInferenceTest,MergeOutputHandleShapesAndTypes)1727 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
1728   TestMergeHandles(false /* input_not_output */);
1729 }
1730 
TestRelaxHandles(bool input_not_output)1731 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
1732   NodeDef def;
1733   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1734                      {});
1735   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1736     ShapeHandle s;
1737     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1738     return s;
1739   };
1740   auto get_shapes_and_types_from_context = [&](int idx) {
1741     if (input_not_output) {
1742       return c.input_handle_shapes_and_types(idx);
1743     } else {
1744       return c.output_handle_shapes_and_types(idx);
1745     }
1746   };
1747   auto relax_shapes_and_types_to_context =
1748       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1749         if (input_not_output) {
1750           return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1751         } else {
1752           return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1753         }
1754       };
1755 
1756   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1757   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1758 
1759   // First relax will take the input completely.
1760   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1761                               {c.UnknownShape(), DT_INVALID},
1762                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1763   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1764   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1765   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1766   ASSERT_EQ(3, v.size());
1767   for (int i = 0; i < v.size(); ++i) {
1768     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1769     EXPECT_EQ(t[i].dtype, v[i].dtype);
1770   }
1771 
1772   // Relax that fails because wrong number of values passed.
1773   // Fails, and no changes made.
1774   ASSERT_FALSE(relax_shapes_and_types_to_context(
1775       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1776   v = *get_shapes_and_types_from_context(0);
1777   ASSERT_EQ(3, v.size());
1778   for (int i = 0; i < v.size(); ++i) {
1779     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1780     EXPECT_EQ(t[i].dtype, v[i].dtype);
1781   }
1782 
1783   // Only difference is in a mismatched shape. This should replace
1784   // the mismatched dimension with an UnknownDim.
1785   auto t2 = t;
1786   t2[2].shape = make_shape({4, 3, 4, 1});
1787   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1788   v = *get_shapes_and_types_from_context(0);
1789   EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1790   for (int i = 0; i < v.size(); ++i) {
1791     EXPECT_EQ(t[i].dtype, v[i].dtype);
1792   }
1793 
1794   // Only difference is in a mismatched dtype, but that cannot be
1795   // updated unless original dtype is DT_INVALID.
1796   t2 = t;
1797   t2[2].dtype = DT_FLOAT;
1798   ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1799   v = *get_shapes_and_types_from_context(0);
1800   ASSERT_EQ(3, v.size());
1801   for (int i = 0; i < v.size(); ++i) {
1802     EXPECT_EQ(t[i].dtype, v[i].dtype);
1803   }
1804 
1805   // Difference is a new shape, which will result in a new UnknownShape.
1806   t[1].shape = make_shape({1, 10});
1807   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1808   v = *get_shapes_and_types_from_context(0);
1809   ASSERT_EQ(3, v.size());
1810   EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1811   EXPECT_EQ("?", c.DebugString(v[1].shape));
1812   for (int i = 0; i < v.size(); ++i) {
1813     EXPECT_EQ(t[i].dtype, v[i].dtype);
1814   }
1815 
1816   // Difference is relaxable (new type).
1817   t[1].dtype = DT_DOUBLE;
1818   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1819   v = *get_shapes_and_types_from_context(0);
1820   EXPECT_EQ(t[1].dtype, v[1].dtype);
1821 }
1822 
TEST_F(ShapeInferenceTest,RelaxInputHandleShapesAndTypes)1823 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) {
1824   TestRelaxHandles(true /* input_not_output */);
1825 }
1826 
TEST_F(ShapeInferenceTest,RelaxOutputHandleShapesAndTypes)1827 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) {
1828   TestRelaxHandles(false /* input_not_output */);
1829 }
1830 
1831 }  // namespace shape_inference
1832 }  // namespace tensorflow
1833