• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40 
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44 
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48 
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50 
51 REGISTER_OP("TestOpWithNoInferenceFn")
52     .Input("x: float")
53     .Output("y: float")
54     .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59 
60 class GraphPropertiesTest : public ::testing::Test {
61  public:
SetUp()62   void SetUp() override {
63     // Provision a single machine with 3 cpu cores
64     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65     TF_ASSERT_OK(cluster_->Provision());
66 
67     // This function is simply
68     // out = Fill(shape, value), but
69     // Fill requires values in the shape input, not just shape of it, to infer
70     // output shape.
71     auto f = FunctionDefHelper::Create(
72         // Name
73         "MyFillFunc",
74         // Inputs
75         {"shape: int32", "value: float"},
76         // Outputs
77         {"out: float"},
78         // Attrs
79         {},
80         // Nodes
81         {
82             {{"a"},
83              "Fill",
84              {"shape", "value"},
85              {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86         },
87         // Returns
88         {{"out", "a:output:0"}});
89     function_lib_.add_function()->Swap(&f);
90   }
91 
TearDown()92   void TearDown() override {
93     TF_ASSERT_OK(cluster_->Shutdown());
94     cluster_.reset();
95   }
96 
97  protected:
98   // Returns a string form of <p>, suitable for comparing type and shape.
99   // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100   string PropToString(const OpInfo::TensorProperties& p) {
101     string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102     if (p.shape().unknown_rank()) {
103       strings::StrAppend(&s, "?");
104     } else {
105       strings::StrAppend(&s, "[");
106       for (int i = 0; i < p.shape().dim_size(); ++i) {
107         strings::StrAppend(&s, i == 0 ? "" : ",",
108                            std::max<int64>(p.shape().dim(i).size(), -1));
109       }
110       strings::StrAppend(&s, "]");
111     }
112     return s;
113   }
114 
115   // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116   // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)117   void ExpectTensorValues(const std::vector<int64>& expected,
118                           const TensorProto& tensor_proto_to_compare) {
119     Tensor tensor;
120     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121     EXPECT_EQ(expected.size(), tensor.NumElements());
122     // We're interested in only integer tensors as only shapes are exported as
123     // graph properties values.
124     ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125     if (tensor.dtype() == DT_INT32) {
126       for (int i = 0; i < tensor.NumElements(); i++) {
127         EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128       }
129     } else {
130       for (int i = 0; i < tensor.NumElements(); i++) {
131         EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
132       }
133     }
134   }
135 
136   // Compare values of float (DT_FLOAT) tensor against expected
137   // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138   void ExpectFloatTensorValues(const std::vector<float>& expected,
139                                const TensorProto& tensor_proto_to_compare) {
140     Tensor tensor;
141     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142     EXPECT_EQ(expected.size(), tensor.NumElements());
143     ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144     for (int i = 0; i < tensor.NumElements(); i++) {
145       EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146     }
147   }
148 
149   std::unique_ptr<SingleMachine> cluster_;
150   FunctionDefLibrary function_lib_;
151 };
152 
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155                                           cluster_->GetDeviceNames());
156   GrapplerItem item;
157   CHECK(fake_input.NextItem(&item));
158 
159   GraphProperties properties(item);
160   Status s = properties.InferStatically(true);
161   TF_ASSERT_OK(s);
162 
163   for (const auto& node : item.graph.node()) {
164     if (node.op() == "RandomStandardNormal") {
165       // The node has one input (the shape of the tensor to generate).
166       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167       // The const node has one output.
168       const auto props = properties.GetOutputProperties(node.name());
169       EXPECT_EQ(1, props.size());
170       const OpInfo::TensorProperties& prop = props[0];
171       EXPECT_EQ(DT_FLOAT, prop.dtype());
172       EXPECT_FALSE(prop.shape().unknown_rank());
173       EXPECT_EQ(2, prop.shape().dim_size());
174       EXPECT_EQ(10, prop.shape().dim(0).size());
175       EXPECT_EQ(1, prop.shape().dim(1).size());
176     } else if (node.op() == "AddN") {
177       const auto in_props = properties.GetInputProperties(node.name());
178       EXPECT_EQ(1, in_props.size());
179       const OpInfo::TensorProperties& in_prop = in_props[0];
180       EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181       EXPECT_FALSE(in_prop.shape().unknown_rank());
182       EXPECT_EQ(2, in_prop.shape().dim_size());
183       EXPECT_EQ(10, in_prop.shape().dim(0).size());
184       EXPECT_EQ(1, in_prop.shape().dim(1).size());
185       const auto out_props = properties.GetOutputProperties(node.name());
186       EXPECT_EQ(1, out_props.size());
187       EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188       EXPECT_EQ(in_prop.shape().DebugString(),
189                 out_props[0].shape().DebugString());
190     }
191   }
192 }
193 
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196                                           cluster_->GetDeviceNames());
197   GrapplerItem item;
198   CHECK(fake_input.NextItem(&item));
199 
200   GraphProperties properties(item);
201   Status s = properties.InferStatically(true);
202   TF_ASSERT_OK(s);
203 
204   for (const auto& node : item.graph.node()) {
205     if (node.op() == "RandomStandardNormal") {
206       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207       const auto props = properties.GetOutputProperties(node.name());
208       properties.ClearOutputProperties(node.name());
209       const auto cleared_props = properties.GetOutputProperties(node.name());
210       EXPECT_TRUE(cleared_props.empty());
211     } else if (node.op() == "AddN") {
212       const auto in_props = properties.GetInputProperties(node.name());
213       EXPECT_EQ(1, in_props.size());
214       properties.ClearInputProperties(node.name());
215       const auto cleared_props = properties.GetInputProperties(node.name());
216       EXPECT_TRUE(cleared_props.empty());
217     }
218   }
219 }
220 
TEST_F(GraphPropertiesTest,DynamicProperties)221 TEST_F(GraphPropertiesTest, DynamicProperties) {
222   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223                                           cluster_->GetDeviceNames());
224   GrapplerItem item;
225   CHECK(fake_input.NextItem(&item));
226 
227   GraphProperties properties(item);
228   TF_ASSERT_OK(cluster_->Initialize(item));
229   Status s = properties.InferDynamically(cluster_.get());
230   TF_ASSERT_OK(s);
231 
232   for (const auto& node : item.graph.node()) {
233     if (node.op() == "RandomStandardNormal") {
234       // The random node is missing from the cost graph (why ?)
235       EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
236     } else if (node.op() == "AddN") {
237       // Since the random node is missing, we can't infer the input properties
238       // of the first AddN node. The other AddN nodes have the expected
239       // properties.
240       if (node.name() == "AddN") {
241         const auto props = properties.GetInputProperties(node.name());
242         EXPECT_EQ(1, props.size());
243         const OpInfo::TensorProperties& prop = props[0];
244         EXPECT_EQ(DT_INVALID, prop.dtype());
245         EXPECT_TRUE(prop.shape().unknown_rank());
246       } else {
247         const auto props = properties.GetInputProperties(node.name());
248         EXPECT_EQ(1, props.size());
249         const OpInfo::TensorProperties& prop = props[0];
250         EXPECT_EQ(DT_FLOAT, prop.dtype());
251         EXPECT_FALSE(prop.shape().unknown_rank());
252         EXPECT_EQ(2, prop.shape().dim_size());
253         EXPECT_EQ(10, prop.shape().dim(0).size());
254         EXPECT_EQ(1, prop.shape().dim(1).size());
255         const auto out_props = properties.GetOutputProperties(node.name());
256 #ifdef INTEL_MKL
257         if (!NativeFormatEnabled()) {
258           // Intel MKL AddN OP would have two output.
259           // One is the real output, another one for MKL metadata
260           EXPECT_EQ(2, out_props.size());
261         } else {
262           EXPECT_EQ(1, out_props.size());
263         }
264 #else
265         EXPECT_EQ(1, out_props.size());
266 #endif  // INTEL_MKL
267         string prop_str;
268         ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
269         string out_prop_str;
270         ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
271                                                           &out_prop_str);
272         EXPECT_EQ(prop_str, out_prop_str);
273       }
274     }
275   }
276 }
277 
TEST_F(GraphPropertiesTest,Variables)278 TEST_F(GraphPropertiesTest, Variables) {
279   GrapplerItem item;
280   TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
281                    .Attr("dtype", DT_FLOAT)
282                    .Attr("shape", TensorShape({3, 7}))
283                    .Finalize(item.graph.add_node()));
284   item.fetch.push_back("Var");
285 
286   Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
287   test::FillIota<float>(&initial_val, 0);
288   TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
289                    .Attr("dtype", DT_FLOAT)
290                    .Attr("value", initial_val)
291                    .Finalize(item.graph.add_node()));
292   TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
293                    .Input("Var", 0, DT_FLOAT_REF)
294                    .Input("InitialVal", 0, DT_FLOAT)
295                    .Finalize(item.graph.add_node()));
296   item.init_ops.push_back("InitVar");
297 
298   {
299     GraphProperties static_properties(item);
300     TF_ASSERT_OK(static_properties.InferStatically(false));
301 
302     const auto props = static_properties.GetOutputProperties("Var");
303     EXPECT_EQ(1, props.size());
304     const OpInfo::TensorProperties& prop = props[0];
305     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
306     EXPECT_FALSE(prop.shape().unknown_rank());
307     EXPECT_EQ(2, prop.shape().dim_size());
308     EXPECT_EQ(3, prop.shape().dim(0).size());
309     EXPECT_EQ(7, prop.shape().dim(1).size());
310   }
311   {
312     TF_ASSERT_OK(cluster_->Initialize(item));
313     GraphProperties dynamic_properties(item);
314     TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
315 
316     const auto props = dynamic_properties.GetOutputProperties("Var");
317     EXPECT_EQ(1, props.size());
318     const OpInfo::TensorProperties& prop = props[0];
319     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
320     EXPECT_FALSE(prop.shape().unknown_rank());
321     EXPECT_EQ(2, prop.shape().dim_size());
322     EXPECT_EQ(3, prop.shape().dim(0).size());
323     EXPECT_EQ(7, prop.shape().dim(1).size());
324   }
325 }
326 
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)327 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
328   GrapplerItem item;
329   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
330                    .Attr("dtype", DT_FLOAT)
331                    .Attr("shape", TensorShape({3, 7}))
332                    .Finalize(item.graph.add_node()));
333   TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
334                    .Attr("T", DT_RESOURCE)
335                    .Attr("frame_name", "while_context")
336                    .Attr("is_constant", true)
337                    .Attr("parallel_iterations", 10)
338                    .Input("Var", 0, DT_RESOURCE)
339                    .Finalize(item.graph.add_node()));
340   TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
341                    .Attr("dtype", DT_FLOAT)
342                    .Input("Enter", 0, DT_RESOURCE)
343                    .Finalize(item.graph.add_node()));
344 
345   GraphProperties properties(item);
346   TF_ASSERT_OK(properties.InferStatically(false));
347   const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
348   EXPECT_EQ(1, props.size());
349   const OpInfo::TensorProperties& prop = props[0];
350   EXPECT_EQ(DT_FLOAT, prop.dtype());
351   EXPECT_FALSE(prop.shape().unknown_rank());
352   EXPECT_EQ(2, prop.shape().dim_size());
353   EXPECT_EQ(3, prop.shape().dim(0).size());
354   EXPECT_EQ(7, prop.shape().dim(1).size());
355 }
356 
TEST_F(GraphPropertiesTest,VarHandles)357 TEST_F(GraphPropertiesTest, VarHandles) {
358   GrapplerItem item;
359   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
360                    .Attr("dtype", DT_FLOAT)
361                    .Attr("shape", TensorShape({3, 7}))
362                    .Finalize(item.graph.add_node()));
363 
364   TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
365                    .Attr("dtype", DT_FLOAT)
366                    .Input("Var", 0, DT_RESOURCE)
367                    .Finalize(item.graph.add_node()));
368 
369   GraphProperties properties(item);
370   TF_ASSERT_OK(properties.InferStatically(false));
371 
372   const auto props = properties.GetOutputProperties("VarRead");
373   EXPECT_EQ(1, props.size());
374   const OpInfo::TensorProperties& prop = props[0];
375   EXPECT_EQ(DT_FLOAT, prop.dtype());
376   EXPECT_FALSE(prop.shape().unknown_rank());
377   EXPECT_EQ(2, prop.shape().dim_size());
378   EXPECT_EQ(3, prop.shape().dim(0).size());
379   EXPECT_EQ(7, prop.shape().dim(1).size());
380 }
381 
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)382 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
383   // Test graph is first generated in python using:
384   /*
385     i0 = tf.constant(0)
386     v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
387     def cond(i, x):
388       return i < 3
389     def body(i, x):
390       return i + 1, x + x
391     v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
392   */
393   // and then modified by hand such that the ReadVariableOp is inside the loop
394   // body instead of outside the while loop (which is the case when constructed
395   // using the python API), such that we have the following pattern: VarHandleOp
396   // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
397   // DT_RESOURCE is passed all the way until ReadVariableOp.
398   GrapplerItem item;
399   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
400                                  "while_loop_var_handle_op.pbtxt");
401   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
402   GraphProperties properties(item);
403   TF_ASSERT_OK(properties.InferStatically(false));
404 
405   std::vector<string> resource_nodes{
406       "loop_var",       "while/Enter",         "while/Merge", "while/Switch",
407       "while/Identity", "while/NextIteration", "while/Exit"};
408   for (const string& node : resource_nodes) {
409     const auto props = properties.GetOutputProperties(node);
410     EXPECT_GE(props.size(), 1);  // Merge has 2 outputs.
411     EXPECT_EQ("resource: []", PropToString(props[0]));
412   }
413 
414   // After ReadVariableOp, the shape should be recovered.
415   const auto props = properties.GetOutputProperties("while/ReadVariableOp");
416   EXPECT_EQ(1, props.size());
417   EXPECT_EQ("int32: []", PropToString(props[0]));
418 }
419 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)420 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
421   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
422   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
423   auto dequeue1 =
424       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
425 
426   GrapplerItem item;
427   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
428 
429   GraphProperties properties(item);
430   TF_ASSERT_OK(properties.InferStatically(false));
431 
432   const auto props1 = properties.GetOutputProperties("Dequeue1");
433   ASSERT_EQ(1, props1.size());
434   EXPECT_EQ("float: ?", PropToString(props1[0]));
435 }
436 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)437 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
438   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
440                            ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
441   auto dequeue1 =
442       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
443 
444   GrapplerItem item;
445   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
446 
447   GraphProperties properties(item);
448   TF_ASSERT_OK(properties.InferStatically(false));
449 
450   const auto props1 = properties.GetOutputProperties("Dequeue1");
451   ASSERT_EQ(1, props1.size());
452   EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
453 }
454 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)455 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
456   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
457   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
458                            ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
459   auto dequeue1 =
460       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
461 
462   GrapplerItem item;
463   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
464 
465   GraphProperties properties(item);
466   TF_ASSERT_OK(properties.InferStatically(false));
467 
468   const auto props1 = properties.GetOutputProperties("Dequeue1");
469   ASSERT_EQ(1, props1.size());
470   EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
471 }
472 
TEST_F(GraphPropertiesTest,Queues)473 TEST_F(GraphPropertiesTest, Queues) {
474   // Create a graph with known input shapes, and propagate the shapes through a
475   // couple of queues.
476   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
477 
478   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
479   Output rnd =
480       ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
481   Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
482   auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
483   auto dequeue1 =
484       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
485 
486   auto q2 =
487       ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
488   Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
489   auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
490   auto dequeue2 =
491       ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
492 
493   auto q4 =
494       ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
495   auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
496   auto enqueue4_2 =
497       ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
498   auto dequeue4 =
499       ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
500 
501   // Create a queue that takes in three tensors.
502   auto q5 = ops::RandomShuffleQueue(
503       root.WithOpName("Queue5"),
504       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
505   Output rnd2 =
506       ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
507   Output rnd3 =
508       ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
509   auto enqueue5 =
510       ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
511   auto dequeue5 = ops::QueueDequeue(
512       root.WithOpName("Dequeue5"), q5,
513       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
514 
515   GrapplerItem item;
516   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
517 
518   GraphProperties properties(item);
519   TF_ASSERT_OK(properties.InferStatically(false));
520 
521   const auto props1 = properties.GetOutputProperties("Dequeue1");
522   ASSERT_EQ(1, props1.size());
523   EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
524 
525   const auto props2 = properties.GetOutputProperties("Dequeue2");
526   ASSERT_EQ(1, props2.size());
527   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
528 
529   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
530   // that we merge the 2 properly to determine the shape of the data coming out
531   // of the queue.
532   const auto props4 = properties.GetOutputProperties("Dequeue4");
533   ASSERT_EQ(1, props4.size());
534   EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
535 
536   // The dequeue5 op shape is known.
537   const auto props5 = properties.GetOutputProperties("Dequeue5");
538   ASSERT_EQ(3, props5.size());
539   EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
540   EXPECT_EQ("double: [10]", PropToString(props5[1]));
541   EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
542 }
543 
TEST_F(GraphPropertiesTest,MergeWithoutLoops)544 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
545   // Test graph produced in python using:
546   /*
547     with tf.Graph().as_default():
548       x = tf.constant(2)
549       y = tf.constant(5)
550       z = tf.ones([1,1,1])
551       def f1(): return tf.concat([z, z], axis=0)
552       def f2(): return tf.concat([z, z], axis=1)
553       r = tf.cond(tf.less(x, y), f1, f2)
554       tf.concat([r, r], axis=2)
555       with open('/tmp/graph.pbtxt', 'w') as f:
556         f.write(str(tf.get_default_graph().as_graph_def()))
557    */
558 
559   GrapplerItem item;
560   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
561                                  "merge_without_loops.pbtxt");
562   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
563   GraphProperties properties(item);
564   TF_ASSERT_OK(properties.InferStatically(false));
565 
566   std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
567   std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
568                                        "float: [1,2,1]"};
569   for (int i = 0; i < nodes.size(); i++) {
570     const auto props = properties.GetOutputProperties(nodes[i]);
571     const OpInfo::TensorProperties& prop = props[0];
572     EXPECT_EQ(DT_FLOAT, prop.dtype());
573     EXPECT_EQ(expected_outputs[i], PropToString(prop));
574   }
575 
576   // The "Less" node should be fed by 2 int32 scalar constant values.
577   const auto props = properties.GetInputProperties("Less");
578   EXPECT_EQ(2, props.size());
579   for (int i = 0; i < props.size(); ++i) {
580     EXPECT_EQ(DT_INT32, props[i].dtype());
581     EXPECT_TRUE(props[i].has_value());
582     EXPECT_EQ("int32: []", PropToString(props[i]));
583   }
584 }
585 
TEST_F(GraphPropertiesTest,WhileLoop)586 TEST_F(GraphPropertiesTest, WhileLoop) {
587   // Test graph produced in python using:
588   /*
589      with tf.Graph().as_default():
590        i0 = tf.constant(0)
591        m0 = tf.placeholder([-1, 2])
592        c = lambda i, m: i < 10
593        b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
594        r = tf.while_loop(
595               c, b, loop_vars=[i0, m0],
596               shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
597        with open('/tmp/graph.pbtxt', 'w') as f:
598          f.write(str(tf.get_default_graph().as_graph_def()))
599   */
600 
601   GrapplerItem item;
602   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
603                                  "while_loop.pbtxt");
604   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
605   GraphProperties properties(item);
606   TF_ASSERT_OK(properties.InferStatically(false));
607 
608   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
609                             "while/Exit_1"};
610   for (const string& node : nodes) {
611     const auto props = properties.GetOutputProperties(node);
612     const OpInfo::TensorProperties& prop = props[0];
613     EXPECT_EQ(DT_FLOAT, prop.dtype());
614     EXPECT_EQ("float: [-1,2]", PropToString(prop));
615   }
616 
617   // The loop outputs batch dim should be different from the input batch dim
618   // since we concatenated along the batch dim.
619   auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
620   auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
621   EXPECT_GE(-2, shape_in.dim(0).size());
622   EXPECT_GE(-2, shape_out.dim(0).size());
623   EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
624 }
625 
TEST_F(GraphPropertiesTest,NestedLoop)626 TEST_F(GraphPropertiesTest, NestedLoop) {
627   // Test graph produced in python using:
628   /*
629     with tf.Graph().as_default():
630       i0 = tf.constant(0)
631 
632       def inner(j, y):
633         def inner_cond(j, y):
634           return j < 3
635 
636         def inner_body(j, y):
637           return j+1, tf.concat([y, y], axis=2)
638 
639         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
640                              shape_invariants=[i0.get_shape(),
641                                               tf.TensorShape([None, 1, None])])
642 
643       def outer_cond(i, x):
644         return i < 3
645 
646       def outer_body(i, x):
647         j, y = inner(0, x)
648         return i+1, tf.concat([x, x], axis=0)
649 
650       r = tf.while_loop(outer_cond, outer_body,
651                         loop_vars=[i0, tf.ones([1, 1, 1])],
652                         shape_invariants=[i0.get_shape(),
653                                           tf.TensorShape([None, 1, None])])
654 
655       with open('/tmp/graph.pbtxt', 'w') as f:
656         f.write(str(tf.get_default_graph().as_graph_def()))
657   */
658 
659   GrapplerItem item;
660   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
661                                  "nested_loop.pbtxt");
662   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
663   GraphProperties properties(item);
664   TF_ASSERT_OK(properties.InferStatically(false));
665 
666   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
667                                   "while/Exit_1"};
668   std::vector<string> inner_nodes{"while/while/Merge_1",
669                                   "while/while/NextIteration_1",
670                                   "while/while/Exit_1"};
671   for (const string& node : outer_nodes) {
672     const auto props = properties.GetOutputProperties(node);
673     const OpInfo::TensorProperties& prop = props[0];
674     EXPECT_EQ(DT_FLOAT, prop.dtype());
675     EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
676   }
677   for (const string& node : inner_nodes) {
678     const auto props = properties.GetOutputProperties(node);
679     const OpInfo::TensorProperties& prop = props[0];
680     EXPECT_EQ(DT_FLOAT, prop.dtype());
681     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
682   }
683 }
684 
TEST_F(GraphPropertiesTest,LoopsAndQueues)685 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
686   // Test graph produced in python using:
687   /*
688     with tf.Graph().as_default():
689       i0 = tf.constant(0)
690       q = tf.FIFOQueue(1, "float")
691 
692       def inner(j, y):
693         def inner_cond(j, y):
694           return j < 3
695 
696         def inner_body(j, y):
697           return j+1, tf.concat([y, y], axis=0)
698 
699         return tf.while_loop(inner_cond, inner_body,
700                              loop_vars=[j, y],
701                              shape_invariants=[i0.get_shape(),
702                                                tf.TensorShape(None)])
703 
704       def outer_cond(i, x):
705         return i < 3
706 
707       def outer_body(i, x):
708         q.enqueue(x)
709         y = tf.concat([x, x], axis=2)
710         inner(0, q.dequeue())
711         return i+1, y
712 
713       i, z = tf.while_loop(outer_cond, outer_body,
714                            loop_vars=[i0, tf.ones([1, 1, 1])],
715                            shape_invariants=[i0.get_shape(),
716                                              tf.TensorShape([None, 1, None])])
717 
718       with open('/tmp/graph.pbtxt', 'w') as f:
719         f.write(str(tf.get_default_graph().as_graph_def()))
720    */
721 
722   GrapplerItem item;
723   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
724                                  "loops_and_queues.pbtxt");
725   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
726   GraphProperties properties(item);
727   TF_ASSERT_OK(properties.InferStatically(false));
728 
729   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
730                                   "while/Exit_1"};
731   std::vector<string> inner_nodes{"while/while/Merge_1",
732                                   "while/while/NextIteration_1",
733                                   "while/while/Exit_1"};
734   for (const string& node : outer_nodes) {
735     const auto props = properties.GetOutputProperties(node);
736     const OpInfo::TensorProperties& prop = props[0];
737     EXPECT_EQ(DT_FLOAT, prop.dtype());
738     EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
739   }
740   for (const string& node : inner_nodes) {
741     const auto props = properties.GetOutputProperties(node);
742     const OpInfo::TensorProperties& prop = props[0];
743     EXPECT_EQ(DT_FLOAT, prop.dtype());
744     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
745   }
746 }
747 
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)748 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
749   // Test graph produced in python using:
750   /*
751     with tf.Graph().as_default():
752       i0 = tf.constant(0)
753       with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
754         v = tf.get_variable(initializer=i0, name='loop_var')
755 
756       def inner(j, y):
757         def inner_cond(j, y):
758           return j < 3
759 
760         def inner_body(j, y):
761           return j + 1, y + y
762 
763         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
764 
765       def outer_cond(i, x):
766         return i < 3
767 
768       def outer_body(i, x):
769         y = x + x
770         inner(0, v)
771         return i + 1, y
772 
773       v, z = tf.while_loop(outer_cond, outer_body,
774                            loop_vars=[v, tf.constant(1)])
775 
776       with open('/tmp/graph.pbtxt', 'w') as f:
777         f.write(str(tf.get_default_graph().as_graph_def()))
778   */
779 
780   GrapplerItem item;
781   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
782                                  "loops_and_resource_vars.pbtxt");
783   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
784   GraphProperties properties(item);
785   TF_ASSERT_OK(properties.InferStatically(false));
786 
787   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
788                                   "while/Exit_1"};
789   std::vector<string> inner_nodes{"while/while/Merge_1",
790                                   "while/while/NextIteration_1",
791                                   "while/while/Exit_1"};
792   for (const string& node : outer_nodes) {
793     const auto props = properties.GetOutputProperties(node);
794     const OpInfo::TensorProperties& prop = props[0];
795     EXPECT_EQ(DT_INT32, prop.dtype());
796     EXPECT_EQ("int32: []", PropToString(prop));
797   }
798   for (const string& node : inner_nodes) {
799     const auto props = properties.GetOutputProperties(node);
800     const OpInfo::TensorProperties& prop = props[0];
801     EXPECT_EQ(DT_INT32, prop.dtype());
802     EXPECT_EQ("int32: []", PropToString(prop));
803   }
804 }
805 
TEST_F(GraphPropertiesTest,QueuesAndLoops)806 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
807   // Test graph produced in python using:
808   /*
809     with tf.Graph().as_default():
810       i0 = tf.constant(0)
811       q0 = tf.FIFOQueue(1, "float")
812       q0.enqueue(tf.ones([2, 2]))
813       q1 = tf.FIFOQueue(1, "float")
814 
815       def c(i, m):
816         return i < 10
817 
818       def b(i, m):
819         return i+1, tf.concat([m, m], axis=0)
820 
821       i, m = tf.while_loop(
822           c, b, loop_vars=[i0,  q0.dequeue()],
823           shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
824 
825       q1.enqueue(m)
826       v = q1.dequeue();
827       tf.concat([v, v], axis=1)
828       with open('/tmp/graph.pbtxt', 'w') as f:
829         f.write(str(tf.get_default_graph().as_graph_def()))
830   */
831 
832   GrapplerItem item;
833   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
834                                  "queues_and_loops.pbtxt");
835   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
836   GraphProperties properties(item);
837   TF_ASSERT_OK(properties.InferStatically(false));
838 
839   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
840                             "while/Exit_1"};
841 
842   for (const string& node : nodes) {
843     const auto props = properties.GetOutputProperties(node);
844     const OpInfo::TensorProperties& prop = props[0];
845     EXPECT_EQ(DT_FLOAT, prop.dtype());
846     EXPECT_EQ("float: [-1,2]", PropToString(prop));
847   }
848 
849   const auto props = properties.GetOutputProperties("concat");
850   const OpInfo::TensorProperties& prop = props[0];
851   EXPECT_EQ(DT_FLOAT, prop.dtype());
852   EXPECT_EQ("float: [-1,4]", PropToString(prop));
853 }
854 
TEST_F(GraphPropertiesTest,InferRestoreOpShape)855 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
856   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
857   Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
858                              DataType::DT_FLOAT);
859   Output filename =
860       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
861   Output tensor_name =
862       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
863   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
864                                 DataType::DT_FLOAT);
865   Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
866 
867   Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
868                                       string("256 256 0,128:-"), TensorShape());
869   Output restore_slice =
870       ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
871                         shape_and_slice, DataType::DT_FLOAT);
872   Output init_restore_slice =
873       ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
874 
875   Output restore_v2 =
876       ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
877                         shape_and_slice, DataType::DT_FLOAT);
878   Output init_restore_v2 =
879       ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
880 
881   GrapplerItem item;
882   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
883   item.fetch.push_back("init_restore");
884 
885   GraphProperties properties(item);
886   TF_ASSERT_OK(properties.InferStatically(false));
887 
888   const auto restore_props = properties.GetOutputProperties("restore");
889   const OpInfo::TensorProperties& restore_prop = restore_props[0];
890   EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
891   EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
892 
893   const auto restore_slice_props =
894       properties.GetOutputProperties("restore_slice");
895   const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
896   EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
897   EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
898 
899   const auto restorev2_props = properties.GetOutputProperties("restore_v2");
900   const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
901   EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
902   EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
903 
904   // Check input shapes of assign op are propagated correctly.
905   const auto input_props = properties.GetInputProperties("init_restore");
906   ASSERT_EQ(2, input_props.size());
907   const OpInfo::TensorProperties& input_prop = input_props[1];
908   EXPECT_EQ(DT_FLOAT, input_prop.dtype());
909   EXPECT_EQ("float: [128,256]", PropToString(input_prop));
910 }
911 
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)912 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
913   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
914   Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
915                              DataType::DT_FLOAT);
916   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
917                               DataType::DT_FLOAT);
918   Output filename =
919       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
920   Output tensor_name =
921       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
922   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
923                                 DataType::DT_FLOAT);
924   Output init = ops::Assign(s.WithOpName("init"), var, restore);
925   Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
926 
927   GrapplerItem item;
928   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
929   item.fetch.push_back("init");
930   item.fetch.push_back("init2");
931 
932   GraphProperties properties(item);
933   TF_ASSERT_OK(properties.InferStatically(false));
934 
935   const auto props = properties.GetOutputProperties("restore");
936   const OpInfo::TensorProperties& prop = props[0];
937   EXPECT_EQ(DT_FLOAT, prop.dtype());
938   EXPECT_EQ("float: [128,256]", PropToString(prop));
939 }
940 
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)941 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
942   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
943   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
944   Output a1 = ops::Identity(s.WithOpName("a1"), a);
945   Output b = ops::Const(s.WithOpName("b"), 99, {});
946   Output b1 = ops::Identity(s.WithOpName("b1"), b);
947   Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
948   Output c1 = ops::Identity(s.WithOpName("c1"), c);
949 
950   GrapplerItem item;
951   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
952   GraphProperties properties(item);
953   TF_ASSERT_OK(properties.InferStatically(false));
954 
955   // Check output shapes.
956   EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
957   EXPECT_EQ("int32: [2]",
958             PropToString(properties.GetOutputProperties("a1")[0]));
959   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
960   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
961   EXPECT_EQ("int32: [4,4,4]",
962             PropToString(properties.GetOutputProperties("c")[0]));
963   EXPECT_EQ("int32: [4,4,4]",
964             PropToString(properties.GetOutputProperties("c1")[0]));
965 
966   // Check has_value.
967   EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
968   EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
969   EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
970   EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
971   EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
972   EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
973   EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
974   EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
975   // Note that we propagate tensor value of only 1D vector and scalar.
976   EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
977 
978   // Check values.
979   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
980   ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
981   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
982   ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
983   ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
984   ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
985   std::vector<int64> c_values;
986   for (int i = 0; i < 4 * 4 * 4; i++) {
987     c_values.push_back(1);
988   }
989   ExpectTensorValues({c_values},
990                      properties.GetOutputProperties("c")[0].value());
991   ExpectTensorValues({c_values},
992                      properties.GetInputProperties("c1")[0].value());
993   ExpectTensorValues({c_values},
994                      properties.GetOutputProperties("c1")[0].value());
995 }
996 
TEST_F(GraphPropertiesTest,IdentityPassingShape)997 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
998   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
999   Output a = ops::Const(s.WithOpName("a"), 5, {2});
1000   Output b = ops::Identity(s.WithOpName("b"), a);
1001   Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1002   // Fill needs not only e's shape but also the value of e to figure out output
1003   // shape; hence, Identity op (b) should pass a's value as
1004   // output_tensors_as_shape.
1005   Output d = ops::Fill(s.WithOpName("fill"), b, c);
1006 
1007   GrapplerItem item;
1008   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1009   GraphProperties properties(item);
1010   TF_ASSERT_OK(properties.InferStatically(false));
1011   const auto out_props = properties.GetOutputProperties("fill");
1012   const OpInfo::TensorProperties out_prop0 = out_props[0];
1013   EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1014 }
1015 
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1016 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1017   // When using aggressive_shape_inference, we run EvaluateNode() for
1018   // allowlisted ops and small input / output tensors. For instance, Fill op is
1019   // evaluated and produces output tensor value if output tensor size is small
1020   // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1021   // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1022   // initializing a large table using Fill.
1023   {
1024     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025     Output a = ops::Const(s.WithOpName("a"), 4, {2});  // 4x4
1026     Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1027     // Shape described by a is small; expect output values of Fill op.
1028     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1029 
1030     GrapplerItem item;
1031     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1032     GraphProperties properties(item);
1033     TF_ASSERT_OK(properties.InferStatically(
1034         /*assume_valid_feeds=*/false,
1035         /*aggressive_shape_inference=*/true,
1036         /*include_tensor_values=*/true));
1037     const auto out_props = properties.GetOutputProperties("fill");
1038     const OpInfo::TensorProperties out_prop0 = out_props[0];
1039     EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
1040     EXPECT_TRUE(out_prop0.has_value());
1041   }
1042   {
1043     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1044     Output a = ops::Const(s.WithOpName("a"), 1000, {4});  // 1000x1000x1000x1000
1045     Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1046     // Shape described by a is huge; in that case we skip value inference.
1047     // Otherwise, it'd be too much overhead.
1048     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1049 
1050     GrapplerItem item;
1051     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1052     GraphProperties properties(item);
1053     TF_ASSERT_OK(properties.InferStatically(
1054         /*assume_valid_feeds=*/false,
1055         /*aggressive_shape_inference=*/true,
1056         /*include_tensor_values=*/true));
1057     const auto out_props = properties.GetOutputProperties("fill");
1058     const OpInfo::TensorProperties out_prop0 = out_props[0];
1059     EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
1060     EXPECT_FALSE(out_prop0.has_value());
1061   }
1062 }
1063 
TEST_F(GraphPropertiesTest,PackWithConstInput)1064 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1065   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1066   Output a = ops::Const(s.WithOpName("a"), 1, {});
1067   Output b = ops::Const(s.WithOpName("b"), 2, {});
1068   Output c = ops::Const(s.WithOpName("c"), 3, {});
1069   Output d = ops::Const(s.WithOpName("d"), 4, {});
1070   // Note ops::Stack instantiates Pack op.
1071   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1072   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1073   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1074   // Fill needs not only e's shape but also its value to figure out output
1075   // shape.
1076   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1077 
1078   GrapplerItem item;
1079   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1080   GraphProperties properties(item);
1081   TF_ASSERT_OK(properties.InferStatically(false));
1082   const auto out_props = properties.GetOutputProperties("fill");
1083   const OpInfo::TensorProperties out_prop0 = out_props[0];
1084   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1085 }
1086 
TEST_F(GraphPropertiesTest,RankOp)1087 TEST_F(GraphPropertiesTest, RankOp) {
1088   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1089   Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1090   Output r = ops::Rank(s.WithOpName("Rank"), c);
1091   Output i = ops::Identity(s.WithOpName("Identity"), r);
1092 
1093   GrapplerItem item;
1094   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1095   GraphProperties properties(item);
1096   TF_ASSERT_OK(properties.InferStatically(false));
1097   const auto rank_props = properties.GetOutputProperties("Rank");
1098   const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1099   EXPECT_EQ("int32: []", PropToString(rank_prop0));
1100   EXPECT_TRUE(rank_prop0.has_value());
1101   ExpectTensorValues({3}, rank_prop0.value());
1102   const auto identity_props = properties.GetOutputProperties("Identity");
1103   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1104   EXPECT_EQ("int32: []", PropToString(identity_props0));
1105   EXPECT_TRUE(identity_props0.has_value());
1106   ExpectTensorValues({3}, identity_props0.value());
1107 }
1108 
TEST_F(GraphPropertiesTest,SizeOp)1109 TEST_F(GraphPropertiesTest, SizeOp) {
1110   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1111   Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1112   Output r = ops::Size(s.WithOpName("Size"), c);
1113   Output i = ops::Identity(s.WithOpName("Identity"), r);
1114 
1115   GrapplerItem item;
1116   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1117   GraphProperties properties(item);
1118   TF_ASSERT_OK(properties.InferStatically(false));
1119   const auto size_props = properties.GetOutputProperties("Size");
1120   const OpInfo::TensorProperties size_props0 = size_props[0];
1121   EXPECT_EQ("int32: []", PropToString(size_props0));
1122   EXPECT_TRUE(size_props0.has_value());
1123   ExpectTensorValues({24}, size_props0.value());
1124   const auto identity_props = properties.GetOutputProperties("Identity");
1125   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1126   EXPECT_EQ("int32: []", PropToString(identity_props0));
1127   EXPECT_TRUE(identity_props0.has_value());
1128   ExpectTensorValues({24}, identity_props0.value());
1129 }
1130 
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1131 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1132   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1133   Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1134   Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1135   Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1136   // pack is [2], with values {4, -1}.
1137 
1138   Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1139   Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1140 
1141   Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1142   Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1143   // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1144   // their output shapes would be [4, -1]. However, though we use the same
1145   // shape input to the Reshape ops, their output shapes can be different;
1146   // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1147   // the same.
1148 
1149   // if input to the Select ops. Note that s0 has a fully defined shape, while
1150   // s1 has unknown shape.
1151   Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1152   Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1153 
1154   Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1155   Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1156 
1157   // We instantiate SelectV2, but will replace it with Select. The shape
1158   // inference function for Select links all inputs and outputs as they should
1159   // have the same shapes.
1160   Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1161   Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1162 
1163   // For z0, as we know the shape of s0, symbolic shape manager in shape
1164   // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1165   // which is [4, 16].
1166   // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1167   // at best.
1168   // Note that x0 and x1 share the same shape input to the Reshape op, but
1169   // -1 in the shape input should not be treated as the same symoblic unknown
1170   // dim; it is merely a constant value -1 for identitying unknown dim for
1171   // Reshape operation.
1172 
1173   GrapplerItem item;
1174   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1175 
1176   // Replace SelectV2 op with Select op.
1177   for (int i = 0; i < item.graph.node_size(); ++i) {
1178     auto* node = item.graph.mutable_node(i);
1179     if (node->op() == "SelectV2") {
1180       node->set_op("Select");
1181     }
1182   }
1183 
1184   GraphProperties properties(item);
1185   TF_ASSERT_OK(properties.InferStatically(false));
1186   for (const auto& node_name : {"x0", "y0", "z0"}) {
1187     const auto out_props = properties.GetOutputProperties(node_name);
1188     const OpInfo::TensorProperties out_prop0 = out_props[0];
1189     EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1190   }
1191   {
1192     const auto out_props = properties.GetOutputProperties("s0");
1193     const OpInfo::TensorProperties out_prop0 = out_props[0];
1194     EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1195   }
1196 
1197   for (const auto& node_name : {"x1", "y1", "z1"}) {
1198     const auto out_props = properties.GetOutputProperties(node_name);
1199     const OpInfo::TensorProperties out_prop0 = out_props[0];
1200     EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1201   }
1202   // if input of Select can be either vector or the same shape to the
1203   // input/output; in this case, even if we know input and output are
1204   // [4, ?], we can't say it's [4, ?] or a vector; hence, it shoudl be
1205   // unknown.
1206   {
1207     const auto out_props = properties.GetOutputProperties("s1");
1208     const OpInfo::TensorProperties out_prop0 = out_props[0];
1209     EXPECT_EQ("bool: ?", PropToString(out_prop0));
1210   }
1211 }
1212 
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1213 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1214   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1215   // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1216   // from Const.
1217   // If output_tensors_as_shape is not set for those Shape ops or Pack op
1218   // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1219   // hence, its output shape becomes unknown.
1220   Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1221   Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1222   Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1223   Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1224   Output a = ops::Identity(s.WithOpName("a"), a0);
1225   Output b = ops::Identity(s.WithOpName("b"), b0);
1226   Output c = ops::Identity(s.WithOpName("c"), c0);
1227   Output d = ops::Identity(s.WithOpName("d"), d0);
1228   // Note ops::Stack instantiates Pack op.
1229   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1230   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1231   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1232   // Fill needs not only e's shape but also its value to figure out output
1233   // shape.
1234   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1235 
1236   GrapplerItem item;
1237   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1238   GraphProperties properties(item);
1239   TF_ASSERT_OK(properties.InferStatically(false));
1240   const auto out_props = properties.GetOutputProperties("fill");
1241   const OpInfo::TensorProperties out_prop0 = out_props[0];
1242   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1243 }
1244 
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1245 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1246   // Function ops may have DT_RESOURCE input; if not properly set shapes and
1247   // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1248   // function ops.
1249   GrapplerItem item;
1250   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1251                                  "function_with_dt_resource_input.pbtxt");
1252   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1253 
1254   // This graph evaluates FunctionWithDtResourceInput with two inputs:
1255   // x [DT_FLOAT Const],
1256   // _Arg [DT_RESOURCE _Arg]
1257   // and has two outputs:
1258   // z1 = x + _Arg
1259   // z2 = x
1260   {
1261     GraphProperties properties(item);
1262     TF_ASSERT_OK(properties.InferStatically(false));
1263     const auto out_props =
1264         properties.GetOutputProperties("FunctionWithDtResourceInput");
1265     EXPECT_EQ(out_props.size(), 2);
1266     const OpInfo::TensorProperties out_prop0 = out_props[0];
1267     EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1268     const OpInfo::TensorProperties out_prop1 = out_props[1];
1269     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1270   }
1271 
1272   {
1273     // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1274     for (int i = 0; i < item.graph.node_size(); i++) {
1275       auto* node = item.graph.mutable_node(i);
1276       if (node->name() == "y") {  // _Arg node with DT_RESOURCE
1277         node->mutable_attr()->erase("_handle_dtypes");
1278         node->mutable_attr()->erase("_handle_shapes");
1279         break;
1280       }
1281     }
1282     // We cannot infer the function output shape correctly without those attr,
1283     // but still it shouldn't fail; also, there can be some shapes we can
1284     // infer in such a case. In this test graph,
1285     // z2 of the function node just returns x input; hence, even if _Arg's shape
1286     // cannot be inferred, we can infer z2 output shape.
1287     GraphProperties properties(item);
1288     TF_ASSERT_OK(properties.InferStatically(false));
1289     const auto out_props =
1290         properties.GetOutputProperties("FunctionWithDtResourceInput");
1291     EXPECT_EQ(out_props.size(), 2);
1292     const OpInfo::TensorProperties out_prop0 = out_props[0];
1293     // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1294     // for x + _Arg.
1295     EXPECT_EQ("float: ?", PropToString(out_prop0));
1296     // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1297     // infer this output shape.
1298     const OpInfo::TensorProperties out_prop1 = out_props[1];
1299     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1300   }
1301 }
1302 
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1303 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1304   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1305   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1306   Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1307   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1308   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1309                                          s.graph()->op_registry());
1310   tensorflow::Node* func_op;
1311   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1312   auto _value = tensorflow::ops::AsNodeOut(s, value);
1313   TF_ASSERT_OK(
1314       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1315   GrapplerItem item;
1316   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1317 
1318   GraphProperties properties(item);
1319   TF_ASSERT_OK(properties.InferStatically(false));
1320   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1321   const OpInfo::TensorProperties out_prop0 = out_props[0];
1322   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1323 }
1324 
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1325 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1326   // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1327   // so tensor shapes, not tensor value, should be used as Const input to
1328   // function.
1329   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1330   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1331   Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1332   Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1333   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1334   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1335                                          s.graph()->op_registry());
1336   tensorflow::Node* func_op;
1337   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1338   auto _value = tensorflow::ops::AsNodeOut(s, value);
1339   TF_ASSERT_OK(
1340       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1341   GrapplerItem item;
1342   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1343 
1344   GraphProperties properties(item);
1345   TF_ASSERT_OK(properties.InferStatically(false));
1346   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1347   const OpInfo::TensorProperties out_prop0 = out_props[0];
1348   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1349 }
1350 
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1351 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1352   FunctionDefLibrary library;
1353   *library.add_function() = FunctionDefHelper::Create(
1354       "MyFunc",                                                   // Name
1355       {"x: int32"},                                               // Inputs
1356       {"out: int32"},                                             // Outputs
1357       {},                                                         // Attrs
1358       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1359       {{"out", "a:output:0"}});                                   // Returns
1360   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1361   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1362 
1363   // MyFunc takes Const (shape) and passes it with Identity. Expect function
1364   // output has the same shape as well as value (output_tensors_as_shape) as
1365   // input Const tensor.
1366   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1367   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1368   auto builder =
1369       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1370   tensorflow::Node* func_op;
1371   TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1372 
1373   GrapplerItem item;
1374   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1375 
1376   GraphProperties properties(item);
1377   TF_ASSERT_OK(properties.InferStatically(true));
1378   const auto out_props = properties.GetOutputProperties("MyFunc");
1379   const OpInfo::TensorProperties out_prop0 = out_props[0];
1380   EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1381   EXPECT_TRUE(out_prop0.has_value());
1382   ExpectTensorValues({5, 7}, out_prop0.value());
1383   ExpectTensorValues({5, 7},
1384                      properties.GetInputProperties("MyFunc")[0].value());
1385 }
1386 
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1387 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1388   FunctionDefLibrary library;
1389   // Function that adds two input values.
1390   *library.add_function() = FunctionDefHelper::Create(
1391       "MyFunc",                                                   // Name
1392       {"x: int32", "y: int32"},                                   // Inputs
1393       {"out: int32"},                                             // Outputs
1394       {},                                                         // Attrs
1395       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1396       {{"out", "a:z:0"}});                                        // Returns
1397   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1398   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1399 
1400   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1401   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1402   auto builder =
1403       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1404   tensorflow::Node* func_op;
1405   TF_ASSERT_OK(
1406       builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1407 
1408   GrapplerItem item;
1409   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1410   {
1411     GraphProperties properties(item);
1412     // Without aggressive_shape_inference, the internal function does not
1413     // evaluate output value.
1414     TF_ASSERT_OK(properties.InferStatically(
1415         /*assume_valid_feeds=*/true,
1416         /*aggressive_shape_inference=*/false,
1417         /*include_tensor_values=*/true));
1418     const auto out_props = properties.GetOutputProperties("MyFunc");
1419     const OpInfo::TensorProperties out_prop0 = out_props[0];
1420     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1421     EXPECT_FALSE(out_prop0.has_value());
1422   }
1423 
1424   {
1425     GraphProperties properties(item);
1426     // With aggressive_shape_inference, output value is evaluated.
1427     TF_ASSERT_OK(properties.InferStatically(
1428         /*assume_valid_feeds=*/true,
1429         /*aggressive_shape_inference=*/true,
1430         /*include_tensor_values=*/true));
1431     const auto out_props = properties.GetOutputProperties("MyFunc");
1432     const OpInfo::TensorProperties out_prop0 = out_props[0];
1433     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1434     EXPECT_TRUE(out_prop0.has_value());
1435 
1436     ExpectTensorValues({10, 14}, out_prop0.value());
1437     ExpectTensorValues({5, 7},
1438                        properties.GetInputProperties("MyFunc")[0].value());
1439     ExpectTensorValues({5, 7},
1440                        properties.GetInputProperties("MyFunc")[1].value());
1441   }
1442 }
1443 
1444 // Same as the above, but float values; also, one of the function input is
1445 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1446 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1447   FunctionDefLibrary library;
1448   // Function that adds two input values.
1449   *library.add_function() = FunctionDefHelper::Create(
1450       "MyFunc",                                                   // Name
1451       {"x: float", "y: float"},                                   // Inputs
1452       {"out: float"},                                             // Outputs
1453       {},                                                         // Attrs
1454       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1455       {{"out", "a:z:0"}});                                        // Returns
1456   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1457   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1458 
1459   Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1460   Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1461   auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1462   auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1463   auto builder =
1464       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1465   tensorflow::Node* func_op;
1466   TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1467 
1468   GrapplerItem item;
1469   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1470   {
1471     GraphProperties properties(item);
1472     // Without aggressive_shape_inference, the internal function does not
1473     // evaluate output value.
1474     TF_ASSERT_OK(properties.InferStatically(
1475         /*assume_valid_feeds=*/true,
1476         /*aggressive_shape_inference=*/false,
1477         /*include_tensor_values=*/true));
1478     const auto out_props = properties.GetOutputProperties("MyFunc");
1479     const OpInfo::TensorProperties out_prop0 = out_props[0];
1480     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1481     EXPECT_FALSE(out_prop0.has_value());
1482   }
1483 
1484   {
1485     GraphProperties properties(item);
1486     // With aggressive_shape_inference, output value is evaluated.
1487     TF_ASSERT_OK(properties.InferStatically(
1488         /*assume_valid_feeds=*/true,
1489         /*aggressive_shape_inference=*/true,
1490         /*include_tensor_values=*/true));
1491     const auto out_props = properties.GetOutputProperties("MyFunc");
1492     const OpInfo::TensorProperties out_prop0 = out_props[0];
1493     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1494     EXPECT_TRUE(out_prop0.has_value());
1495 
1496     ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1497     ExpectFloatTensorValues({5.0, 7.0},
1498                             properties.GetInputProperties("MyFunc")[0].value());
1499     ExpectFloatTensorValues({5.0, 7.0},
1500                             properties.GetInputProperties("MyFunc")[1].value());
1501   }
1502 }
1503 
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1504 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1505   // Create graph with a function that takes a scalar value so that we use
1506   // Placeholder with scalar as for input to the function shape inference.
1507   // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1508   // the input; all tensors are scalars.
1509   FunctionDefLibrary library;
1510   *library.add_function() = FunctionDefHelper::Create(
1511       "MyFunc",                                                   // Name
1512       {"x: float"},                                               // Inputs
1513       {"out: float"},                                             // Outputs
1514       {},                                                         // Attrs
1515       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1516       {{"out", "a:output:0"}});                                   // Returns
1517   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1518   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1519   Output placeholder =
1520       ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1521                        ops::Placeholder::Shape(TensorShape({})));
1522   Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1523   auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1524   auto builder =
1525       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1526   tensorflow::Node* func_op;
1527   TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1528   GrapplerItem item;
1529   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1530 
1531   // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1532   // as unknown, instead of scalar.
1533   EXPECT_GT(item.graph.versions().producer(), 21);
1534 
1535   // MyFunc output shouldn't be unknown rank.
1536   GraphProperties properties(item);
1537   TF_ASSERT_OK(properties.InferStatically(true));
1538   const auto out_props = properties.GetOutputProperties("MyFunc");
1539   const OpInfo::TensorProperties out_prop0 = out_props[0];
1540   EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1541   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1542 }
1543 
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1544 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1545   // Test graph produced in python using:
1546   /*
1547     @function.Defun(*[tf.float32] * 2, noinline=True)
1548     def MyAdd(x, y):
1549       return tf.add(x,y)
1550 
1551     with tf.Graph().as_default():
1552       x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1553       y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1554       z = MyAdd(x, y)
1555       z = MyAdd(x, z)
1556   */
1557   GrapplerItem item;
1558   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1559                                  "simple_function.pbtxt");
1560   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1561   GraphProperties properties(item);
1562   TF_ASSERT_OK(properties.InferStatically(false));
1563   const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1564   const OpInfo::TensorProperties& out_prop = out_props[0];
1565   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1566   EXPECT_FALSE(out_prop.shape().unknown_rank());
1567   EXPECT_EQ(2, out_prop.shape().dim_size());
1568   EXPECT_EQ(1, out_prop.shape().dim(0).size());
1569   EXPECT_EQ(2, out_prop.shape().dim(1).size());
1570 
1571   const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1572   EXPECT_EQ(2, in_props.size());
1573 
1574   const OpInfo::TensorProperties& in_prop = in_props[0];
1575   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1576 
1577   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1578   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1579 }
1580 
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1581 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1582   GrapplerItem item;
1583   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1584                                  "large_function_graph.pbtxt");
1585   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1586   GraphProperties properties(item);
1587   TF_ASSERT_OK(properties.InferStatically(false));
1588 
1589   const auto out_props = properties.GetOutputProperties("y0");
1590   EXPECT_EQ(2, out_props.size());
1591 
1592   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1593   EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1594 
1595   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1596   EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1597 
1598   const auto in_props = properties.GetInputProperties("y0");
1599   EXPECT_EQ(4, in_props.size());
1600 
1601   const OpInfo::TensorProperties& in_prop0 = in_props[0];
1602   EXPECT_EQ("float: [64]", PropToString(in_prop0));
1603 
1604   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1605   EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1606 
1607   const OpInfo::TensorProperties& in_prop2 = in_props[2];
1608   EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1609 
1610   const OpInfo::TensorProperties& in_prop3 = in_props[3];
1611   EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1612 }
1613 
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1614 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1615   // Test graph produced in python using:
1616   /*
1617     @function.Defun(noinline=True)
1618     def MyFunc():
1619       @function.Defun(*[tf.float32] * 2)
1620       def Cond(n, unused_x):
1621         return n > 0
1622 
1623       @function.Defun(*[tf.float32] * 2)
1624       def Body(n, x):
1625         return n - 1, x + n
1626 
1627       i = tf.constant(10)
1628       return functional_ops.While([i, 0.], Cond, Body)
1629 
1630     with tf.Graph().as_default():
1631       z = MyFunc()
1632   */
1633   GrapplerItem item;
1634   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1635                                  "function_functional_while.pbtxt");
1636   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1637   GraphProperties properties(item);
1638   TF_ASSERT_OK(properties.InferStatically(false));
1639 
1640   const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1641   EXPECT_EQ(2, out_props.size());
1642 
1643   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1644   EXPECT_EQ(DT_INT32, out_prop0.dtype());
1645   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1646 
1647   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1648   EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1649   EXPECT_FALSE(out_prop1.shape().unknown_rank());
1650 }
1651 
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1652 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1653   GrapplerItem item;
1654   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1655                                  "function_error.pbtxt");
1656   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1657   GraphProperties properties(item);
1658   TF_ASSERT_OK(properties.InferStatically(false));
1659 
1660   const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1661   EXPECT_EQ(1, out_props.size());
1662 
1663   const OpInfo::TensorProperties& out_prop = out_props[0];
1664   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1665   EXPECT_TRUE(out_prop.shape().unknown_rank());
1666 
1667   const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1668   EXPECT_EQ(2, in_props.size());
1669 
1670   const OpInfo::TensorProperties& in_prop = in_props[0];
1671   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1672 
1673   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1674   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1675 }
1676 
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1677 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1678   // Test graph produced in python using:
1679   /*
1680     @function.Defun(*[tf.float32] * 2, noinline=True)
1681     def MyAdd(x, y):
1682       return tf.add(x, y)
1683 
1684     with tf.Graph().as_default():
1685       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1686       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1687       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1688       z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1689   */
1690   GrapplerItem item;
1691   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1692                                  "function_switch.pbtxt");
1693   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1694   GraphProperties properties(item);
1695   TF_ASSERT_OK(properties.InferStatically(false));
1696   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1697   const OpInfo::TensorProperties& out_prop = out_props[0];
1698   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1699   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1700 
1701   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1702   EXPECT_EQ(2, in_props.size());
1703 
1704   const OpInfo::TensorProperties& in_prop = in_props[0];
1705   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1706 
1707   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1708   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1709 }
1710 
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1711 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1712   // Test graph produced in python using:
1713   /*
1714     @function.Defun(*[tf.float32] * 2, noinline=True)
1715     def MyAdd(x, y):
1716       return tf.add(x, y)
1717 
1718     with tf.Graph().as_default():
1719       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1720       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1721       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1722       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1723   */
1724   GrapplerItem item;
1725   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1726                                  "function_switch_2.pbtxt");
1727   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1728   GraphProperties properties(item);
1729   TF_ASSERT_OK(properties.InferStatically(false));
1730   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1731   const OpInfo::TensorProperties& out_prop = out_props[0];
1732   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1733 
1734   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1735   EXPECT_EQ(2, in_props.size());
1736 
1737   const OpInfo::TensorProperties& in_prop = in_props[0];
1738   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1739 
1740   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1741   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1742 }
1743 
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1744 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1745   // Test graph produced in python using:
1746   /*
1747     @function.Defun(*[tf.float32] * 2, noinline=True)
1748     def MyAdd(x, y):
1749       a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1750       b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1751       c = tf.add(x, a)
1752       d = tf.add(y, b)
1753       return c
1754 
1755     with tf.Graph().as_default():
1756       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1757       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1758       z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1759       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1760   */
1761   GrapplerItem item;
1762   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1763                                  "function_switch_shapes.pbtxt");
1764   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1765   GraphProperties properties(item);
1766   TF_ASSERT_OK(properties.InferStatically(false));
1767   const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1768   const OpInfo::TensorProperties& out_prop = out_props[0];
1769   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1770 
1771   const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1772   EXPECT_EQ(2, in_props.size());
1773 
1774   const OpInfo::TensorProperties& in_prop = in_props[0];
1775   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1776 
1777   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1778   EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1779 }
1780 
TEST_F(GraphPropertiesTest,SymbolicShapes)1781 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1782   // Build a simple graph with placeholders of unknown dimensions. These
1783   // dimensions will be encoded symbolically.
1784   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1785 
1786   Output a =
1787       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1788                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1789   Output b =
1790       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1791                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1792   Output c = ops::Identity(s.WithOpName("c"), a);
1793   Output d = ops::Identity(s.WithOpName("d"), b);
1794   Output e = ops::Add(s.WithOpName("e"), c, d);
1795   Output f = ops::Add(s.WithOpName("f"), a, c);
1796 
1797   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1798   Output g = ops::Shape(s.WithOpName("g"), c);
1799   Output h = ops::Fill(s.WithOpName("h"), g, zero);
1800   Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1801   Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1802 
1803   GrapplerItem item;
1804   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1805 
1806   GraphProperties properties(item);
1807   TF_ASSERT_OK(properties.InferStatically(false));
1808   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1809   const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1810   EXPECT_EQ(2, shape_a.dim_size());
1811   EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1812   EXPECT_GE(-2, shape_a.dim(0).size());
1813   EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1814   EXPECT_GE(-2, shape_a.dim(1).size());
1815   EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1816 
1817   PartialTensorShape shape(shape_a);
1818   EXPECT_FALSE(shape.IsFullyDefined());
1819   EXPECT_FALSE(shape.unknown_rank());
1820 
1821   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1822   const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1823   EXPECT_EQ(1, shape_b.dim_size());
1824   EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1825   EXPECT_GE(-2, shape_b.dim(0).size());
1826   EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1827   EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1828 
1829   const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1830   ASSERT_EQ(2, shape_e.dim_size());
1831   EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1832   EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1833   EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1834 
1835   const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1836   ASSERT_EQ(2, shape_f.dim_size());
1837   EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1838   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1839 
1840   const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1841   ASSERT_EQ(2, shape_f.dim_size());
1842   EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1843   EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1844 
1845   const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
1846   ASSERT_EQ(1, shape_j.dim_size());
1847   EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
1848 }
1849 
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)1850 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
1851   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1852   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
1853   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
1854   Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
1855   GrapplerItem item;
1856   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1857   // Create a graph with node a removed (say by some graph optimization
1858   // pass), noting that node c is colocated with a. This is fine as it
1859   // is in the late stage of graph execution, the colocation constraints have
1860   // been validated previously and the device placement of nodes has completed.
1861   GraphDef optimized_graph;
1862   for (const auto& node : item.graph.node()) {
1863     if (node.name() != "a") {
1864       *optimized_graph.add_node() = node;
1865     }
1866   }
1867   item.graph.Swap(&optimized_graph);
1868   GraphProperties properties(item);
1869   // This function should return OK, since it doesn't validate the colocation
1870   // constraints internally.
1871   TF_EXPECT_OK(properties.InferStatically(false));
1872 }
1873 
TEST_F(GraphPropertiesTest,ShapeTracking)1874 TEST_F(GraphPropertiesTest, ShapeTracking) {
1875   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1876   Output a =
1877       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1878                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1879   Output b =
1880       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1881                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1882   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1883   auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
1884   Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
1885   Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
1886 
1887   GrapplerItem item;
1888   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1889 
1890   GraphProperties properties(item);
1891   TF_ASSERT_OK(properties.InferStatically(false));
1892   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1893   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1894   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1895   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1896   EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
1897   EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
1898 }
1899 
TEST_F(GraphPropertiesTest,FedNodes)1900 TEST_F(GraphPropertiesTest, FedNodes) {
1901   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
1902                                           cluster_->GetDeviceNames());
1903   GrapplerItem item;
1904   CHECK(fake_input.NextItem(&item));
1905 
1906   {
1907     // Conservative shape analysis: the shape of fed ports should be unknown
1908     GraphProperties properties(item);
1909     Status s = properties.InferStatically(false);
1910     TF_ASSERT_OK(s);
1911     for (const auto& node : item.graph.node()) {
1912       if (node.op() == "Const") {
1913         continue;
1914       }
1915       const auto in_props = properties.GetInputProperties(node.name());
1916       EXPECT_EQ(1, in_props.size());
1917       const OpInfo::TensorProperties& in_prop = in_props[0];
1918       const auto out_props = properties.GetOutputProperties(node.name());
1919       EXPECT_EQ(1, out_props.size());
1920       const OpInfo::TensorProperties& out_prop = out_props[0];
1921 
1922       if (node.name() == "x") {
1923         // x is fed: its input should have a known shape, while its output
1924         // doesn't
1925         EXPECT_FALSE(in_prop.shape().unknown_rank());
1926         EXPECT_EQ(1, in_prop.shape().dim_size());
1927         EXPECT_EQ(2, in_prop.shape().dim(0).size());
1928         EXPECT_TRUE(out_prop.shape().unknown_rank());
1929       } else if (node.op() == "Square" || node.op() == "AddN") {
1930         // These nodes are in the fanout of x: their shapes should be unknown.
1931         EXPECT_TRUE(in_prop.shape().unknown_rank());
1932         EXPECT_TRUE(out_prop.shape().unknown_rank());
1933       }
1934     }
1935   }
1936   {
1937     // Optimistic shape analysis: the shape of fed ports should be derived from
1938     // the shape of the fanin.
1939     GraphProperties properties(item);
1940     Status s = properties.InferStatically(true);
1941     TF_ASSERT_OK(s);
1942     for (const auto& node : item.graph.node()) {
1943       if (node.op() == "Square" || node.op() == "AddN") {
1944         const auto in_props = properties.GetInputProperties(node.name());
1945         EXPECT_EQ(1, in_props.size());
1946         const OpInfo::TensorProperties& in_prop = in_props[0];
1947         EXPECT_EQ(DT_FLOAT, in_prop.dtype());
1948         EXPECT_FALSE(in_prop.shape().unknown_rank());
1949         EXPECT_EQ(2, in_prop.shape().dim_size());
1950         const auto out_props = properties.GetOutputProperties(node.name());
1951         EXPECT_EQ(1, out_props.size());
1952         const OpInfo::TensorProperties& out_prop = out_props[0];
1953         EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
1954         EXPECT_EQ(in_prop.shape().DebugString(),
1955                   out_prop.shape().DebugString());
1956       }
1957     }
1958   }
1959 }
1960 
TEST_F(GraphPropertiesTest,Performance)1961 TEST_F(GraphPropertiesTest, Performance) {
1962   // Load a large graph with many nested loops to make sure we can infer shapes
1963   // quickly.
1964   GrapplerItem item;
1965   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1966                                  "large_graph.pbtxt.html");
1967   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1968   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
1969       &item.graph,
1970       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
1971       true));
1972 
1973   GraphProperties properties(item);
1974   TF_ASSERT_OK(properties.InferStatically(false));
1975 }
1976 
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)1977 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
1978   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1979   Output a =
1980       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1981                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1982   auto shp = ops::Shape(s.WithOpName("shape"), {a});
1983 
1984   Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
1985   Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
1986   Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
1987 
1988   Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
1989   Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
1990 
1991   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1992   Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
1993   Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
1994 
1995   GrapplerItem item;
1996   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1997 
1998   GraphProperties properties(item);
1999   TF_ASSERT_OK(properties.InferStatically(false));
2000   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2001   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2002   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2003   EXPECT_EQ(2, shape_a.dim_size());
2004   EXPECT_EQ(1, shape_o1.dim_size());
2005   EXPECT_EQ(1, shape_o2.dim_size());
2006   EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2007   EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2008 }
2009 
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2010 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2011   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2012   Output placeholder =
2013       ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2014                        ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2015   auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2016 
2017   Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2018   Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2019   Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2020 
2021   Output slice =
2022       ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2023                         stride, ops::StridedSlice::ShrinkAxisMask(1));
2024 
2025   GrapplerItem item;
2026   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2027 
2028   // Without aggressive shape inference, it cannot infer output value of
2029   // StridedSlice with ShrinkAxisMask.
2030   {
2031     GraphProperties properties(item);
2032     TF_ASSERT_OK(properties.InferStatically(
2033         /*assume_valid_feeds=*/false,
2034         /*aggressive_shape_inference=*/false,
2035         /*include_tensor_values=*/true));
2036     EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2037   }
2038 
2039   // InferStatically with aggressive shape inference can infer output value of
2040   // StridedSlice with ShrinkAxisMask.
2041   {
2042     GraphProperties properties(item);
2043     TF_ASSERT_OK(properties.InferStatically(
2044         /*assume_valid_feeds=*/false,
2045         /*aggressive_shape_inference=*/true,
2046         /*include_tensor_values=*/true));
2047     EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2048     const auto slice_value =
2049         properties.GetOutputProperties("slice").at(0).value();
2050     ExpectTensorValues({5}, slice_value);
2051   }
2052 }
2053 
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2054 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2055   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2056   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2057   Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2058   Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2059 
2060   Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2061   Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2062   Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2063   Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2064   Output c_plus_b_plus_2a =
2065       ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2066 
2067   GrapplerItem item;
2068   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2069   GraphProperties properties(item);
2070   TF_ASSERT_OK(properties.InferStatically(
2071       /*assume_valid_feeds=*/false,
2072       /*aggressive_shape_inference=*/true,
2073       /*include_tensor_values=*/true));
2074 
2075   // Check output shapes and values.
2076   const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2077   EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2078   EXPECT_TRUE(a_plus_one_prop.has_value());
2079   ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2080 
2081   const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2082   EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2083   EXPECT_TRUE(a_plus_a_prop.has_value());
2084   ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2085 
2086   const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2087   EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2088   EXPECT_TRUE(b_plus_2a_prop.has_value());
2089   ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2090 
2091   const auto& c_plus_b_plus_2a_prop =
2092       properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2093   EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2094   EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2095   ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2096 }
2097 
TEST_F(GraphPropertiesTest,ShapeAnnotation)2098 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2099   GrapplerItem item;
2100   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2101                    .Attr("dtype", DT_FLOAT)
2102                    .Attr("shape", PartialTensorShape({-1, -1}))
2103                    .Finalize(item.graph.add_node()));
2104   // Annotate shapes.
2105   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2106                    .Attr("dtype", DT_FLOAT)
2107                    .Attr("_same_output_for_iterations", true)
2108                    .Attr("_output_shape_vector", {TensorShape({5, 7})})
2109                    .Input("Input", 0, DT_FLOAT)
2110                    .Finalize(item.graph.add_node()));
2111   {
2112     GraphProperties properties(item);
2113     // Without aggressive_shape_inference, ignore annotated information.
2114     TF_ASSERT_OK(properties.InferStatically(
2115         /*assume_valid_feeds=*/false,
2116         /*aggressive_shape_inference=*/false,
2117         /*include_tensor_values=*/true));
2118     const auto props = properties.GetOutputProperties("Identity");
2119     EXPECT_EQ(1, props.size());
2120     const OpInfo::TensorProperties& prop = props[0];
2121     EXPECT_EQ(DT_FLOAT, prop.dtype());
2122     EXPECT_EQ(2, prop.shape().dim_size());
2123     // Get unknown shapes without using annotated information.
2124     EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2125   }
2126   {
2127     GraphProperties properties(item);
2128     // Use annotated information.
2129     TF_ASSERT_OK(properties.InferStatically(
2130         /*assume_valid_feeds=*/false,
2131         /*aggressive_shape_inference=*/true,
2132         /*include_tensor_values=*/true));
2133     const auto props = properties.GetOutputProperties("Identity");
2134     EXPECT_EQ(1, props.size());
2135     const OpInfo::TensorProperties& prop = props[0];
2136     EXPECT_EQ(DT_FLOAT, prop.dtype());
2137     EXPECT_EQ(2, prop.shape().dim_size());
2138     // Update output shape using annotated shapes.
2139     EXPECT_EQ("float: [5,7]", PropToString(prop));
2140   }
2141 }
2142 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2143 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2144   GrapplerItem item;
2145   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2146                    .Attr("dtype", DT_FLOAT)
2147                    .Attr("shape", PartialTensorShape({-1, 100}))
2148                    .Finalize(item.graph.add_node()));
2149   // Annotate shapes.
2150   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2151                    .Attr("dtype", DT_FLOAT)
2152                    .Attr("_same_output_for_iterations", true)
2153                    .Attr("_output_shape_vector", {TensorShape({10, 100})})
2154                    .Input("Input", 0, DT_FLOAT)
2155                    .Finalize(item.graph.add_node()));
2156   GraphProperties properties(item);
2157   // Use annotated information.
2158   TF_ASSERT_OK(properties.InferStatically(
2159       /*assume_valid_feeds=*/false,
2160       /*aggressive_shape_inference=*/true,
2161       /*include_tensor_values=*/true));
2162   const auto props = properties.GetOutputProperties("Identity");
2163   EXPECT_EQ(1, props.size());
2164   const OpInfo::TensorProperties& prop = props[0];
2165   EXPECT_EQ(DT_FLOAT, prop.dtype());
2166   EXPECT_EQ(2, prop.shape().dim_size());
2167   // Compatible shapes. Update output shape using annotated shapes.
2168   EXPECT_EQ("float: [10,100]", PropToString(prop));
2169 }
2170 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2171 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2172   GrapplerItem item;
2173   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2174                    .Attr("dtype", DT_FLOAT)
2175                    .Attr("shape", PartialTensorShape({-1, 100}))
2176                    .Finalize(item.graph.add_node()));
2177   // Annotate shapes.
2178   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2179                    .Attr("dtype", DT_FLOAT)
2180                    .Attr("_same_output_for_iterations", true)
2181                    .Attr("_output_shape_vector", {TensorShape({10, 10})})
2182                    .Input("Input", 0, DT_FLOAT)
2183                    .Finalize(item.graph.add_node()));
2184   GraphProperties properties(item);
2185   // Use annotated information.
2186   TF_ASSERT_OK(properties.InferStatically(
2187       /*assume_valid_feeds=*/false,
2188       /*aggressive_shape_inference=*/true,
2189       /*include_tensor_values=*/true));
2190   const auto props = properties.GetOutputProperties("Identity");
2191   EXPECT_EQ(1, props.size());
2192   const OpInfo::TensorProperties& prop = props[0];
2193   EXPECT_EQ(DT_FLOAT, prop.dtype());
2194   EXPECT_EQ(2, prop.shape().dim_size());
2195   // Incompatible shapes. Do not use annotated shapes.
2196   EXPECT_EQ("float: [-1,100]", PropToString(prop));
2197 }
2198 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2199 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2200   GrapplerItem item;
2201   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2202                    .Attr("dtype", DT_FLOAT)
2203                    .Attr("shape", PartialTensorShape({-1, -1}))
2204                    .Finalize(item.graph.add_node()));
2205   // Annotate shapes.
2206   TF_ASSERT_OK(
2207       NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2208           .Attr("_same_output_for_iterations", true)
2209           .Attr("_output_shape_vector", {TensorShape({10, 100})})
2210           .Input("Input", 0, DT_FLOAT)
2211           .Finalize(item.graph.add_node()));
2212   GraphProperties properties(item);
2213   // Use annotated information.
2214   TF_ASSERT_OK(properties.InferStatically(
2215       /*assume_valid_feeds=*/false,
2216       /*aggressive_shape_inference=*/true,
2217       /*include_tensor_values=*/true));
2218   const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2219   EXPECT_EQ(1, props.size());
2220   const OpInfo::TensorProperties& prop = props[0];
2221   EXPECT_EQ(DT_FLOAT, prop.dtype());
2222   EXPECT_EQ(2, prop.shape().dim_size());
2223   EXPECT_EQ("float: [10,100]", PropToString(prop));
2224 }
2225 
TEST_F(GraphPropertiesTest,PartitionedCallOp)2226 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2227   Scope root = Scope::NewRootScope().ExitOnError();
2228   FunctionDefLibrary library;
2229   FunctionDef called_func = FunctionDefHelper::Create(
2230       "identity_function",
2231       /*in_def=*/{"arg0: int32"},
2232       /*out_def=*/{"ret0: int32"},
2233       /*attr_def=*/{},
2234       {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2235       /*ret_def=*/{{"ret0", "Identity:output:0"}});
2236   *library.add_function() = called_func;
2237   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2238 
2239   Output in = ops::Const(root, {3, 1, 2, 0});
2240   NameAttrList b_name_attr;
2241   b_name_attr.set_name("identity_function");
2242   ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2243                             b_name_attr);
2244 
2245   GrapplerItem item;
2246   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2247 
2248   GraphProperties properties(item);
2249   TF_ASSERT_OK(properties.InferStatically(
2250       /*assume_valid_feeds=*/true,
2251       /*aggressive_shape_inference=*/false,
2252       /*include_tensor_values=*/true));
2253 
2254   EXPECT_EQ("int32: [4]",
2255             PropToString(properties.GetOutputProperties("identity_call")[0]));
2256 }
2257 
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2258 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2259   auto f = FunctionDefHelper::Create(
2260       // Name
2261       "FunctionWhichAdds",
2262       // Inputs
2263       {"arg0: int32", "arg1: int32"},
2264       // Outputs
2265       {"ret0: int32"},
2266       /*attr_def=*/{},
2267       // Nodes
2268       {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2269       /*ret_def=*/{{"ret0", "a:z:0"}});
2270 
2271   FunctionDefLibrary function_lib;
2272   function_lib.add_function()->Swap(&f);
2273   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2274   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2275 
2276   PartialTensorShape input_shape({2, 2, -1});
2277   Output in1 =
2278       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2279   Output in2 =
2280       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2281   NameAttrList b_name_attr;
2282   b_name_attr.set_name("FunctionWhichAdds");
2283   ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2284                             b_name_attr);
2285 
2286   GrapplerItem item;
2287   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2288 
2289   GraphProperties properties(item);
2290   TF_ASSERT_OK(properties.InferStatically(
2291       /*assume_valid_feeds=*/true,
2292       /*aggressive_shape_inference=*/false,
2293       /*include_tensor_values=*/true));
2294 
2295   EXPECT_EQ("int32: [2,2,-1]",
2296             PropToString(properties.GetOutputProperties("add_call")[0]));
2297 }
2298 
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2299 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2300   // A function, which we cannot infer output shape statically.
2301   auto f = FunctionDefHelper::Create(
2302       // Name
2303       "FuncShapeCannotBeInferred",
2304       // Inputs
2305       {},
2306       // Outputs
2307       {"output: float"},
2308       // Attrs
2309       {},
2310       // Nodes
2311       {
2312           // Placeholder without shape attr; unknown rank.
2313           {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2314       },
2315       // Returns
2316       {{"output", "p:output:0"}});
2317   FunctionDefLibrary function_lib;
2318   function_lib.add_function()->Swap(&f);
2319   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2320   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2321   tensorflow::Node* func_op;
2322   TensorShapeProto output_shape;
2323   output_shape.set_unknown_rank(false);
2324   output_shape.add_dim()->set_size(1);
2325   output_shape.add_dim()->set_size(2);
2326   output_shape.add_dim()->set_size(3);
2327   output_shape.add_dim()->set_size(4);
2328   // The function node, f, includes shape annotation.
2329   TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2330                                        s.graph()->op_registry())
2331                    .Attr("_execution_count", 1)
2332                    .Attr("_same_output_for_iterations", true)
2333                    .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2334                    .Attr("_output_shape_vector", {output_shape})
2335                    .Finalize(s.graph(), &func_op));
2336   GrapplerItem item;
2337   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2338 
2339   // InferStatically with aggressive_shape_inference would fail to infer
2340   // the output shape of the node f.
2341   {
2342     GraphProperties properties(item);
2343     TF_ASSERT_OK(properties.InferStatically(
2344         /*assume_valid_feeds=*/false,
2345         /*aggressive_shape_inference=*/false,
2346         /*include_tensor_values=*/false));
2347     const auto out_props = properties.GetOutputProperties("f");
2348     const OpInfo::TensorProperties out_prop0 = out_props[0];
2349     EXPECT_EQ("float: ?", PropToString(out_prop0));
2350   }
2351   // With aggressive_shape_inference, it skips recursively callying
2352   // InferStatically for the function node and outputs annotated shape info.
2353   {
2354     GraphProperties properties(item);
2355     TF_ASSERT_OK(properties.InferStatically(
2356         /*assume_valid_feeds=*/false,
2357         /*aggressive_shape_inference=*/true,
2358         /*include_tensor_values=*/true));
2359     const auto out_props = properties.GetOutputProperties("f");
2360     const OpInfo::TensorProperties out_prop0 = out_props[0];
2361     EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2362   }
2363 }
2364 
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2365 TEST_F(GraphPropertiesTest,
2366        SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2367   GrapplerItem item;
2368   // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2369   // used for two reshape ops. One reshape op is segment_ids input to
2370   // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2371   // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2372   // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2373   // SymbolicShapeRefiner.
2374   // This dim value (10), however, should not affect the other reshape op, even
2375   // though it shares the shape input; -1 in the shape input of Reshape op is
2376   // a special case of computed output dim, not unknown dim.
2377   // data and num_segments are inputs to UnsortedSegmenetSum.
2378 
2379   TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2380                    .Attr("dtype", DT_FLOAT)
2381                    .Attr("shape", TensorShape({10, 10, 10, 10}))
2382                    .Finalize(item.graph.add_node()));
2383   Tensor num_segments(DT_INT32, TensorShape({}));
2384   // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2385   // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2386   // output to form shape vector [-1, 10] to Reshape.
2387   test::FillIota<int>(&num_segments, 3);
2388   TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2389                    .Attr("dtype", DT_INT32)
2390                    .Attr("value", num_segments)
2391                    .Finalize(item.graph.add_node()));
2392   Tensor minus_one(DT_INT32, TensorShape({1}));
2393   test::FillIota<int>(&minus_one, -1);
2394   TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2395                    .Attr("dtype", DT_INT32)
2396                    .Attr("value", minus_one)
2397                    .Finalize(item.graph.add_node()));
2398   Tensor plus_ten(DT_INT32, TensorShape({1}));
2399   test::FillIota<int>(&plus_ten, 10);
2400   TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2401                    .Attr("dtype", DT_INT32)
2402                    .Attr("value", plus_ten)
2403                    .Finalize(item.graph.add_node()));
2404   Tensor axis(DT_INT32, TensorShape({}));
2405   test::FillIota<int>(&axis, -1);
2406   TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2407                    .Attr("dtype", DT_INT32)
2408                    .Attr("value", axis)
2409                    .Finalize(item.graph.add_node()));
2410   std::vector<NodeDefBuilder::NodeOut> inputs(2);
2411   inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2412   inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2413   TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2414                    .Input(inputs)
2415                    .Input("axis", 0, DT_INT32)
2416                    .Attr("N", 2)
2417                    .Attr("T", DT_INT32)
2418                    .Attr("Tidx", DT_INT32)
2419                    .Finalize(item.graph.add_node()));
2420   TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2421                    .Attr("dtype", DT_FLOAT)
2422                    .Finalize(item.graph.add_node()));
2423   TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2424                    .Input("segment_ids_", 0, DT_FLOAT)
2425                    .Attr("T", DT_FLOAT)
2426                    .Attr("out_type", DT_INT32)
2427                    .Finalize(item.graph.add_node()));
2428   TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2429                    .Input("segment_ids_", 0, DT_FLOAT)
2430                    .Input("concat", 0, DT_INT32)
2431                    .Attr("T", DT_FLOAT)
2432                    .Attr("Tshape", DT_INT32)
2433                    .Finalize(item.graph.add_node()));
2434   // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2435   // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2436   // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2437   // assign 10 from data shape to the unknown dim in segment_ids.
2438   TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2439                    .Input("data", 0, DT_FLOAT)
2440                    .Input("segment_ids", 0, DT_INT32)
2441                    .Input("num_segments", 0, DT_INT32)
2442                    .Attr("T", DT_FLOAT)
2443                    .Attr("Tindices", DT_INT32)
2444                    .Attr("Tnumsegments", DT_INT32)
2445                    .Finalize(item.graph.add_node()));
2446   // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2447   // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2448   TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2449                    .Attr("dtype", DT_FLOAT)
2450                    .Finalize(item.graph.add_node()));
2451   TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2452                    .Input("x1", 0, DT_FLOAT)
2453                    .Input("concat", 0, DT_INT32)
2454                    .Attr("T", DT_FLOAT)
2455                    .Attr("Tshape", DT_INT32)
2456                    .Finalize(item.graph.add_node()));
2457 
2458   GraphProperties properties(item);
2459   TF_ASSERT_OK(properties.InferStatically(true));
2460   const auto& y1_output_properties = properties.GetOutputProperties("y1");
2461   // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2462   // The first dimension should not be 10.
2463   EXPECT_EQ(y1_output_properties.size(), 1);
2464   EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2465   EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2466   EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2467 }
2468 
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2469 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2470   // Make a dummy InferenceContext.
2471   NodeDef node_def;
2472   OpRegistrationData op_reg_data;
2473   OpDefBuilder b("dummy");
2474   CHECK(b.Finalize(&op_reg_data).ok());
2475   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2476       input_handle_shapes_and_types;
2477   InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2478                       /*input_shapes=*/{},
2479                       /*input_tensors=*/{},
2480                       /*input_tensors_as_shapes=*/{},
2481                       std::move(input_handle_shapes_and_types));
2482 
2483   // ShapeHandles for testing.
2484   ShapeHandle fully_defined_vector = ic.MakeShape(
2485       {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2486   ShapeHandle vector_with_unknown = ic.MakeShape(
2487       {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2488   // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2489   // graph_properties.cc
2490   ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2491       {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2492   ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2493 
2494   // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2495   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2496       &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2497   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2498       &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2499 
2500   // Non-integer data type.
2501   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2502       &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2503 
2504   // tensor_as_shape including Unknown or UnknownFromConst.
2505   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2506       &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2507   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2508       &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2509   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2510       &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2511   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2512       &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2513 
2514   // shape rank > 1.
2515   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2516       &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2517 }
2518 }  // namespace
2519 }  // namespace grappler
2520 }  // namespace tensorflow
2521