• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40 
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44 
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48 
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50 
51 REGISTER_OP("TestOpWithNoInferenceFn")
52     .Input("x: float")
53     .Output("y: float")
54     .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59 
60 class GraphPropertiesTest : public ::testing::Test {
61  public:
SetUp()62   void SetUp() override {
63     // Provision a single machine with 3 cpu cores
64     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65     TF_ASSERT_OK(cluster_->Provision());
66 
67     // This function is simply
68     // out = Fill(shape, value), but
69     // Fill requires values in the shape input, not just shape of it, to infer
70     // output shape.
71     auto f = FunctionDefHelper::Create(
72         // Name
73         "MyFillFunc",
74         // Inputs
75         {"shape: int32", "value: float"},
76         // Outputs
77         {"out: float"},
78         // Attrs
79         {},
80         // Nodes
81         {
82             {{"a"},
83              "Fill",
84              {"shape", "value"},
85              {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86         },
87         // Returns
88         {{"out", "a:output:0"}});
89     function_lib_.add_function()->Swap(&f);
90   }
91 
TearDown()92   void TearDown() override {
93     TF_ASSERT_OK(cluster_->Shutdown());
94     cluster_.reset();
95   }
96 
97  protected:
98   // Returns a string form of <p>, suitable for comparing type and shape.
99   // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100   string PropToString(const OpInfo::TensorProperties& p) {
101     string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102     if (p.shape().unknown_rank()) {
103       strings::StrAppend(&s, "?");
104     } else {
105       strings::StrAppend(&s, "[");
106       for (int i = 0; i < p.shape().dim_size(); ++i) {
107         strings::StrAppend(&s, i == 0 ? "" : ",",
108                            std::max<int64>(p.shape().dim(i).size(), -1));
109       }
110       strings::StrAppend(&s, "]");
111     }
112     return s;
113   }
114 
115   // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116   // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)117   void ExpectTensorValues(const std::vector<int64>& expected,
118                           const TensorProto& tensor_proto_to_compare) {
119     Tensor tensor;
120     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121     EXPECT_EQ(expected.size(), tensor.NumElements());
122     // We're interested in only integer tensors as only shapes are exported as
123     // graph properties values.
124     ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125     if (tensor.dtype() == DT_INT32) {
126       for (int i = 0; i < tensor.NumElements(); i++) {
127         EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128       }
129     } else {
130       for (int i = 0; i < tensor.NumElements(); i++) {
131         EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
132       }
133     }
134   }
135 
136   // Compare values of float (DT_FLOAT) tensor against expected
137   // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138   void ExpectFloatTensorValues(const std::vector<float>& expected,
139                                const TensorProto& tensor_proto_to_compare) {
140     Tensor tensor;
141     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142     EXPECT_EQ(expected.size(), tensor.NumElements());
143     ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144     for (int i = 0; i < tensor.NumElements(); i++) {
145       EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146     }
147   }
148 
149   std::unique_ptr<SingleMachine> cluster_;
150   FunctionDefLibrary function_lib_;
151 };
152 
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155                                           cluster_->GetDeviceNames());
156   GrapplerItem item;
157   CHECK(fake_input.NextItem(&item));
158 
159   GraphProperties properties(item);
160   Status s = properties.InferStatically(true);
161   TF_ASSERT_OK(s);
162 
163   for (const auto& node : item.graph.node()) {
164     if (node.op() == "RandomStandardNormal") {
165       // The node has one input (the shape of the tensor to generate).
166       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167       // The const node has one output.
168       const auto props = properties.GetOutputProperties(node.name());
169       EXPECT_EQ(1, props.size());
170       const OpInfo::TensorProperties& prop = props[0];
171       EXPECT_EQ(DT_FLOAT, prop.dtype());
172       EXPECT_FALSE(prop.shape().unknown_rank());
173       EXPECT_EQ(2, prop.shape().dim_size());
174       EXPECT_EQ(10, prop.shape().dim(0).size());
175       EXPECT_EQ(1, prop.shape().dim(1).size());
176     } else if (node.op() == "AddN") {
177       const auto in_props = properties.GetInputProperties(node.name());
178       EXPECT_EQ(1, in_props.size());
179       const OpInfo::TensorProperties& in_prop = in_props[0];
180       EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181       EXPECT_FALSE(in_prop.shape().unknown_rank());
182       EXPECT_EQ(2, in_prop.shape().dim_size());
183       EXPECT_EQ(10, in_prop.shape().dim(0).size());
184       EXPECT_EQ(1, in_prop.shape().dim(1).size());
185       const auto out_props = properties.GetOutputProperties(node.name());
186       EXPECT_EQ(1, out_props.size());
187       EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188       EXPECT_EQ(in_prop.shape().DebugString(),
189                 out_props[0].shape().DebugString());
190     }
191   }
192 }
193 
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196                                           cluster_->GetDeviceNames());
197   GrapplerItem item;
198   CHECK(fake_input.NextItem(&item));
199 
200   GraphProperties properties(item);
201   Status s = properties.InferStatically(true);
202   TF_ASSERT_OK(s);
203 
204   for (const auto& node : item.graph.node()) {
205     if (node.op() == "RandomStandardNormal") {
206       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207       const auto props = properties.GetOutputProperties(node.name());
208       properties.ClearOutputProperties(node.name());
209       const auto cleared_props = properties.GetOutputProperties(node.name());
210       EXPECT_TRUE(cleared_props.empty());
211     } else if (node.op() == "AddN") {
212       const auto in_props = properties.GetInputProperties(node.name());
213       EXPECT_EQ(1, in_props.size());
214       properties.ClearInputProperties(node.name());
215       const auto cleared_props = properties.GetInputProperties(node.name());
216       EXPECT_TRUE(cleared_props.empty());
217     }
218   }
219 }
220 
TEST_F(GraphPropertiesTest,DynamicProperties)221 TEST_F(GraphPropertiesTest, DynamicProperties) {
222   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223                                           cluster_->GetDeviceNames());
224   GrapplerItem item;
225   CHECK(fake_input.NextItem(&item));
226 
227   GraphProperties properties(item);
228   TF_ASSERT_OK(cluster_->Initialize(item));
229   Status s = properties.InferDynamically(cluster_.get());
230   TF_ASSERT_OK(s);
231 
232   for (const auto& node : item.graph.node()) {
233     if (node.op() == "RandomStandardNormal") {
234       // The random node is missing from the cost graph (why ?)
235       EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
236     } else if (node.op() == "AddN") {
237       // Since the random node is missing, we can't infer the input properties
238       // of the first AddN node. The other AddN nodes have the expected
239       // properties.
240       if (node.name() == "AddN") {
241         const auto props = properties.GetInputProperties(node.name());
242         EXPECT_EQ(1, props.size());
243         const OpInfo::TensorProperties& prop = props[0];
244         EXPECT_EQ(DT_INVALID, prop.dtype());
245         EXPECT_TRUE(prop.shape().unknown_rank());
246       } else {
247         const auto props = properties.GetInputProperties(node.name());
248         EXPECT_EQ(1, props.size());
249         const OpInfo::TensorProperties& prop = props[0];
250         EXPECT_EQ(DT_FLOAT, prop.dtype());
251         EXPECT_FALSE(prop.shape().unknown_rank());
252         EXPECT_EQ(2, prop.shape().dim_size());
253         EXPECT_EQ(10, prop.shape().dim(0).size());
254         EXPECT_EQ(1, prop.shape().dim(1).size());
255         const auto out_props = properties.GetOutputProperties(node.name());
256 #ifdef INTEL_MKL
257         if (!NativeFormatEnabled()) {
258           // Intel MKL AddN OP would have two output.
259           // One is the real output, another one for MKL metadata
260           EXPECT_EQ(2, out_props.size());
261         } else {
262           EXPECT_EQ(1, out_props.size());
263         }
264 #else
265         EXPECT_EQ(1, out_props.size());
266 #endif  // INTEL_MKL
267         string prop_str;
268         ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
269         string out_prop_str;
270         ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
271                                                           &out_prop_str);
272         EXPECT_EQ(prop_str, out_prop_str);
273       }
274     }
275   }
276 }
277 
278 // A test op that outputs different shape based on input_tensor in the shape
279 // inference context.
280 REGISTER_OP("DetectInputValueInShapeInferenceOp")
281     .Input("a: T")
282     .Output("o: T")
283     .Attr("T: {numbertype, bool}")
__anond3d7e4310202(shape_inference::InferenceContext* c) 284     .SetShapeFn([](shape_inference::InferenceContext* c) {
285       if (c->input_tensor(0)) {
286         // 10x10 if input_tensor is given to the inference context.
287         c->set_output(0, c->Matrix(10, 10));
288         return Status::OK();
289       }
290       // unknown rank if input_tensor is not provided.
291       return shape_inference::UnknownShape(c);
292     });
293 
294 // Helper class for testing Const tensor skip.
295 class ConstTensorSkipTestCase {
296  public:
ConstTensorSkipTestCase(const DataType data_type,const std::vector<int64> shape,const double value,const bool expected)297   ConstTensorSkipTestCase(const DataType data_type,
298                           const std::vector<int64> shape, const double value,
299                           const bool expected)
300       : data_type_(data_type),
301         shape_(shape),
302         value_(value),
303         expected_(expected) {}
304 
RunTestAndValidate() const305   void RunTestAndValidate() const {
306     LOG(INFO) << "Run Const tensor skip test: "
307               << "data_type: " << data_type_ << ", shape: {"
308               << absl::StrJoin(shape_, ",") << "}, value: " << value_
309               << ", expected: " << expected_;
310     // Build a graph wiht Const --> Identity --> Detect.
311     GrapplerItem item;
312     const gtl::ArraySlice<int64> shape_array_slice(shape_);
313     Tensor const_tensor_value(data_type_, TensorShape(shape_array_slice));
314     // Fille the const tensor value based on data type.
315     switch (data_type_) {
316       case DT_INT32:
317         test::FillIota<int32>(&const_tensor_value, static_cast<int32>(value_));
318         break;
319       case DT_INT64:
320         test::FillIota<int64>(&const_tensor_value, static_cast<int64>(value_));
321         break;
322       case DT_FLOAT:
323         test::FillIota<float>(&const_tensor_value, static_cast<float>(value_));
324         break;
325       case DT_DOUBLE:
326         test::FillIota<double>(&const_tensor_value,
327                                static_cast<double>(value_));
328         break;
329       case DT_BFLOAT16:
330         test::FillIota<Eigen::bfloat16>(&const_tensor_value,
331                                         static_cast<Eigen::bfloat16>(value_));
332         break;
333       default:
334         CHECK(false) << "Unsupported data type (" << data_type_
335                      << ") in this test.";
336         break;
337     }
338     TF_ASSERT_OK(NodeDefBuilder("const", "Const")
339                      .Attr("dtype", data_type_)
340                      .Attr("value", const_tensor_value)
341                      .Finalize(item.graph.add_node()));
342     TF_ASSERT_OK(NodeDefBuilder("const_identity", "Identity")
343                      .Attr("dtype", data_type_)
344                      .Input("const", 0, data_type_)
345                      .Finalize(item.graph.add_node()));
346     TF_ASSERT_OK(NodeDefBuilder("detect", "DetectInputValueInShapeInferenceOp")
347                      .Attr("T", data_type_)
348                      .Input("const_identity", 0, data_type_)
349                      .Finalize(item.graph.add_node()));
350     item.fetch.push_back("const");
351     item.fetch.push_back("const_identity");
352     item.fetch.push_back("detect");
353 
354     // Run static shape inference.
355     GraphProperties graph_properties(item);
356     TF_ASSERT_OK(graph_properties.InferStatically(false));
357 
358     // Extract input / output properties of interest.
359     const auto& const_output = graph_properties.GetOutputProperties("const");
360     EXPECT_EQ(1, const_output.size());
361     const OpInfo::TensorProperties& const_output0 = const_output[0];
362     const auto& const_identity_input =
363         graph_properties.GetInputProperties("const_identity");
364     EXPECT_EQ(1, const_identity_input.size());
365     const OpInfo::TensorProperties& const_identity_input0 =
366         const_identity_input[0];
367     const auto& const_identity_output =
368         graph_properties.GetOutputProperties("const_identity");
369     EXPECT_EQ(1, const_identity_output.size());
370     const OpInfo::TensorProperties& const_identity_output0 =
371         const_identity_output[0];
372     EXPECT_TRUE(const_output0.has_value());
373     EXPECT_TRUE(const_identity_input0.has_value());
374     EXPECT_TRUE(const_identity_output0.has_value());
375     const auto& detect_input = graph_properties.GetInputProperties("detect");
376     EXPECT_EQ(1, detect_input.size());
377     const OpInfo::TensorProperties& detect_input0 = detect_input[0];
378     const auto& detect_output = graph_properties.GetOutputProperties("detect");
379     EXPECT_EQ(1, detect_output.size());
380     const OpInfo::TensorProperties& detect_output0 = detect_output[0];
381 
382     // Tensor protos are propagated, regardless of types and sizes.
383     EXPECT_TRUE(const_output0.has_value());
384     EXPECT_TRUE(const_identity_input0.has_value());
385     EXPECT_TRUE(const_identity_output0.has_value());
386     EXPECT_TRUE(detect_input0.has_value());
387 
388     // Detect op outputs 10x10 matrix if it has input_tensor in the shape
389     // inference context. Otherwise, unknown rank.
390     if (expected_) {
391       EXPECT_EQ(detect_output0.shape().dim_size(), 2);
392       EXPECT_EQ(detect_output0.shape().dim(0).size(), 10);
393       EXPECT_EQ(detect_output0.shape().dim(1).size(), 10);
394     } else {
395       EXPECT_TRUE(detect_output0.shape().unknown_rank());
396     }
397   }
398 
399  private:
400   DataType data_type_;
401   std::vector<int64> shape_;
402   double value_;
403   bool expected_;
404 };
405 
TEST_F(GraphPropertiesTest,SkipInstantiatingConstTensor)406 TEST_F(GraphPropertiesTest, SkipInstantiatingConstTensor) {
407   // We skip const tensor value propagation in shape inference, if a const
408   // tensor is too large.
409   std::vector<ConstTensorSkipTestCase> test_cases = {
410       // data_type, shape, value, bool: propagate const?
411       {DT_INT32, {16, 8}, 1, true},      // 128 elements; smaller than threshold
412       {DT_INT32, {1, 129}, 2, false},    // 129 elements; larger than threshold
413       {DT_INT64, {8, 8}, 3, true},       // 64 elements; smaller than threshold
414       {DT_INT64, {128, 2}, 0, false},    // 256 elements; larger than threshold
415       {DT_FLOAT, {16, 8}, 1.0, true},    // integer value for float tensor
416       {DT_FLOAT, {16, 8}, 1.3, true},    // fractional value (1.3)
417       {DT_FLOAT, {1, 129}, 0.7, false},  // fractional value (0.7)
418       {DT_DOUBLE, {16, 8}, 1.0, true},   // integer value for float tensor
419       {DT_DOUBLE, {16, 8}, 1.3, true},   // fractional value (1.3)
420       {DT_DOUBLE, {1, 129}, 0.7, false},    // fractional value (0.7)
421       {DT_BFLOAT16, {16, 8}, 1.0, true},    // integer value for float tensor
422       {DT_BFLOAT16, {16, 8}, 1.3, true},    // fractional value (1.3)
423       {DT_BFLOAT16, {1, 129}, 0.7, false},  // fractional value (0.7)
424   };
425   for (const auto& test_case : test_cases) {
426     test_case.RunTestAndValidate();
427   }
428 }
429 
TEST_F(GraphPropertiesTest,Variables)430 TEST_F(GraphPropertiesTest, Variables) {
431   GrapplerItem item;
432   TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
433                    .Attr("dtype", DT_FLOAT)
434                    .Attr("shape", TensorShape({3, 7}))
435                    .Finalize(item.graph.add_node()));
436   item.fetch.push_back("Var");
437 
438   Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
439   test::FillIota<float>(&initial_val, 0);
440   TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
441                    .Attr("dtype", DT_FLOAT)
442                    .Attr("value", initial_val)
443                    .Finalize(item.graph.add_node()));
444   TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
445                    .Input("Var", 0, DT_FLOAT_REF)
446                    .Input("InitialVal", 0, DT_FLOAT)
447                    .Finalize(item.graph.add_node()));
448   item.init_ops.push_back("InitVar");
449 
450   {
451     GraphProperties static_properties(item);
452     TF_ASSERT_OK(static_properties.InferStatically(false));
453 
454     const auto props = static_properties.GetOutputProperties("Var");
455     EXPECT_EQ(1, props.size());
456     const OpInfo::TensorProperties& prop = props[0];
457     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
458     EXPECT_FALSE(prop.shape().unknown_rank());
459     EXPECT_EQ(2, prop.shape().dim_size());
460     EXPECT_EQ(3, prop.shape().dim(0).size());
461     EXPECT_EQ(7, prop.shape().dim(1).size());
462   }
463   {
464     TF_ASSERT_OK(cluster_->Initialize(item));
465     GraphProperties dynamic_properties(item);
466     TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
467 
468     const auto props = dynamic_properties.GetOutputProperties("Var");
469     EXPECT_EQ(1, props.size());
470     const OpInfo::TensorProperties& prop = props[0];
471     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
472     EXPECT_FALSE(prop.shape().unknown_rank());
473     EXPECT_EQ(2, prop.shape().dim_size());
474     EXPECT_EQ(3, prop.shape().dim(0).size());
475     EXPECT_EQ(7, prop.shape().dim(1).size());
476   }
477 }
478 
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)479 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
480   GrapplerItem item;
481   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
482                    .Attr("dtype", DT_FLOAT)
483                    .Attr("shape", TensorShape({3, 7}))
484                    .Finalize(item.graph.add_node()));
485   TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
486                    .Attr("T", DT_RESOURCE)
487                    .Attr("frame_name", "while_context")
488                    .Attr("is_constant", true)
489                    .Attr("parallel_iterations", 10)
490                    .Input("Var", 0, DT_RESOURCE)
491                    .Finalize(item.graph.add_node()));
492   TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
493                    .Attr("dtype", DT_FLOAT)
494                    .Input("Enter", 0, DT_RESOURCE)
495                    .Finalize(item.graph.add_node()));
496 
497   GraphProperties properties(item);
498   TF_ASSERT_OK(properties.InferStatically(false));
499   const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
500   EXPECT_EQ(1, props.size());
501   const OpInfo::TensorProperties& prop = props[0];
502   EXPECT_EQ(DT_FLOAT, prop.dtype());
503   EXPECT_FALSE(prop.shape().unknown_rank());
504   EXPECT_EQ(2, prop.shape().dim_size());
505   EXPECT_EQ(3, prop.shape().dim(0).size());
506   EXPECT_EQ(7, prop.shape().dim(1).size());
507 }
508 
TEST_F(GraphPropertiesTest,VarHandles)509 TEST_F(GraphPropertiesTest, VarHandles) {
510   GrapplerItem item;
511   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
512                    .Attr("dtype", DT_FLOAT)
513                    .Attr("shape", TensorShape({3, 7}))
514                    .Finalize(item.graph.add_node()));
515 
516   TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
517                    .Attr("dtype", DT_FLOAT)
518                    .Input("Var", 0, DT_RESOURCE)
519                    .Finalize(item.graph.add_node()));
520 
521   GraphProperties properties(item);
522   TF_ASSERT_OK(properties.InferStatically(false));
523 
524   const auto props = properties.GetOutputProperties("VarRead");
525   EXPECT_EQ(1, props.size());
526   const OpInfo::TensorProperties& prop = props[0];
527   EXPECT_EQ(DT_FLOAT, prop.dtype());
528   EXPECT_FALSE(prop.shape().unknown_rank());
529   EXPECT_EQ(2, prop.shape().dim_size());
530   EXPECT_EQ(3, prop.shape().dim(0).size());
531   EXPECT_EQ(7, prop.shape().dim(1).size());
532 }
533 
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)534 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
535   // Test graph is first generated in python using:
536   /*
537     i0 = tf.constant(0)
538     v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
539     def cond(i, x):
540       return i < 3
541     def body(i, x):
542       return i + 1, x + x
543     v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
544   */
545   // and then modified by hand such that the ReadVariableOp is inside the loop
546   // body instead of outside the while loop (which is the case when constructed
547   // using the python API), such that we have the following pattern: VarHandleOp
548   // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
549   // DT_RESOURCE is passed all the way until ReadVariableOp.
550   GrapplerItem item;
551   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
552                                  "while_loop_var_handle_op.pbtxt");
553   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
554   GraphProperties properties(item);
555   TF_ASSERT_OK(properties.InferStatically(false));
556 
557   std::vector<string> resource_nodes{
558       "loop_var",       "while/Enter",         "while/Merge", "while/Switch",
559       "while/Identity", "while/NextIteration", "while/Exit"};
560   for (const string& node : resource_nodes) {
561     const auto props = properties.GetOutputProperties(node);
562     EXPECT_GE(props.size(), 1);  // Merge has 2 outputs.
563     EXPECT_EQ("resource: []", PropToString(props[0]));
564   }
565 
566   // After ReadVariableOp, the shape should be recovered.
567   const auto props = properties.GetOutputProperties("while/ReadVariableOp");
568   EXPECT_EQ(1, props.size());
569   EXPECT_EQ("int32: []", PropToString(props[0]));
570 }
571 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)572 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
573   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
574   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
575   auto dequeue1 =
576       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
577 
578   GrapplerItem item;
579   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
580 
581   GraphProperties properties(item);
582   TF_ASSERT_OK(properties.InferStatically(false));
583 
584   const auto props1 = properties.GetOutputProperties("Dequeue1");
585   ASSERT_EQ(1, props1.size());
586   EXPECT_EQ("float: ?", PropToString(props1[0]));
587 }
588 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)589 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
590   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
591   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
592                            ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
593   auto dequeue1 =
594       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
595 
596   GrapplerItem item;
597   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
598 
599   GraphProperties properties(item);
600   TF_ASSERT_OK(properties.InferStatically(false));
601 
602   const auto props1 = properties.GetOutputProperties("Dequeue1");
603   ASSERT_EQ(1, props1.size());
604   EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
605 }
606 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)607 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
608   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
609   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
610                            ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
611   auto dequeue1 =
612       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
613 
614   GrapplerItem item;
615   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
616 
617   GraphProperties properties(item);
618   TF_ASSERT_OK(properties.InferStatically(false));
619 
620   const auto props1 = properties.GetOutputProperties("Dequeue1");
621   ASSERT_EQ(1, props1.size());
622   EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
623 }
624 
TEST_F(GraphPropertiesTest,Queues)625 TEST_F(GraphPropertiesTest, Queues) {
626   // Create a graph with known input shapes, and propagate the shapes through a
627   // couple of queues.
628   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
629 
630   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
631   Output rnd =
632       ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
633   Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
634   auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
635   auto dequeue1 =
636       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
637 
638   auto q2 =
639       ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
640   Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
641   auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
642   auto dequeue2 =
643       ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
644 
645   auto q4 =
646       ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
647   auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
648   auto enqueue4_2 =
649       ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
650   auto dequeue4 =
651       ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
652 
653   // Create a queue that takes in three tensors.
654   auto q5 = ops::RandomShuffleQueue(
655       root.WithOpName("Queue5"),
656       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
657   Output rnd2 =
658       ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
659   Output rnd3 =
660       ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
661   auto enqueue5 =
662       ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
663   auto dequeue5 = ops::QueueDequeue(
664       root.WithOpName("Dequeue5"), q5,
665       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
666 
667   GrapplerItem item;
668   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
669 
670   GraphProperties properties(item);
671   TF_ASSERT_OK(properties.InferStatically(false));
672 
673   const auto props1 = properties.GetOutputProperties("Dequeue1");
674   ASSERT_EQ(1, props1.size());
675   EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
676 
677   const auto props2 = properties.GetOutputProperties("Dequeue2");
678   ASSERT_EQ(1, props2.size());
679   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
680 
681   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
682   // that we merge the 2 properly to determine the shape of the data coming out
683   // of the queue.
684   const auto props4 = properties.GetOutputProperties("Dequeue4");
685   ASSERT_EQ(1, props4.size());
686   EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
687 
688   // The dequeue5 op shape is known.
689   const auto props5 = properties.GetOutputProperties("Dequeue5");
690   ASSERT_EQ(3, props5.size());
691   EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
692   EXPECT_EQ("double: [10]", PropToString(props5[1]));
693   EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
694 }
695 
TEST_F(GraphPropertiesTest,MergeWithoutLoops)696 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
697   // Test graph produced in python using:
698   /*
699     with tf.Graph().as_default():
700       x = tf.constant(2)
701       y = tf.constant(5)
702       z = tf.ones([1,1,1])
703       def f1(): return tf.concat([z, z], axis=0)
704       def f2(): return tf.concat([z, z], axis=1)
705       r = tf.cond(tf.less(x, y), f1, f2)
706       tf.concat([r, r], axis=2)
707       with open('/tmp/graph.pbtxt', 'w') as f:
708         f.write(str(tf.get_default_graph().as_graph_def()))
709    */
710 
711   GrapplerItem item;
712   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
713                                  "merge_without_loops.pbtxt");
714   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
715   GraphProperties properties(item);
716   TF_ASSERT_OK(properties.InferStatically(false));
717 
718   std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
719   std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
720                                        "float: [1,2,1]"};
721   for (int i = 0; i < nodes.size(); i++) {
722     const auto props = properties.GetOutputProperties(nodes[i]);
723     const OpInfo::TensorProperties& prop = props[0];
724     EXPECT_EQ(DT_FLOAT, prop.dtype());
725     EXPECT_EQ(expected_outputs[i], PropToString(prop));
726   }
727 
728   // The "Less" node should be fed by 2 int32 scalar constant values.
729   const auto props = properties.GetInputProperties("Less");
730   EXPECT_EQ(2, props.size());
731   for (int i = 0; i < props.size(); ++i) {
732     EXPECT_EQ(DT_INT32, props[i].dtype());
733     EXPECT_TRUE(props[i].has_value());
734     EXPECT_EQ("int32: []", PropToString(props[i]));
735   }
736 }
737 
TEST_F(GraphPropertiesTest,WhileLoop)738 TEST_F(GraphPropertiesTest, WhileLoop) {
739   // Test graph produced in python using:
740   /*
741      with tf.Graph().as_default():
742        i0 = tf.constant(0)
743        m0 = tf.placeholder([-1, 2])
744        c = lambda i, m: i < 10
745        b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
746        r = tf.while_loop(
747               c, b, loop_vars=[i0, m0],
748               shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
749        with open('/tmp/graph.pbtxt', 'w') as f:
750          f.write(str(tf.get_default_graph().as_graph_def()))
751   */
752 
753   GrapplerItem item;
754   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
755                                  "while_loop.pbtxt");
756   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
757   GraphProperties properties(item);
758   TF_ASSERT_OK(properties.InferStatically(false));
759 
760   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
761                             "while/Exit_1"};
762   for (const string& node : nodes) {
763     const auto props = properties.GetOutputProperties(node);
764     const OpInfo::TensorProperties& prop = props[0];
765     EXPECT_EQ(DT_FLOAT, prop.dtype());
766     EXPECT_EQ("float: [-1,2]", PropToString(prop));
767   }
768 
769   // The loop outputs batch dim should be different from the input batch dim
770   // since we concatenated along the batch dim.
771   auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
772   auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
773   EXPECT_GE(-2, shape_in.dim(0).size());
774   EXPECT_GE(-2, shape_out.dim(0).size());
775   EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
776 }
777 
TEST_F(GraphPropertiesTest,NestedLoop)778 TEST_F(GraphPropertiesTest, NestedLoop) {
779   // Test graph produced in python using:
780   /*
781     with tf.Graph().as_default():
782       i0 = tf.constant(0)
783 
784       def inner(j, y):
785         def inner_cond(j, y):
786           return j < 3
787 
788         def inner_body(j, y):
789           return j+1, tf.concat([y, y], axis=2)
790 
791         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
792                              shape_invariants=[i0.get_shape(),
793                                               tf.TensorShape([None, 1, None])])
794 
795       def outer_cond(i, x):
796         return i < 3
797 
798       def outer_body(i, x):
799         j, y = inner(0, x)
800         return i+1, tf.concat([x, x], axis=0)
801 
802       r = tf.while_loop(outer_cond, outer_body,
803                         loop_vars=[i0, tf.ones([1, 1, 1])],
804                         shape_invariants=[i0.get_shape(),
805                                           tf.TensorShape([None, 1, None])])
806 
807       with open('/tmp/graph.pbtxt', 'w') as f:
808         f.write(str(tf.get_default_graph().as_graph_def()))
809   */
810 
811   GrapplerItem item;
812   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
813                                  "nested_loop.pbtxt");
814   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
815   GraphProperties properties(item);
816   TF_ASSERT_OK(properties.InferStatically(false));
817 
818   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
819                                   "while/Exit_1"};
820   std::vector<string> inner_nodes{"while/while/Merge_1",
821                                   "while/while/NextIteration_1",
822                                   "while/while/Exit_1"};
823   for (const string& node : outer_nodes) {
824     const auto props = properties.GetOutputProperties(node);
825     const OpInfo::TensorProperties& prop = props[0];
826     EXPECT_EQ(DT_FLOAT, prop.dtype());
827     EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
828   }
829   for (const string& node : inner_nodes) {
830     const auto props = properties.GetOutputProperties(node);
831     const OpInfo::TensorProperties& prop = props[0];
832     EXPECT_EQ(DT_FLOAT, prop.dtype());
833     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
834   }
835 }
836 
TEST_F(GraphPropertiesTest,LoopsAndQueues)837 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
838   // Test graph produced in python using:
839   /*
840     with tf.Graph().as_default():
841       i0 = tf.constant(0)
842       q = tf.FIFOQueue(1, "float")
843 
844       def inner(j, y):
845         def inner_cond(j, y):
846           return j < 3
847 
848         def inner_body(j, y):
849           return j+1, tf.concat([y, y], axis=0)
850 
851         return tf.while_loop(inner_cond, inner_body,
852                              loop_vars=[j, y],
853                              shape_invariants=[i0.get_shape(),
854                                                tf.TensorShape(None)])
855 
856       def outer_cond(i, x):
857         return i < 3
858 
859       def outer_body(i, x):
860         q.enqueue(x)
861         y = tf.concat([x, x], axis=2)
862         inner(0, q.dequeue())
863         return i+1, y
864 
865       i, z = tf.while_loop(outer_cond, outer_body,
866                            loop_vars=[i0, tf.ones([1, 1, 1])],
867                            shape_invariants=[i0.get_shape(),
868                                              tf.TensorShape([None, 1, None])])
869 
870       with open('/tmp/graph.pbtxt', 'w') as f:
871         f.write(str(tf.get_default_graph().as_graph_def()))
872    */
873 
874   GrapplerItem item;
875   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
876                                  "loops_and_queues.pbtxt");
877   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
878   GraphProperties properties(item);
879   TF_ASSERT_OK(properties.InferStatically(false));
880 
881   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
882                                   "while/Exit_1"};
883   std::vector<string> inner_nodes{"while/while/Merge_1",
884                                   "while/while/NextIteration_1",
885                                   "while/while/Exit_1"};
886   for (const string& node : outer_nodes) {
887     const auto props = properties.GetOutputProperties(node);
888     const OpInfo::TensorProperties& prop = props[0];
889     EXPECT_EQ(DT_FLOAT, prop.dtype());
890     EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
891   }
892   for (const string& node : inner_nodes) {
893     const auto props = properties.GetOutputProperties(node);
894     const OpInfo::TensorProperties& prop = props[0];
895     EXPECT_EQ(DT_FLOAT, prop.dtype());
896     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
897   }
898 }
899 
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)900 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
901   // Test graph produced in python using:
902   /*
903     with tf.Graph().as_default():
904       i0 = tf.constant(0)
905       with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
906         v = tf.get_variable(initializer=i0, name='loop_var')
907 
908       def inner(j, y):
909         def inner_cond(j, y):
910           return j < 3
911 
912         def inner_body(j, y):
913           return j + 1, y + y
914 
915         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
916 
917       def outer_cond(i, x):
918         return i < 3
919 
920       def outer_body(i, x):
921         y = x + x
922         inner(0, v)
923         return i + 1, y
924 
925       v, z = tf.while_loop(outer_cond, outer_body,
926                            loop_vars=[v, tf.constant(1)])
927 
928       with open('/tmp/graph.pbtxt', 'w') as f:
929         f.write(str(tf.get_default_graph().as_graph_def()))
930   */
931 
932   GrapplerItem item;
933   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
934                                  "loops_and_resource_vars.pbtxt");
935   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
936   GraphProperties properties(item);
937   TF_ASSERT_OK(properties.InferStatically(false));
938 
939   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
940                                   "while/Exit_1"};
941   std::vector<string> inner_nodes{"while/while/Merge_1",
942                                   "while/while/NextIteration_1",
943                                   "while/while/Exit_1"};
944   for (const string& node : outer_nodes) {
945     const auto props = properties.GetOutputProperties(node);
946     const OpInfo::TensorProperties& prop = props[0];
947     EXPECT_EQ(DT_INT32, prop.dtype());
948     EXPECT_EQ("int32: []", PropToString(prop));
949   }
950   for (const string& node : inner_nodes) {
951     const auto props = properties.GetOutputProperties(node);
952     const OpInfo::TensorProperties& prop = props[0];
953     EXPECT_EQ(DT_INT32, prop.dtype());
954     EXPECT_EQ("int32: []", PropToString(prop));
955   }
956 }
957 
TEST_F(GraphPropertiesTest,QueuesAndLoops)958 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
959   // Test graph produced in python using:
960   /*
961     with tf.Graph().as_default():
962       i0 = tf.constant(0)
963       q0 = tf.FIFOQueue(1, "float")
964       q0.enqueue(tf.ones([2, 2]))
965       q1 = tf.FIFOQueue(1, "float")
966 
967       def c(i, m):
968         return i < 10
969 
970       def b(i, m):
971         return i+1, tf.concat([m, m], axis=0)
972 
973       i, m = tf.while_loop(
974           c, b, loop_vars=[i0,  q0.dequeue()],
975           shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
976 
977       q1.enqueue(m)
978       v = q1.dequeue();
979       tf.concat([v, v], axis=1)
980       with open('/tmp/graph.pbtxt', 'w') as f:
981         f.write(str(tf.get_default_graph().as_graph_def()))
982   */
983 
984   GrapplerItem item;
985   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
986                                  "queues_and_loops.pbtxt");
987   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
988   GraphProperties properties(item);
989   TF_ASSERT_OK(properties.InferStatically(false));
990 
991   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
992                             "while/Exit_1"};
993 
994   for (const string& node : nodes) {
995     const auto props = properties.GetOutputProperties(node);
996     const OpInfo::TensorProperties& prop = props[0];
997     EXPECT_EQ(DT_FLOAT, prop.dtype());
998     EXPECT_EQ("float: [-1,2]", PropToString(prop));
999   }
1000 
1001   const auto props = properties.GetOutputProperties("concat");
1002   const OpInfo::TensorProperties& prop = props[0];
1003   EXPECT_EQ(DT_FLOAT, prop.dtype());
1004   EXPECT_EQ("float: [-1,4]", PropToString(prop));
1005 }
1006 
TEST_F(GraphPropertiesTest,InferRestoreOpShape)1007 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
1008   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1009   Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
1010                              DataType::DT_FLOAT);
1011   Output filename =
1012       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1013   Output tensor_name =
1014       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1015   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1016                                 DataType::DT_FLOAT);
1017   Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
1018 
1019   Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
1020                                       string("256 256 0,128:-"), TensorShape());
1021   Output restore_slice =
1022       ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
1023                         shape_and_slice, DataType::DT_FLOAT);
1024   Output init_restore_slice =
1025       ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
1026 
1027   Output restore_v2 =
1028       ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
1029                         shape_and_slice, DataType::DT_FLOAT);
1030   Output init_restore_v2 =
1031       ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
1032 
1033   GrapplerItem item;
1034   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1035   item.fetch.push_back("init_restore");
1036 
1037   GraphProperties properties(item);
1038   TF_ASSERT_OK(properties.InferStatically(false));
1039 
1040   const auto restore_props = properties.GetOutputProperties("restore");
1041   const OpInfo::TensorProperties& restore_prop = restore_props[0];
1042   EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
1043   EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
1044 
1045   const auto restore_slice_props =
1046       properties.GetOutputProperties("restore_slice");
1047   const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
1048   EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
1049   EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
1050 
1051   const auto restorev2_props = properties.GetOutputProperties("restore_v2");
1052   const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
1053   EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
1054   EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
1055 
1056   // Check input shapes of assign op are propagated correctly.
1057   const auto input_props = properties.GetInputProperties("init_restore");
1058   ASSERT_EQ(2, input_props.size());
1059   const OpInfo::TensorProperties& input_prop = input_props[1];
1060   EXPECT_EQ(DT_FLOAT, input_prop.dtype());
1061   EXPECT_EQ("float: [128,256]", PropToString(input_prop));
1062 }
1063 
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)1064 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
1065   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1066   Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
1067                              DataType::DT_FLOAT);
1068   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
1069                               DataType::DT_FLOAT);
1070   Output filename =
1071       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1072   Output tensor_name =
1073       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1074   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1075                                 DataType::DT_FLOAT);
1076   Output init = ops::Assign(s.WithOpName("init"), var, restore);
1077   Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
1078 
1079   GrapplerItem item;
1080   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1081   item.fetch.push_back("init");
1082   item.fetch.push_back("init2");
1083 
1084   GraphProperties properties(item);
1085   TF_ASSERT_OK(properties.InferStatically(false));
1086 
1087   const auto props = properties.GetOutputProperties("restore");
1088   const OpInfo::TensorProperties& prop = props[0];
1089   EXPECT_EQ(DT_FLOAT, prop.dtype());
1090   EXPECT_EQ("float: [128,256]", PropToString(prop));
1091 }
1092 
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)1093 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
1094   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1095   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1096   Output a1 = ops::Identity(s.WithOpName("a1"), a);
1097   Output b = ops::Const(s.WithOpName("b"), 99, {});
1098   Output b1 = ops::Identity(s.WithOpName("b1"), b);
1099   Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
1100   Output c1 = ops::Identity(s.WithOpName("c1"), c);
1101 
1102   GrapplerItem item;
1103   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1104   GraphProperties properties(item);
1105   TF_ASSERT_OK(properties.InferStatically(false));
1106 
1107   // Check output shapes.
1108   EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
1109   EXPECT_EQ("int32: [2]",
1110             PropToString(properties.GetOutputProperties("a1")[0]));
1111   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
1112   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
1113   EXPECT_EQ("int32: [4,4,4]",
1114             PropToString(properties.GetOutputProperties("c")[0]));
1115   EXPECT_EQ("int32: [4,4,4]",
1116             PropToString(properties.GetOutputProperties("c1")[0]));
1117 
1118   // Check has_value.
1119   EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
1120   EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
1121   EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
1122   EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
1123   EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
1124   EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
1125   EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
1126   EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
1127   // Note that we propagate tensor value of only 1D vector and scalar.
1128   EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
1129 
1130   // Check values.
1131   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
1132   ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
1133   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
1134   ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
1135   ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
1136   ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
1137   std::vector<int64> c_values;
1138   for (int i = 0; i < 4 * 4 * 4; i++) {
1139     c_values.push_back(1);
1140   }
1141   ExpectTensorValues({c_values},
1142                      properties.GetOutputProperties("c")[0].value());
1143   ExpectTensorValues({c_values},
1144                      properties.GetInputProperties("c1")[0].value());
1145   ExpectTensorValues({c_values},
1146                      properties.GetOutputProperties("c1")[0].value());
1147 }
1148 
TEST_F(GraphPropertiesTest,IdentityPassingShape)1149 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
1150   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1151   Output a = ops::Const(s.WithOpName("a"), 5, {2});
1152   Output b = ops::Identity(s.WithOpName("b"), a);
1153   Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1154   // Fill needs not only e's shape but also the value of e to figure out output
1155   // shape; hence, Identity op (b) should pass a's value as
1156   // output_tensors_as_shape.
1157   Output d = ops::Fill(s.WithOpName("fill"), b, c);
1158 
1159   GrapplerItem item;
1160   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1161   GraphProperties properties(item);
1162   TF_ASSERT_OK(properties.InferStatically(false));
1163   const auto out_props = properties.GetOutputProperties("fill");
1164   const OpInfo::TensorProperties out_prop0 = out_props[0];
1165   EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1166 }
1167 
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1168 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1169   // When using aggressive_shape_inference, we run EvaluateNode() for
1170   // allowlisted ops and small input / output tensors. For instance, Fill op is
1171   // evaluated and produces output tensor value if output tensor size is small
1172   // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1173   // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1174   // initializing a large table using Fill.
1175   // Note that we do not propagate float const tensors with fractional values
1176   // (even if they're small); so this test should use integer values.
1177   {
1178     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1179     Output a = ops::Const(s.WithOpName("a"), 4, {2});  // 4x4
1180     Output b = ops::Const(s.WithOpName("const"), 7, {});
1181     // Shape described by a is small; expect output values of Fill op.
1182     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1183 
1184     GrapplerItem item;
1185     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1186     GraphProperties properties(item);
1187     TF_ASSERT_OK(properties.InferStatically(
1188         /*assume_valid_feeds=*/false,
1189         /*aggressive_shape_inference=*/true,
1190         /*include_tensor_values=*/true));
1191     const auto out_props = properties.GetOutputProperties("fill");
1192     const OpInfo::TensorProperties out_prop0 = out_props[0];
1193     EXPECT_EQ("int32: [4,4]", PropToString(out_prop0));
1194     EXPECT_TRUE(out_prop0.has_value());
1195   }
1196   {
1197     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1198     Output a = ops::Const(s.WithOpName("a"), 1000, {4});  // 1000x1000x1000x1000
1199     Output b = ops::Const(s.WithOpName("const"), 7, {});
1200     // Shape described by a is huge; in that case we skip value inference.
1201     // Otherwise, it'd be too much overhead.
1202     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1203 
1204     GrapplerItem item;
1205     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1206     GraphProperties properties(item);
1207     TF_ASSERT_OK(properties.InferStatically(
1208         /*assume_valid_feeds=*/false,
1209         /*aggressive_shape_inference=*/true,
1210         /*include_tensor_values=*/true));
1211     const auto out_props = properties.GetOutputProperties("fill");
1212     const OpInfo::TensorProperties out_prop0 = out_props[0];
1213     EXPECT_EQ("int32: [1000,1000,1000,1000]", PropToString(out_prop0));
1214     EXPECT_FALSE(out_prop0.has_value());
1215   }
1216 }
1217 
TEST_F(GraphPropertiesTest,PackWithConstInput)1218 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1219   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1220   Output a = ops::Const(s.WithOpName("a"), 1, {});
1221   Output b = ops::Const(s.WithOpName("b"), 2, {});
1222   Output c = ops::Const(s.WithOpName("c"), 3, {});
1223   Output d = ops::Const(s.WithOpName("d"), 4, {});
1224   // Note ops::Stack instantiates Pack op.
1225   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1226   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1227   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1228   // Fill needs not only e's shape but also its value to figure out output
1229   // shape.
1230   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1231 
1232   GrapplerItem item;
1233   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1234   GraphProperties properties(item);
1235   TF_ASSERT_OK(properties.InferStatically(false));
1236   const auto out_props = properties.GetOutputProperties("fill");
1237   const OpInfo::TensorProperties out_prop0 = out_props[0];
1238   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1239 }
1240 
TEST_F(GraphPropertiesTest,RankOp)1241 TEST_F(GraphPropertiesTest, RankOp) {
1242   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1243   Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1244   Output r = ops::Rank(s.WithOpName("Rank"), c);
1245   Output i = ops::Identity(s.WithOpName("Identity"), r);
1246 
1247   GrapplerItem item;
1248   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1249   GraphProperties properties(item);
1250   TF_ASSERT_OK(properties.InferStatically(false));
1251   const auto rank_props = properties.GetOutputProperties("Rank");
1252   const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1253   EXPECT_EQ("int32: []", PropToString(rank_prop0));
1254   EXPECT_TRUE(rank_prop0.has_value());
1255   ExpectTensorValues({3}, rank_prop0.value());
1256   const auto identity_props = properties.GetOutputProperties("Identity");
1257   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1258   EXPECT_EQ("int32: []", PropToString(identity_props0));
1259   EXPECT_TRUE(identity_props0.has_value());
1260   ExpectTensorValues({3}, identity_props0.value());
1261 }
1262 
TEST_F(GraphPropertiesTest,SizeOp)1263 TEST_F(GraphPropertiesTest, SizeOp) {
1264   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1265   Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1266   Output r = ops::Size(s.WithOpName("Size"), c);
1267   Output i = ops::Identity(s.WithOpName("Identity"), r);
1268 
1269   GrapplerItem item;
1270   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1271   GraphProperties properties(item);
1272   TF_ASSERT_OK(properties.InferStatically(false));
1273   const auto size_props = properties.GetOutputProperties("Size");
1274   const OpInfo::TensorProperties size_props0 = size_props[0];
1275   EXPECT_EQ("int32: []", PropToString(size_props0));
1276   EXPECT_TRUE(size_props0.has_value());
1277   ExpectTensorValues({24}, size_props0.value());
1278   const auto identity_props = properties.GetOutputProperties("Identity");
1279   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1280   EXPECT_EQ("int32: []", PropToString(identity_props0));
1281   EXPECT_TRUE(identity_props0.has_value());
1282   ExpectTensorValues({24}, identity_props0.value());
1283 }
1284 
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1285 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1286   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1287   Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1288   Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1289   Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1290   // pack is [2], with values {4, -1}.
1291 
1292   Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1293   Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1294 
1295   Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1296   Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1297   // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1298   // their output shapes would be [4, -1]. However, though we use the same
1299   // shape input to the Reshape ops, their output shapes can be different;
1300   // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1301   // the same.
1302 
1303   // if input to the Select ops. Note that s0 has a fully defined shape, while
1304   // s1 has unknown shape.
1305   Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1306   Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1307 
1308   Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1309   Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1310 
1311   // We instantiate SelectV2, but will replace it with Select. The shape
1312   // inference function for Select links all inputs and outputs as they should
1313   // have the same shapes.
1314   Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1315   Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1316 
1317   // For z0, as we know the shape of s0, symbolic shape manager in shape
1318   // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1319   // which is [4, 16].
1320   // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1321   // at best.
1322   // Note that x0 and x1 share the same shape input to the Reshape op, but
1323   // -1 in the shape input should not be treated as the same symoblic unknown
1324   // dim; it is merely a constant value -1 for identitying unknown dim for
1325   // Reshape operation.
1326 
1327   GrapplerItem item;
1328   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1329 
1330   // Replace SelectV2 op with Select op.
1331   for (int i = 0; i < item.graph.node_size(); ++i) {
1332     auto* node = item.graph.mutable_node(i);
1333     if (node->op() == "SelectV2") {
1334       node->set_op("Select");
1335     }
1336   }
1337 
1338   GraphProperties properties(item);
1339   TF_ASSERT_OK(properties.InferStatically(false));
1340   for (const auto& node_name : {"x0", "y0", "z0"}) {
1341     const auto out_props = properties.GetOutputProperties(node_name);
1342     const OpInfo::TensorProperties out_prop0 = out_props[0];
1343     EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1344   }
1345   {
1346     const auto out_props = properties.GetOutputProperties("s0");
1347     const OpInfo::TensorProperties out_prop0 = out_props[0];
1348     EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1349   }
1350 
1351   for (const auto& node_name : {"x1", "y1", "z1"}) {
1352     const auto out_props = properties.GetOutputProperties(node_name);
1353     const OpInfo::TensorProperties out_prop0 = out_props[0];
1354     EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1355   }
1356   // if input of Select can be either vector or the same shape to the
1357   // input/output; in this case, even if we know input and output are
1358   // [4, ?], we can't say it's [4, ?] or a vector; hence, it shoudl be
1359   // unknown.
1360   {
1361     const auto out_props = properties.GetOutputProperties("s1");
1362     const OpInfo::TensorProperties out_prop0 = out_props[0];
1363     EXPECT_EQ("bool: ?", PropToString(out_prop0));
1364   }
1365 }
1366 
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1367 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1368   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1369   // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1370   // from Const.
1371   // If output_tensors_as_shape is not set for those Shape ops or Pack op
1372   // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1373   // hence, its output shape becomes unknown.
1374   Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1375   Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1376   Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1377   Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1378   Output a = ops::Identity(s.WithOpName("a"), a0);
1379   Output b = ops::Identity(s.WithOpName("b"), b0);
1380   Output c = ops::Identity(s.WithOpName("c"), c0);
1381   Output d = ops::Identity(s.WithOpName("d"), d0);
1382   // Note ops::Stack instantiates Pack op.
1383   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1384   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1385   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1386   // Fill needs not only e's shape but also its value to figure out output
1387   // shape.
1388   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1389 
1390   GrapplerItem item;
1391   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1392   GraphProperties properties(item);
1393   TF_ASSERT_OK(properties.InferStatically(false));
1394   const auto out_props = properties.GetOutputProperties("fill");
1395   const OpInfo::TensorProperties out_prop0 = out_props[0];
1396   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1397 }
1398 
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1399 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1400   // Function ops may have DT_RESOURCE input; if not properly set shapes and
1401   // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1402   // function ops.
1403   GrapplerItem item;
1404   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1405                                  "function_with_dt_resource_input.pbtxt");
1406   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1407 
1408   // This graph evaluates FunctionWithDtResourceInput with two inputs:
1409   // x [DT_FLOAT Const],
1410   // _Arg [DT_RESOURCE _Arg]
1411   // and has two outputs:
1412   // z1 = x + _Arg
1413   // z2 = x
1414   {
1415     GraphProperties properties(item);
1416     TF_ASSERT_OK(properties.InferStatically(false));
1417     const auto out_props =
1418         properties.GetOutputProperties("FunctionWithDtResourceInput");
1419     EXPECT_EQ(out_props.size(), 2);
1420     const OpInfo::TensorProperties out_prop0 = out_props[0];
1421     EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1422     const OpInfo::TensorProperties out_prop1 = out_props[1];
1423     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1424   }
1425 
1426   {
1427     // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1428     for (int i = 0; i < item.graph.node_size(); i++) {
1429       auto* node = item.graph.mutable_node(i);
1430       if (node->name() == "y") {  // _Arg node with DT_RESOURCE
1431         node->mutable_attr()->erase("_handle_dtypes");
1432         node->mutable_attr()->erase("_handle_shapes");
1433         break;
1434       }
1435     }
1436     // We cannot infer the function output shape correctly without those attr,
1437     // but still it shouldn't fail; also, there can be some shapes we can
1438     // infer in such a case. In this test graph,
1439     // z2 of the function node just returns x input; hence, even if _Arg's shape
1440     // cannot be inferred, we can infer z2 output shape.
1441     GraphProperties properties(item);
1442     TF_ASSERT_OK(properties.InferStatically(false));
1443     const auto out_props =
1444         properties.GetOutputProperties("FunctionWithDtResourceInput");
1445     EXPECT_EQ(out_props.size(), 2);
1446     const OpInfo::TensorProperties out_prop0 = out_props[0];
1447     // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1448     // for x + _Arg.
1449     EXPECT_EQ("float: ?", PropToString(out_prop0));
1450     // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1451     // infer this output shape.
1452     const OpInfo::TensorProperties out_prop1 = out_props[1];
1453     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1454   }
1455 }
1456 
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1457 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1458   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1459   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1460   Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1461   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1462   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1463                                          s.graph()->op_registry());
1464   tensorflow::Node* func_op;
1465   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1466   auto _value = tensorflow::ops::AsNodeOut(s, value);
1467   TF_ASSERT_OK(
1468       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1469   GrapplerItem item;
1470   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1471 
1472   GraphProperties properties(item);
1473   TF_ASSERT_OK(properties.InferStatically(false));
1474   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1475   const OpInfo::TensorProperties out_prop0 = out_props[0];
1476   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1477 }
1478 
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1479 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1480   // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1481   // so tensor shapes, not tensor value, should be used as Const input to
1482   // function.
1483   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1484   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1485   Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1486   Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1487   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1488   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1489                                          s.graph()->op_registry());
1490   tensorflow::Node* func_op;
1491   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1492   auto _value = tensorflow::ops::AsNodeOut(s, value);
1493   TF_ASSERT_OK(
1494       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1495   GrapplerItem item;
1496   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1497 
1498   GraphProperties properties(item);
1499   TF_ASSERT_OK(properties.InferStatically(false));
1500   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1501   const OpInfo::TensorProperties out_prop0 = out_props[0];
1502   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1503 }
1504 
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1505 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1506   FunctionDefLibrary library;
1507   *library.add_function() = FunctionDefHelper::Create(
1508       "MyFunc",                                                   // Name
1509       {"x: int32"},                                               // Inputs
1510       {"out: int32"},                                             // Outputs
1511       {},                                                         // Attrs
1512       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1513       {{"out", "a:output:0"}});                                   // Returns
1514   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1515   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1516 
1517   // MyFunc takes Const (shape) and passes it with Identity. Expect function
1518   // output has the same shape as well as value (output_tensors_as_shape) as
1519   // input Const tensor.
1520   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1521   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1522   auto builder =
1523       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1524   tensorflow::Node* func_op;
1525   TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1526 
1527   GrapplerItem item;
1528   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1529 
1530   GraphProperties properties(item);
1531   TF_ASSERT_OK(properties.InferStatically(true));
1532   const auto out_props = properties.GetOutputProperties("MyFunc");
1533   const OpInfo::TensorProperties out_prop0 = out_props[0];
1534   EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1535   EXPECT_TRUE(out_prop0.has_value());
1536   ExpectTensorValues({5, 7}, out_prop0.value());
1537   ExpectTensorValues({5, 7},
1538                      properties.GetInputProperties("MyFunc")[0].value());
1539 }
1540 
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1541 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1542   FunctionDefLibrary library;
1543   // Function that adds two input values.
1544   *library.add_function() = FunctionDefHelper::Create(
1545       "MyFunc",                                                   // Name
1546       {"x: int32", "y: int32"},                                   // Inputs
1547       {"out: int32"},                                             // Outputs
1548       {},                                                         // Attrs
1549       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1550       {{"out", "a:z:0"}});                                        // Returns
1551   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1552   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1553 
1554   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1555   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1556   auto builder =
1557       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1558   tensorflow::Node* func_op;
1559   TF_ASSERT_OK(
1560       builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1561 
1562   GrapplerItem item;
1563   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1564   {
1565     GraphProperties properties(item);
1566     // Without aggressive_shape_inference, the internal function does not
1567     // evaluate output value.
1568     TF_ASSERT_OK(properties.InferStatically(
1569         /*assume_valid_feeds=*/true,
1570         /*aggressive_shape_inference=*/false,
1571         /*include_tensor_values=*/true));
1572     const auto out_props = properties.GetOutputProperties("MyFunc");
1573     const OpInfo::TensorProperties out_prop0 = out_props[0];
1574     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1575     EXPECT_FALSE(out_prop0.has_value());
1576   }
1577 
1578   {
1579     GraphProperties properties(item);
1580     // With aggressive_shape_inference, output value is evaluated.
1581     TF_ASSERT_OK(properties.InferStatically(
1582         /*assume_valid_feeds=*/true,
1583         /*aggressive_shape_inference=*/true,
1584         /*include_tensor_values=*/true));
1585     const auto out_props = properties.GetOutputProperties("MyFunc");
1586     const OpInfo::TensorProperties out_prop0 = out_props[0];
1587     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1588     EXPECT_TRUE(out_prop0.has_value());
1589 
1590     ExpectTensorValues({10, 14}, out_prop0.value());
1591     ExpectTensorValues({5, 7},
1592                        properties.GetInputProperties("MyFunc")[0].value());
1593     ExpectTensorValues({5, 7},
1594                        properties.GetInputProperties("MyFunc")[1].value());
1595   }
1596 }
1597 
1598 // Same as the above, but float values; also, one of the function input is
1599 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1600 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1601   FunctionDefLibrary library;
1602   // Function that adds two input values.
1603   *library.add_function() = FunctionDefHelper::Create(
1604       "MyFunc",                                                   // Name
1605       {"x: float", "y: float"},                                   // Inputs
1606       {"out: float"},                                             // Outputs
1607       {},                                                         // Attrs
1608       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1609       {{"out", "a:z:0"}});                                        // Returns
1610   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1611   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1612 
1613   Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1614   Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1615   auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1616   auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1617   auto builder =
1618       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1619   tensorflow::Node* func_op;
1620   TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1621 
1622   GrapplerItem item;
1623   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1624   {
1625     GraphProperties properties(item);
1626     // Without aggressive_shape_inference, the internal function does not
1627     // evaluate output value.
1628     TF_ASSERT_OK(properties.InferStatically(
1629         /*assume_valid_feeds=*/true,
1630         /*aggressive_shape_inference=*/false,
1631         /*include_tensor_values=*/true));
1632     const auto out_props = properties.GetOutputProperties("MyFunc");
1633     const OpInfo::TensorProperties out_prop0 = out_props[0];
1634     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1635     EXPECT_FALSE(out_prop0.has_value());
1636   }
1637 
1638   {
1639     GraphProperties properties(item);
1640     // With aggressive_shape_inference, output value is evaluated.
1641     TF_ASSERT_OK(properties.InferStatically(
1642         /*assume_valid_feeds=*/true,
1643         /*aggressive_shape_inference=*/true,
1644         /*include_tensor_values=*/true));
1645     const auto out_props = properties.GetOutputProperties("MyFunc");
1646     const OpInfo::TensorProperties out_prop0 = out_props[0];
1647     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1648     EXPECT_TRUE(out_prop0.has_value());
1649 
1650     ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1651     ExpectFloatTensorValues({5.0, 7.0},
1652                             properties.GetInputProperties("MyFunc")[0].value());
1653     ExpectFloatTensorValues({5.0, 7.0},
1654                             properties.GetInputProperties("MyFunc")[1].value());
1655   }
1656 }
1657 
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1658 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1659   // Create graph with a function that takes a scalar value so that we use
1660   // Placeholder with scalar as for input to the function shape inference.
1661   // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1662   // the input; all tensors are scalars.
1663   FunctionDefLibrary library;
1664   *library.add_function() = FunctionDefHelper::Create(
1665       "MyFunc",                                                   // Name
1666       {"x: float"},                                               // Inputs
1667       {"out: float"},                                             // Outputs
1668       {},                                                         // Attrs
1669       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1670       {{"out", "a:output:0"}});                                   // Returns
1671   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1672   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1673   Output placeholder =
1674       ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1675                        ops::Placeholder::Shape(TensorShape({})));
1676   Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1677   auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1678   auto builder =
1679       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1680   tensorflow::Node* func_op;
1681   TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1682   GrapplerItem item;
1683   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1684 
1685   // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1686   // as unknown, instead of scalar.
1687   EXPECT_GT(item.graph.versions().producer(), 21);
1688 
1689   // MyFunc output shouldn't be unknown rank.
1690   GraphProperties properties(item);
1691   TF_ASSERT_OK(properties.InferStatically(true));
1692   const auto out_props = properties.GetOutputProperties("MyFunc");
1693   const OpInfo::TensorProperties out_prop0 = out_props[0];
1694   EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1695   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1696 }
1697 
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1698 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1699   // Test graph produced in python using:
1700   /*
1701     @function.Defun(*[tf.float32] * 2, noinline=True)
1702     def MyAdd(x, y):
1703       return tf.add(x,y)
1704 
1705     with tf.Graph().as_default():
1706       x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1707       y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1708       z = MyAdd(x, y)
1709       z = MyAdd(x, z)
1710   */
1711   GrapplerItem item;
1712   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1713                                  "simple_function.pbtxt");
1714   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1715   GraphProperties properties(item);
1716   TF_ASSERT_OK(properties.InferStatically(false));
1717   const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1718   const OpInfo::TensorProperties& out_prop = out_props[0];
1719   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1720   EXPECT_FALSE(out_prop.shape().unknown_rank());
1721   EXPECT_EQ(2, out_prop.shape().dim_size());
1722   EXPECT_EQ(1, out_prop.shape().dim(0).size());
1723   EXPECT_EQ(2, out_prop.shape().dim(1).size());
1724 
1725   const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1726   EXPECT_EQ(2, in_props.size());
1727 
1728   const OpInfo::TensorProperties& in_prop = in_props[0];
1729   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1730 
1731   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1732   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1733 }
1734 
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1735 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1736   GrapplerItem item;
1737   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1738                                  "large_function_graph.pbtxt");
1739   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1740   GraphProperties properties(item);
1741   TF_ASSERT_OK(properties.InferStatically(false));
1742 
1743   const auto out_props = properties.GetOutputProperties("y0");
1744   EXPECT_EQ(2, out_props.size());
1745 
1746   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1747   EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1748 
1749   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1750   EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1751 
1752   const auto in_props = properties.GetInputProperties("y0");
1753   EXPECT_EQ(4, in_props.size());
1754 
1755   const OpInfo::TensorProperties& in_prop0 = in_props[0];
1756   EXPECT_EQ("float: [64]", PropToString(in_prop0));
1757 
1758   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1759   EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1760 
1761   const OpInfo::TensorProperties& in_prop2 = in_props[2];
1762   EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1763 
1764   const OpInfo::TensorProperties& in_prop3 = in_props[3];
1765   EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1766 }
1767 
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1768 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1769   // Test graph produced in python using:
1770   /*
1771     @function.Defun(noinline=True)
1772     def MyFunc():
1773       @function.Defun(*[tf.float32] * 2)
1774       def Cond(n, unused_x):
1775         return n > 0
1776 
1777       @function.Defun(*[tf.float32] * 2)
1778       def Body(n, x):
1779         return n - 1, x + n
1780 
1781       i = tf.constant(10)
1782       return functional_ops.While([i, 0.], Cond, Body)
1783 
1784     with tf.Graph().as_default():
1785       z = MyFunc()
1786   */
1787   GrapplerItem item;
1788   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1789                                  "function_functional_while.pbtxt");
1790   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1791   GraphProperties properties(item);
1792   TF_ASSERT_OK(properties.InferStatically(false));
1793 
1794   const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1795   EXPECT_EQ(2, out_props.size());
1796 
1797   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1798   EXPECT_EQ(DT_INT32, out_prop0.dtype());
1799   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1800 
1801   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1802   EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1803   EXPECT_FALSE(out_prop1.shape().unknown_rank());
1804 }
1805 
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1806 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1807   GrapplerItem item;
1808   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1809                                  "function_error.pbtxt");
1810   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1811   GraphProperties properties(item);
1812   TF_ASSERT_OK(properties.InferStatically(false));
1813 
1814   const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1815   EXPECT_EQ(1, out_props.size());
1816 
1817   const OpInfo::TensorProperties& out_prop = out_props[0];
1818   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1819   EXPECT_TRUE(out_prop.shape().unknown_rank());
1820 
1821   const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1822   EXPECT_EQ(2, in_props.size());
1823 
1824   const OpInfo::TensorProperties& in_prop = in_props[0];
1825   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1826 
1827   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1828   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1829 }
1830 
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1831 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1832   // Test graph produced in python using:
1833   /*
1834     @function.Defun(*[tf.float32] * 2, noinline=True)
1835     def MyAdd(x, y):
1836       return tf.add(x, y)
1837 
1838     with tf.Graph().as_default():
1839       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1840       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1841       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1842       z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1843   */
1844   GrapplerItem item;
1845   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1846                                  "function_switch.pbtxt");
1847   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1848   GraphProperties properties(item);
1849   TF_ASSERT_OK(properties.InferStatically(false));
1850   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1851   const OpInfo::TensorProperties& out_prop = out_props[0];
1852   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1853   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1854 
1855   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1856   EXPECT_EQ(2, in_props.size());
1857 
1858   const OpInfo::TensorProperties& in_prop = in_props[0];
1859   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1860 
1861   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1862   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1863 }
1864 
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1865 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1866   // Test graph produced in python using:
1867   /*
1868     @function.Defun(*[tf.float32] * 2, noinline=True)
1869     def MyAdd(x, y):
1870       return tf.add(x, y)
1871 
1872     with tf.Graph().as_default():
1873       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1874       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1875       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1876       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1877   */
1878   GrapplerItem item;
1879   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1880                                  "function_switch_2.pbtxt");
1881   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1882   GraphProperties properties(item);
1883   TF_ASSERT_OK(properties.InferStatically(false));
1884   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1885   const OpInfo::TensorProperties& out_prop = out_props[0];
1886   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1887 
1888   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1889   EXPECT_EQ(2, in_props.size());
1890 
1891   const OpInfo::TensorProperties& in_prop = in_props[0];
1892   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1893 
1894   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1895   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1896 }
1897 
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1898 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1899   // Test graph produced in python using:
1900   /*
1901     @function.Defun(*[tf.float32] * 2, noinline=True)
1902     def MyAdd(x, y):
1903       a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1904       b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1905       c = tf.add(x, a)
1906       d = tf.add(y, b)
1907       return c
1908 
1909     with tf.Graph().as_default():
1910       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1911       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1912       z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1913       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1914   */
1915   GrapplerItem item;
1916   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1917                                  "function_switch_shapes.pbtxt");
1918   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1919   GraphProperties properties(item);
1920   TF_ASSERT_OK(properties.InferStatically(false));
1921   const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1922   const OpInfo::TensorProperties& out_prop = out_props[0];
1923   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1924 
1925   const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1926   EXPECT_EQ(2, in_props.size());
1927 
1928   const OpInfo::TensorProperties& in_prop = in_props[0];
1929   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1930 
1931   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1932   EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1933 }
1934 
TEST_F(GraphPropertiesTest,SymbolicShapes)1935 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1936   // Build a simple graph with placeholders of unknown dimensions. These
1937   // dimensions will be encoded symbolically.
1938   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1939 
1940   Output a =
1941       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1942                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1943   Output b =
1944       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1945                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1946   Output c = ops::Identity(s.WithOpName("c"), a);
1947   Output d = ops::Identity(s.WithOpName("d"), b);
1948   Output e = ops::Add(s.WithOpName("e"), c, d);
1949   Output f = ops::Add(s.WithOpName("f"), a, c);
1950 
1951   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1952   Output g = ops::Shape(s.WithOpName("g"), c);
1953   Output h = ops::Fill(s.WithOpName("h"), g, zero);
1954   Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1955   Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1956 
1957   GrapplerItem item;
1958   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1959 
1960   GraphProperties properties(item);
1961   TF_ASSERT_OK(properties.InferStatically(false));
1962   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1963   const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1964   EXPECT_EQ(2, shape_a.dim_size());
1965   EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1966   EXPECT_GE(-2, shape_a.dim(0).size());
1967   EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1968   EXPECT_GE(-2, shape_a.dim(1).size());
1969   EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1970 
1971   PartialTensorShape shape(shape_a);
1972   EXPECT_FALSE(shape.IsFullyDefined());
1973   EXPECT_FALSE(shape.unknown_rank());
1974 
1975   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1976   const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1977   EXPECT_EQ(1, shape_b.dim_size());
1978   EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1979   EXPECT_GE(-2, shape_b.dim(0).size());
1980   EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1981   EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1982 
1983   const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1984   ASSERT_EQ(2, shape_e.dim_size());
1985   EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1986   EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1987   EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1988 
1989   const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1990   ASSERT_EQ(2, shape_f.dim_size());
1991   EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1992   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1993 
1994   const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1995   ASSERT_EQ(2, shape_f.dim_size());
1996   EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1997   EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1998 
1999   const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
2000   ASSERT_EQ(1, shape_j.dim_size());
2001   EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
2002 }
2003 
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)2004 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
2005   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2006   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
2007   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
2008   Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
2009   GrapplerItem item;
2010   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2011   // Create a graph with node a removed (say by some graph optimization
2012   // pass), noting that node c is colocated with a. This is fine as it
2013   // is in the late stage of graph execution, the colocation constraints have
2014   // been validated previously and the device placement of nodes has completed.
2015   GraphDef optimized_graph;
2016   for (const auto& node : item.graph.node()) {
2017     if (node.name() != "a") {
2018       *optimized_graph.add_node() = node;
2019     }
2020   }
2021   item.graph.Swap(&optimized_graph);
2022   GraphProperties properties(item);
2023   // This function should return OK, since it doesn't validate the colocation
2024   // constraints internally.
2025   TF_EXPECT_OK(properties.InferStatically(false));
2026 }
2027 
TEST_F(GraphPropertiesTest,ShapeTracking)2028 TEST_F(GraphPropertiesTest, ShapeTracking) {
2029   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2030   Output a =
2031       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2032                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2033   Output b =
2034       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
2035                        ops::Placeholder::Shape(PartialTensorShape({-1})));
2036   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2037   auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
2038   Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
2039   Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
2040 
2041   GrapplerItem item;
2042   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2043 
2044   GraphProperties properties(item);
2045   TF_ASSERT_OK(properties.InferStatically(false));
2046   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2047   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
2048   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2049   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2050   EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
2051   EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
2052 }
2053 
TEST_F(GraphPropertiesTest,FedNodes)2054 TEST_F(GraphPropertiesTest, FedNodes) {
2055   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
2056                                           cluster_->GetDeviceNames());
2057   GrapplerItem item;
2058   CHECK(fake_input.NextItem(&item));
2059 
2060   {
2061     // Conservative shape analysis: the shape of fed ports should be unknown
2062     GraphProperties properties(item);
2063     Status s = properties.InferStatically(false);
2064     TF_ASSERT_OK(s);
2065     for (const auto& node : item.graph.node()) {
2066       if (node.op() == "Const") {
2067         continue;
2068       }
2069       const auto in_props = properties.GetInputProperties(node.name());
2070       EXPECT_EQ(1, in_props.size());
2071       const OpInfo::TensorProperties& in_prop = in_props[0];
2072       const auto out_props = properties.GetOutputProperties(node.name());
2073       EXPECT_EQ(1, out_props.size());
2074       const OpInfo::TensorProperties& out_prop = out_props[0];
2075 
2076       if (node.name() == "x") {
2077         // x is fed: its input should have a known shape, while its output
2078         // doesn't
2079         EXPECT_FALSE(in_prop.shape().unknown_rank());
2080         EXPECT_EQ(1, in_prop.shape().dim_size());
2081         EXPECT_EQ(2, in_prop.shape().dim(0).size());
2082         EXPECT_TRUE(out_prop.shape().unknown_rank());
2083       } else if (node.op() == "Square" || node.op() == "AddN") {
2084         // These nodes are in the fanout of x: their shapes should be unknown.
2085         EXPECT_TRUE(in_prop.shape().unknown_rank());
2086         EXPECT_TRUE(out_prop.shape().unknown_rank());
2087       }
2088     }
2089   }
2090   {
2091     // Optimistic shape analysis: the shape of fed ports should be derived from
2092     // the shape of the fanin.
2093     GraphProperties properties(item);
2094     Status s = properties.InferStatically(true);
2095     TF_ASSERT_OK(s);
2096     for (const auto& node : item.graph.node()) {
2097       if (node.op() == "Square" || node.op() == "AddN") {
2098         const auto in_props = properties.GetInputProperties(node.name());
2099         EXPECT_EQ(1, in_props.size());
2100         const OpInfo::TensorProperties& in_prop = in_props[0];
2101         EXPECT_EQ(DT_FLOAT, in_prop.dtype());
2102         EXPECT_FALSE(in_prop.shape().unknown_rank());
2103         EXPECT_EQ(2, in_prop.shape().dim_size());
2104         const auto out_props = properties.GetOutputProperties(node.name());
2105         EXPECT_EQ(1, out_props.size());
2106         const OpInfo::TensorProperties& out_prop = out_props[0];
2107         EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
2108         EXPECT_EQ(in_prop.shape().DebugString(),
2109                   out_prop.shape().DebugString());
2110       }
2111     }
2112   }
2113 }
2114 
TEST_F(GraphPropertiesTest,Performance)2115 TEST_F(GraphPropertiesTest, Performance) {
2116   // Load a large graph with many nested loops to make sure we can infer shapes
2117   // quickly.
2118   GrapplerItem item;
2119   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
2120                                  "large_graph.pbtxt.html");
2121   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
2122   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
2123       &item.graph,
2124       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
2125       true));
2126 
2127   GraphProperties properties(item);
2128   TF_ASSERT_OK(properties.InferStatically(false));
2129 }
2130 
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)2131 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
2132   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2133   Output a =
2134       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2135                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2136   auto shp = ops::Shape(s.WithOpName("shape"), {a});
2137 
2138   Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
2139   Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
2140   Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
2141 
2142   Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
2143   Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
2144 
2145   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2146   Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
2147   Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
2148 
2149   GrapplerItem item;
2150   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2151 
2152   GraphProperties properties(item);
2153   TF_ASSERT_OK(properties.InferStatically(false));
2154   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2155   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2156   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2157   EXPECT_EQ(2, shape_a.dim_size());
2158   EXPECT_EQ(1, shape_o1.dim_size());
2159   EXPECT_EQ(1, shape_o2.dim_size());
2160   EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2161   EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2162 }
2163 
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2164 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2165   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2166   Output placeholder =
2167       ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2168                        ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2169   auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2170 
2171   Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2172   Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2173   Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2174 
2175   Output slice =
2176       ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2177                         stride, ops::StridedSlice::ShrinkAxisMask(1));
2178 
2179   GrapplerItem item;
2180   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2181 
2182   // Without aggressive shape inference, it cannot infer output value of
2183   // StridedSlice with ShrinkAxisMask.
2184   {
2185     GraphProperties properties(item);
2186     TF_ASSERT_OK(properties.InferStatically(
2187         /*assume_valid_feeds=*/false,
2188         /*aggressive_shape_inference=*/false,
2189         /*include_tensor_values=*/true));
2190     EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2191   }
2192 
2193   // InferStatically with aggressive shape inference can infer output value of
2194   // StridedSlice with ShrinkAxisMask.
2195   {
2196     GraphProperties properties(item);
2197     TF_ASSERT_OK(properties.InferStatically(
2198         /*assume_valid_feeds=*/false,
2199         /*aggressive_shape_inference=*/true,
2200         /*include_tensor_values=*/true));
2201     EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2202     const auto slice_value =
2203         properties.GetOutputProperties("slice").at(0).value();
2204     ExpectTensorValues({5}, slice_value);
2205   }
2206 }
2207 
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2208 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2209   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2210   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2211   Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2212   Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2213 
2214   Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2215   Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2216   Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2217   Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2218   Output c_plus_b_plus_2a =
2219       ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2220 
2221   GrapplerItem item;
2222   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2223   GraphProperties properties(item);
2224   TF_ASSERT_OK(properties.InferStatically(
2225       /*assume_valid_feeds=*/false,
2226       /*aggressive_shape_inference=*/true,
2227       /*include_tensor_values=*/true));
2228 
2229   // Check output shapes and values.
2230   const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2231   EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2232   EXPECT_TRUE(a_plus_one_prop.has_value());
2233   ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2234 
2235   const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2236   EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2237   EXPECT_TRUE(a_plus_a_prop.has_value());
2238   ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2239 
2240   const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2241   EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2242   EXPECT_TRUE(b_plus_2a_prop.has_value());
2243   ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2244 
2245   const auto& c_plus_b_plus_2a_prop =
2246       properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2247   EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2248   EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2249   ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2250 }
2251 
TEST_F(GraphPropertiesTest,ShapeAnnotation)2252 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2253   GrapplerItem item;
2254   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2255                    .Attr("dtype", DT_FLOAT)
2256                    .Attr("shape", PartialTensorShape({-1, -1}))
2257                    .Finalize(item.graph.add_node()));
2258   // Annotate shapes.
2259   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2260                    .Attr("dtype", DT_FLOAT)
2261                    .Attr("_same_output_for_iterations", true)
2262                    .Attr("_output_shape_vector", {TensorShape({5, 7})})
2263                    .Input("Input", 0, DT_FLOAT)
2264                    .Finalize(item.graph.add_node()));
2265   {
2266     GraphProperties properties(item);
2267     // Without aggressive_shape_inference, ignore annotated information.
2268     TF_ASSERT_OK(properties.InferStatically(
2269         /*assume_valid_feeds=*/false,
2270         /*aggressive_shape_inference=*/false,
2271         /*include_tensor_values=*/true));
2272     const auto props = properties.GetOutputProperties("Identity");
2273     EXPECT_EQ(1, props.size());
2274     const OpInfo::TensorProperties& prop = props[0];
2275     EXPECT_EQ(DT_FLOAT, prop.dtype());
2276     EXPECT_EQ(2, prop.shape().dim_size());
2277     // Get unknown shapes without using annotated information.
2278     EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2279   }
2280   {
2281     GraphProperties properties(item);
2282     // Use annotated information.
2283     TF_ASSERT_OK(properties.InferStatically(
2284         /*assume_valid_feeds=*/false,
2285         /*aggressive_shape_inference=*/true,
2286         /*include_tensor_values=*/true));
2287     const auto props = properties.GetOutputProperties("Identity");
2288     EXPECT_EQ(1, props.size());
2289     const OpInfo::TensorProperties& prop = props[0];
2290     EXPECT_EQ(DT_FLOAT, prop.dtype());
2291     EXPECT_EQ(2, prop.shape().dim_size());
2292     // Update output shape using annotated shapes.
2293     EXPECT_EQ("float: [5,7]", PropToString(prop));
2294   }
2295 }
2296 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2297 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2298   GrapplerItem item;
2299   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2300                    .Attr("dtype", DT_FLOAT)
2301                    .Attr("shape", PartialTensorShape({-1, 100}))
2302                    .Finalize(item.graph.add_node()));
2303   // Annotate shapes.
2304   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2305                    .Attr("dtype", DT_FLOAT)
2306                    .Attr("_same_output_for_iterations", true)
2307                    .Attr("_output_shape_vector", {TensorShape({10, 100})})
2308                    .Input("Input", 0, DT_FLOAT)
2309                    .Finalize(item.graph.add_node()));
2310   GraphProperties properties(item);
2311   // Use annotated information.
2312   TF_ASSERT_OK(properties.InferStatically(
2313       /*assume_valid_feeds=*/false,
2314       /*aggressive_shape_inference=*/true,
2315       /*include_tensor_values=*/true));
2316   const auto props = properties.GetOutputProperties("Identity");
2317   EXPECT_EQ(1, props.size());
2318   const OpInfo::TensorProperties& prop = props[0];
2319   EXPECT_EQ(DT_FLOAT, prop.dtype());
2320   EXPECT_EQ(2, prop.shape().dim_size());
2321   // Compatible shapes. Update output shape using annotated shapes.
2322   EXPECT_EQ("float: [10,100]", PropToString(prop));
2323 }
2324 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2325 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2326   GrapplerItem item;
2327   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2328                    .Attr("dtype", DT_FLOAT)
2329                    .Attr("shape", PartialTensorShape({-1, 100}))
2330                    .Finalize(item.graph.add_node()));
2331   // Annotate shapes.
2332   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2333                    .Attr("dtype", DT_FLOAT)
2334                    .Attr("_same_output_for_iterations", true)
2335                    .Attr("_output_shape_vector", {TensorShape({10, 10})})
2336                    .Input("Input", 0, DT_FLOAT)
2337                    .Finalize(item.graph.add_node()));
2338   GraphProperties properties(item);
2339   // Use annotated information.
2340   TF_ASSERT_OK(properties.InferStatically(
2341       /*assume_valid_feeds=*/false,
2342       /*aggressive_shape_inference=*/true,
2343       /*include_tensor_values=*/true));
2344   const auto props = properties.GetOutputProperties("Identity");
2345   EXPECT_EQ(1, props.size());
2346   const OpInfo::TensorProperties& prop = props[0];
2347   EXPECT_EQ(DT_FLOAT, prop.dtype());
2348   EXPECT_EQ(2, prop.shape().dim_size());
2349   // Incompatible shapes. Do not use annotated shapes.
2350   EXPECT_EQ("float: [-1,100]", PropToString(prop));
2351 }
2352 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2353 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2354   GrapplerItem item;
2355   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2356                    .Attr("dtype", DT_FLOAT)
2357                    .Attr("shape", PartialTensorShape({-1, -1}))
2358                    .Finalize(item.graph.add_node()));
2359   // Annotate shapes.
2360   TF_ASSERT_OK(
2361       NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2362           .Attr("_same_output_for_iterations", true)
2363           .Attr("_output_shape_vector", {TensorShape({10, 100})})
2364           .Input("Input", 0, DT_FLOAT)
2365           .Finalize(item.graph.add_node()));
2366   GraphProperties properties(item);
2367   // Use annotated information.
2368   TF_ASSERT_OK(properties.InferStatically(
2369       /*assume_valid_feeds=*/false,
2370       /*aggressive_shape_inference=*/true,
2371       /*include_tensor_values=*/true));
2372   const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2373   EXPECT_EQ(1, props.size());
2374   const OpInfo::TensorProperties& prop = props[0];
2375   EXPECT_EQ(DT_FLOAT, prop.dtype());
2376   EXPECT_EQ(2, prop.shape().dim_size());
2377   EXPECT_EQ("float: [10,100]", PropToString(prop));
2378 }
2379 
TEST_F(GraphPropertiesTest,PartitionedCallOp)2380 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2381   Scope root = Scope::NewRootScope().ExitOnError();
2382   FunctionDefLibrary library;
2383   FunctionDef called_func = FunctionDefHelper::Create(
2384       "identity_function",
2385       /*in_def=*/{"arg0: int32"},
2386       /*out_def=*/{"ret0: int32"},
2387       /*attr_def=*/{},
2388       {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2389       /*ret_def=*/{{"ret0", "Identity:output:0"}});
2390   *library.add_function() = called_func;
2391   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2392 
2393   Output in = ops::Const(root, {3, 1, 2, 0});
2394   NameAttrList b_name_attr;
2395   b_name_attr.set_name("identity_function");
2396   ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2397                             b_name_attr);
2398 
2399   GrapplerItem item;
2400   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2401 
2402   GraphProperties properties(item);
2403   TF_ASSERT_OK(properties.InferStatically(
2404       /*assume_valid_feeds=*/true,
2405       /*aggressive_shape_inference=*/false,
2406       /*include_tensor_values=*/true));
2407 
2408   EXPECT_EQ("int32: [4]",
2409             PropToString(properties.GetOutputProperties("identity_call")[0]));
2410 }
2411 
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2412 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2413   auto f = FunctionDefHelper::Create(
2414       // Name
2415       "FunctionWhichAdds",
2416       // Inputs
2417       {"arg0: int32", "arg1: int32"},
2418       // Outputs
2419       {"ret0: int32"},
2420       /*attr_def=*/{},
2421       // Nodes
2422       {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2423       /*ret_def=*/{{"ret0", "a:z:0"}});
2424 
2425   FunctionDefLibrary function_lib;
2426   function_lib.add_function()->Swap(&f);
2427   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2428   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2429 
2430   PartialTensorShape input_shape({2, 2, -1});
2431   Output in1 =
2432       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2433   Output in2 =
2434       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2435   NameAttrList b_name_attr;
2436   b_name_attr.set_name("FunctionWhichAdds");
2437   ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2438                             b_name_attr);
2439 
2440   GrapplerItem item;
2441   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2442 
2443   GraphProperties properties(item);
2444   TF_ASSERT_OK(properties.InferStatically(
2445       /*assume_valid_feeds=*/true,
2446       /*aggressive_shape_inference=*/false,
2447       /*include_tensor_values=*/true));
2448 
2449   EXPECT_EQ("int32: [2,2,-1]",
2450             PropToString(properties.GetOutputProperties("add_call")[0]));
2451 }
2452 
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2453 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2454   // A function, which we cannot infer output shape statically.
2455   auto f = FunctionDefHelper::Create(
2456       // Name
2457       "FuncShapeCannotBeInferred",
2458       // Inputs
2459       {},
2460       // Outputs
2461       {"output: float"},
2462       // Attrs
2463       {},
2464       // Nodes
2465       {
2466           // Placeholder without shape attr; unknown rank.
2467           {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2468       },
2469       // Returns
2470       {{"output", "p:output:0"}});
2471   FunctionDefLibrary function_lib;
2472   function_lib.add_function()->Swap(&f);
2473   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2474   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2475   tensorflow::Node* func_op;
2476   TensorShapeProto output_shape;
2477   output_shape.set_unknown_rank(false);
2478   output_shape.add_dim()->set_size(1);
2479   output_shape.add_dim()->set_size(2);
2480   output_shape.add_dim()->set_size(3);
2481   output_shape.add_dim()->set_size(4);
2482   // The function node, f, includes shape annotation.
2483   TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2484                                        s.graph()->op_registry())
2485                    .Attr("_execution_count", 1)
2486                    .Attr("_same_output_for_iterations", true)
2487                    .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2488                    .Attr("_output_shape_vector", {output_shape})
2489                    .Finalize(s.graph(), &func_op));
2490   GrapplerItem item;
2491   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2492 
2493   // InferStatically with aggressive_shape_inference would fail to infer
2494   // the output shape of the node f.
2495   {
2496     GraphProperties properties(item);
2497     TF_ASSERT_OK(properties.InferStatically(
2498         /*assume_valid_feeds=*/false,
2499         /*aggressive_shape_inference=*/false,
2500         /*include_tensor_values=*/false));
2501     const auto out_props = properties.GetOutputProperties("f");
2502     const OpInfo::TensorProperties out_prop0 = out_props[0];
2503     EXPECT_EQ("float: ?", PropToString(out_prop0));
2504   }
2505   // With aggressive_shape_inference, it skips recursively callying
2506   // InferStatically for the function node and outputs annotated shape info.
2507   {
2508     GraphProperties properties(item);
2509     TF_ASSERT_OK(properties.InferStatically(
2510         /*assume_valid_feeds=*/false,
2511         /*aggressive_shape_inference=*/true,
2512         /*include_tensor_values=*/true));
2513     const auto out_props = properties.GetOutputProperties("f");
2514     const OpInfo::TensorProperties out_prop0 = out_props[0];
2515     EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2516   }
2517 }
2518 
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2519 TEST_F(GraphPropertiesTest,
2520        SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2521   GrapplerItem item;
2522   // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2523   // used for two reshape ops. One reshape op is segment_ids input to
2524   // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2525   // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2526   // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2527   // SymbolicShapeRefiner.
2528   // This dim value (10), however, should not affect the other reshape op, even
2529   // though it shares the shape input; -1 in the shape input of Reshape op is
2530   // a special case of computed output dim, not unknown dim.
2531   // data and num_segments are inputs to UnsortedSegmenetSum.
2532 
2533   TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2534                    .Attr("dtype", DT_FLOAT)
2535                    .Attr("shape", TensorShape({10, 10, 10, 10}))
2536                    .Finalize(item.graph.add_node()));
2537   Tensor num_segments(DT_INT32, TensorShape({}));
2538   // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2539   // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2540   // output to form shape vector [-1, 10] to Reshape.
2541   test::FillIota<int>(&num_segments, 3);
2542   TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2543                    .Attr("dtype", DT_INT32)
2544                    .Attr("value", num_segments)
2545                    .Finalize(item.graph.add_node()));
2546   Tensor minus_one(DT_INT32, TensorShape({1}));
2547   test::FillIota<int>(&minus_one, -1);
2548   TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2549                    .Attr("dtype", DT_INT32)
2550                    .Attr("value", minus_one)
2551                    .Finalize(item.graph.add_node()));
2552   Tensor plus_ten(DT_INT32, TensorShape({1}));
2553   test::FillIota<int>(&plus_ten, 10);
2554   TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2555                    .Attr("dtype", DT_INT32)
2556                    .Attr("value", plus_ten)
2557                    .Finalize(item.graph.add_node()));
2558   Tensor axis(DT_INT32, TensorShape({}));
2559   test::FillIota<int>(&axis, -1);
2560   TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2561                    .Attr("dtype", DT_INT32)
2562                    .Attr("value", axis)
2563                    .Finalize(item.graph.add_node()));
2564   std::vector<NodeDefBuilder::NodeOut> inputs(2);
2565   inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2566   inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2567   TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2568                    .Input(inputs)
2569                    .Input("axis", 0, DT_INT32)
2570                    .Attr("N", 2)
2571                    .Attr("T", DT_INT32)
2572                    .Attr("Tidx", DT_INT32)
2573                    .Finalize(item.graph.add_node()));
2574   TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2575                    .Attr("dtype", DT_FLOAT)
2576                    .Finalize(item.graph.add_node()));
2577   TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2578                    .Input("segment_ids_", 0, DT_FLOAT)
2579                    .Attr("T", DT_FLOAT)
2580                    .Attr("out_type", DT_INT32)
2581                    .Finalize(item.graph.add_node()));
2582   TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2583                    .Input("segment_ids_", 0, DT_FLOAT)
2584                    .Input("concat", 0, DT_INT32)
2585                    .Attr("T", DT_FLOAT)
2586                    .Attr("Tshape", DT_INT32)
2587                    .Finalize(item.graph.add_node()));
2588   // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2589   // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2590   // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2591   // assign 10 from data shape to the unknown dim in segment_ids.
2592   TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2593                    .Input("data", 0, DT_FLOAT)
2594                    .Input("segment_ids", 0, DT_INT32)
2595                    .Input("num_segments", 0, DT_INT32)
2596                    .Attr("T", DT_FLOAT)
2597                    .Attr("Tindices", DT_INT32)
2598                    .Attr("Tnumsegments", DT_INT32)
2599                    .Finalize(item.graph.add_node()));
2600   // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2601   // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2602   TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2603                    .Attr("dtype", DT_FLOAT)
2604                    .Finalize(item.graph.add_node()));
2605   TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2606                    .Input("x1", 0, DT_FLOAT)
2607                    .Input("concat", 0, DT_INT32)
2608                    .Attr("T", DT_FLOAT)
2609                    .Attr("Tshape", DT_INT32)
2610                    .Finalize(item.graph.add_node()));
2611 
2612   GraphProperties properties(item);
2613   TF_ASSERT_OK(properties.InferStatically(true));
2614   const auto& y1_output_properties = properties.GetOutputProperties("y1");
2615   // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2616   // The first dimension should not be 10.
2617   EXPECT_EQ(y1_output_properties.size(), 1);
2618   EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2619   EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2620   EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2621 }
2622 
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2623 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2624   // Make a dummy InferenceContext.
2625   NodeDef node_def;
2626   OpRegistrationData op_reg_data;
2627   OpDefBuilder b("dummy");
2628   CHECK(b.Finalize(&op_reg_data).ok());
2629   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2630       input_handle_shapes_and_types;
2631   InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2632                       /*input_shapes=*/{},
2633                       /*input_tensors=*/{},
2634                       /*input_tensors_as_shapes=*/{},
2635                       std::move(input_handle_shapes_and_types));
2636 
2637   // ShapeHandles for testing.
2638   ShapeHandle fully_defined_vector = ic.MakeShape(
2639       {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2640   ShapeHandle vector_with_unknown = ic.MakeShape(
2641       {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2642   // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2643   // graph_properties.cc
2644   ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2645       {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2646   ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2647 
2648   // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2649   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2650       &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2651   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2652       &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2653 
2654   // Non-integer data type.
2655   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2656       &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2657 
2658   // tensor_as_shape including Unknown or UnknownFromConst.
2659   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2660       &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2661   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2662       &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2663   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2664       &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2665   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2666       &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2667 
2668   // shape rank > 1.
2669   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2670       &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2671 }
2672 }  // namespace
2673 }  // namespace grappler
2674 }  // namespace tensorflow
2675