1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50
51 REGISTER_OP("TestOpWithNoInferenceFn")
52 .Input("x: float")
53 .Output("y: float")
54 .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59
60 class GraphPropertiesTest : public ::testing::Test {
61 public:
SetUp()62 void SetUp() override {
63 // Provision a single machine with 3 cpu cores
64 cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65 TF_ASSERT_OK(cluster_->Provision());
66
67 // This function is simply
68 // out = Fill(shape, value), but
69 // Fill requires values in the shape input, not just shape of it, to infer
70 // output shape.
71 auto f = FunctionDefHelper::Create(
72 // Name
73 "MyFillFunc",
74 // Inputs
75 {"shape: int32", "value: float"},
76 // Outputs
77 {"out: float"},
78 // Attrs
79 {},
80 // Nodes
81 {
82 {{"a"},
83 "Fill",
84 {"shape", "value"},
85 {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86 },
87 // Returns
88 {{"out", "a:output:0"}});
89 function_lib_.add_function()->Swap(&f);
90 }
91
TearDown()92 void TearDown() override {
93 TF_ASSERT_OK(cluster_->Shutdown());
94 cluster_.reset();
95 }
96
97 protected:
98 // Returns a string form of <p>, suitable for comparing type and shape.
99 // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100 string PropToString(const OpInfo::TensorProperties& p) {
101 string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102 if (p.shape().unknown_rank()) {
103 strings::StrAppend(&s, "?");
104 } else {
105 strings::StrAppend(&s, "[");
106 for (int i = 0; i < p.shape().dim_size(); ++i) {
107 strings::StrAppend(&s, i == 0 ? "" : ",",
108 std::max<int64>(p.shape().dim(i).size(), -1));
109 }
110 strings::StrAppend(&s, "]");
111 }
112 return s;
113 }
114
115 // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116 // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)117 void ExpectTensorValues(const std::vector<int64>& expected,
118 const TensorProto& tensor_proto_to_compare) {
119 Tensor tensor;
120 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121 EXPECT_EQ(expected.size(), tensor.NumElements());
122 // We're interested in only integer tensors as only shapes are exported as
123 // graph properties values.
124 ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125 if (tensor.dtype() == DT_INT32) {
126 for (int i = 0; i < tensor.NumElements(); i++) {
127 EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128 }
129 } else {
130 for (int i = 0; i < tensor.NumElements(); i++) {
131 EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
132 }
133 }
134 }
135
136 // Compare values of float (DT_FLOAT) tensor against expected
137 // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138 void ExpectFloatTensorValues(const std::vector<float>& expected,
139 const TensorProto& tensor_proto_to_compare) {
140 Tensor tensor;
141 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142 EXPECT_EQ(expected.size(), tensor.NumElements());
143 ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144 for (int i = 0; i < tensor.NumElements(); i++) {
145 EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146 }
147 }
148
149 std::unique_ptr<SingleMachine> cluster_;
150 FunctionDefLibrary function_lib_;
151 };
152
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155 cluster_->GetDeviceNames());
156 GrapplerItem item;
157 CHECK(fake_input.NextItem(&item));
158
159 GraphProperties properties(item);
160 Status s = properties.InferStatically(true);
161 TF_ASSERT_OK(s);
162
163 for (const auto& node : item.graph.node()) {
164 if (node.op() == "RandomStandardNormal") {
165 // The node has one input (the shape of the tensor to generate).
166 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167 // The const node has one output.
168 const auto props = properties.GetOutputProperties(node.name());
169 EXPECT_EQ(1, props.size());
170 const OpInfo::TensorProperties& prop = props[0];
171 EXPECT_EQ(DT_FLOAT, prop.dtype());
172 EXPECT_FALSE(prop.shape().unknown_rank());
173 EXPECT_EQ(2, prop.shape().dim_size());
174 EXPECT_EQ(10, prop.shape().dim(0).size());
175 EXPECT_EQ(1, prop.shape().dim(1).size());
176 } else if (node.op() == "AddN") {
177 const auto in_props = properties.GetInputProperties(node.name());
178 EXPECT_EQ(1, in_props.size());
179 const OpInfo::TensorProperties& in_prop = in_props[0];
180 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181 EXPECT_FALSE(in_prop.shape().unknown_rank());
182 EXPECT_EQ(2, in_prop.shape().dim_size());
183 EXPECT_EQ(10, in_prop.shape().dim(0).size());
184 EXPECT_EQ(1, in_prop.shape().dim(1).size());
185 const auto out_props = properties.GetOutputProperties(node.name());
186 EXPECT_EQ(1, out_props.size());
187 EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188 EXPECT_EQ(in_prop.shape().DebugString(),
189 out_props[0].shape().DebugString());
190 }
191 }
192 }
193
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196 cluster_->GetDeviceNames());
197 GrapplerItem item;
198 CHECK(fake_input.NextItem(&item));
199
200 GraphProperties properties(item);
201 Status s = properties.InferStatically(true);
202 TF_ASSERT_OK(s);
203
204 for (const auto& node : item.graph.node()) {
205 if (node.op() == "RandomStandardNormal") {
206 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207 const auto props = properties.GetOutputProperties(node.name());
208 properties.ClearOutputProperties(node.name());
209 const auto cleared_props = properties.GetOutputProperties(node.name());
210 EXPECT_TRUE(cleared_props.empty());
211 } else if (node.op() == "AddN") {
212 const auto in_props = properties.GetInputProperties(node.name());
213 EXPECT_EQ(1, in_props.size());
214 properties.ClearInputProperties(node.name());
215 const auto cleared_props = properties.GetInputProperties(node.name());
216 EXPECT_TRUE(cleared_props.empty());
217 }
218 }
219 }
220
TEST_F(GraphPropertiesTest,DynamicProperties)221 TEST_F(GraphPropertiesTest, DynamicProperties) {
222 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223 cluster_->GetDeviceNames());
224 GrapplerItem item;
225 CHECK(fake_input.NextItem(&item));
226
227 GraphProperties properties(item);
228 TF_ASSERT_OK(cluster_->Initialize(item));
229 Status s = properties.InferDynamically(cluster_.get());
230 TF_ASSERT_OK(s);
231
232 for (const auto& node : item.graph.node()) {
233 if (node.op() == "RandomStandardNormal") {
234 // The random node is missing from the cost graph (why ?)
235 EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
236 } else if (node.op() == "AddN") {
237 // Since the random node is missing, we can't infer the input properties
238 // of the first AddN node. The other AddN nodes have the expected
239 // properties.
240 if (node.name() == "AddN") {
241 const auto props = properties.GetInputProperties(node.name());
242 EXPECT_EQ(1, props.size());
243 const OpInfo::TensorProperties& prop = props[0];
244 EXPECT_EQ(DT_INVALID, prop.dtype());
245 EXPECT_TRUE(prop.shape().unknown_rank());
246 } else {
247 const auto props = properties.GetInputProperties(node.name());
248 EXPECT_EQ(1, props.size());
249 const OpInfo::TensorProperties& prop = props[0];
250 EXPECT_EQ(DT_FLOAT, prop.dtype());
251 EXPECT_FALSE(prop.shape().unknown_rank());
252 EXPECT_EQ(2, prop.shape().dim_size());
253 EXPECT_EQ(10, prop.shape().dim(0).size());
254 EXPECT_EQ(1, prop.shape().dim(1).size());
255 const auto out_props = properties.GetOutputProperties(node.name());
256 #ifdef INTEL_MKL
257 if (!NativeFormatEnabled()) {
258 // Intel MKL AddN OP would have two output.
259 // One is the real output, another one for MKL metadata
260 EXPECT_EQ(2, out_props.size());
261 } else {
262 EXPECT_EQ(1, out_props.size());
263 }
264 #else
265 EXPECT_EQ(1, out_props.size());
266 #endif // INTEL_MKL
267 string prop_str;
268 ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
269 string out_prop_str;
270 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
271 &out_prop_str);
272 EXPECT_EQ(prop_str, out_prop_str);
273 }
274 }
275 }
276 }
277
TEST_F(GraphPropertiesTest,Variables)278 TEST_F(GraphPropertiesTest, Variables) {
279 GrapplerItem item;
280 TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
281 .Attr("dtype", DT_FLOAT)
282 .Attr("shape", TensorShape({3, 7}))
283 .Finalize(item.graph.add_node()));
284 item.fetch.push_back("Var");
285
286 Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
287 test::FillIota<float>(&initial_val, 0);
288 TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
289 .Attr("dtype", DT_FLOAT)
290 .Attr("value", initial_val)
291 .Finalize(item.graph.add_node()));
292 TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
293 .Input("Var", 0, DT_FLOAT_REF)
294 .Input("InitialVal", 0, DT_FLOAT)
295 .Finalize(item.graph.add_node()));
296 item.init_ops.push_back("InitVar");
297
298 {
299 GraphProperties static_properties(item);
300 TF_ASSERT_OK(static_properties.InferStatically(false));
301
302 const auto props = static_properties.GetOutputProperties("Var");
303 EXPECT_EQ(1, props.size());
304 const OpInfo::TensorProperties& prop = props[0];
305 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
306 EXPECT_FALSE(prop.shape().unknown_rank());
307 EXPECT_EQ(2, prop.shape().dim_size());
308 EXPECT_EQ(3, prop.shape().dim(0).size());
309 EXPECT_EQ(7, prop.shape().dim(1).size());
310 }
311 {
312 TF_ASSERT_OK(cluster_->Initialize(item));
313 GraphProperties dynamic_properties(item);
314 TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
315
316 const auto props = dynamic_properties.GetOutputProperties("Var");
317 EXPECT_EQ(1, props.size());
318 const OpInfo::TensorProperties& prop = props[0];
319 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
320 EXPECT_FALSE(prop.shape().unknown_rank());
321 EXPECT_EQ(2, prop.shape().dim_size());
322 EXPECT_EQ(3, prop.shape().dim(0).size());
323 EXPECT_EQ(7, prop.shape().dim(1).size());
324 }
325 }
326
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)327 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
328 GrapplerItem item;
329 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
330 .Attr("dtype", DT_FLOAT)
331 .Attr("shape", TensorShape({3, 7}))
332 .Finalize(item.graph.add_node()));
333 TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
334 .Attr("T", DT_RESOURCE)
335 .Attr("frame_name", "while_context")
336 .Attr("is_constant", true)
337 .Attr("parallel_iterations", 10)
338 .Input("Var", 0, DT_RESOURCE)
339 .Finalize(item.graph.add_node()));
340 TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
341 .Attr("dtype", DT_FLOAT)
342 .Input("Enter", 0, DT_RESOURCE)
343 .Finalize(item.graph.add_node()));
344
345 GraphProperties properties(item);
346 TF_ASSERT_OK(properties.InferStatically(false));
347 const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
348 EXPECT_EQ(1, props.size());
349 const OpInfo::TensorProperties& prop = props[0];
350 EXPECT_EQ(DT_FLOAT, prop.dtype());
351 EXPECT_FALSE(prop.shape().unknown_rank());
352 EXPECT_EQ(2, prop.shape().dim_size());
353 EXPECT_EQ(3, prop.shape().dim(0).size());
354 EXPECT_EQ(7, prop.shape().dim(1).size());
355 }
356
TEST_F(GraphPropertiesTest,VarHandles)357 TEST_F(GraphPropertiesTest, VarHandles) {
358 GrapplerItem item;
359 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
360 .Attr("dtype", DT_FLOAT)
361 .Attr("shape", TensorShape({3, 7}))
362 .Finalize(item.graph.add_node()));
363
364 TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
365 .Attr("dtype", DT_FLOAT)
366 .Input("Var", 0, DT_RESOURCE)
367 .Finalize(item.graph.add_node()));
368
369 GraphProperties properties(item);
370 TF_ASSERT_OK(properties.InferStatically(false));
371
372 const auto props = properties.GetOutputProperties("VarRead");
373 EXPECT_EQ(1, props.size());
374 const OpInfo::TensorProperties& prop = props[0];
375 EXPECT_EQ(DT_FLOAT, prop.dtype());
376 EXPECT_FALSE(prop.shape().unknown_rank());
377 EXPECT_EQ(2, prop.shape().dim_size());
378 EXPECT_EQ(3, prop.shape().dim(0).size());
379 EXPECT_EQ(7, prop.shape().dim(1).size());
380 }
381
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)382 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
383 // Test graph is first generated in python using:
384 /*
385 i0 = tf.constant(0)
386 v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
387 def cond(i, x):
388 return i < 3
389 def body(i, x):
390 return i + 1, x + x
391 v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
392 */
393 // and then modified by hand such that the ReadVariableOp is inside the loop
394 // body instead of outside the while loop (which is the case when constructed
395 // using the python API), such that we have the following pattern: VarHandleOp
396 // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
397 // DT_RESOURCE is passed all the way until ReadVariableOp.
398 GrapplerItem item;
399 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
400 "while_loop_var_handle_op.pbtxt");
401 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
402 GraphProperties properties(item);
403 TF_ASSERT_OK(properties.InferStatically(false));
404
405 std::vector<string> resource_nodes{
406 "loop_var", "while/Enter", "while/Merge", "while/Switch",
407 "while/Identity", "while/NextIteration", "while/Exit"};
408 for (const string& node : resource_nodes) {
409 const auto props = properties.GetOutputProperties(node);
410 EXPECT_GE(props.size(), 1); // Merge has 2 outputs.
411 EXPECT_EQ("resource: []", PropToString(props[0]));
412 }
413
414 // After ReadVariableOp, the shape should be recovered.
415 const auto props = properties.GetOutputProperties("while/ReadVariableOp");
416 EXPECT_EQ(1, props.size());
417 EXPECT_EQ("int32: []", PropToString(props[0]));
418 }
419
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)420 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
421 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
422 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
423 auto dequeue1 =
424 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
425
426 GrapplerItem item;
427 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
428
429 GraphProperties properties(item);
430 TF_ASSERT_OK(properties.InferStatically(false));
431
432 const auto props1 = properties.GetOutputProperties("Dequeue1");
433 ASSERT_EQ(1, props1.size());
434 EXPECT_EQ("float: ?", PropToString(props1[0]));
435 }
436
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)437 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
438 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
440 ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
441 auto dequeue1 =
442 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
443
444 GrapplerItem item;
445 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
446
447 GraphProperties properties(item);
448 TF_ASSERT_OK(properties.InferStatically(false));
449
450 const auto props1 = properties.GetOutputProperties("Dequeue1");
451 ASSERT_EQ(1, props1.size());
452 EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
453 }
454
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)455 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
456 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
457 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
458 ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
459 auto dequeue1 =
460 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
461
462 GrapplerItem item;
463 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
464
465 GraphProperties properties(item);
466 TF_ASSERT_OK(properties.InferStatically(false));
467
468 const auto props1 = properties.GetOutputProperties("Dequeue1");
469 ASSERT_EQ(1, props1.size());
470 EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
471 }
472
TEST_F(GraphPropertiesTest,Queues)473 TEST_F(GraphPropertiesTest, Queues) {
474 // Create a graph with known input shapes, and propagate the shapes through a
475 // couple of queues.
476 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
477
478 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
479 Output rnd =
480 ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
481 Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
482 auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
483 auto dequeue1 =
484 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
485
486 auto q2 =
487 ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
488 Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
489 auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
490 auto dequeue2 =
491 ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
492
493 auto q4 =
494 ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
495 auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
496 auto enqueue4_2 =
497 ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
498 auto dequeue4 =
499 ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
500
501 // Create a queue that takes in three tensors.
502 auto q5 = ops::RandomShuffleQueue(
503 root.WithOpName("Queue5"),
504 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
505 Output rnd2 =
506 ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
507 Output rnd3 =
508 ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
509 auto enqueue5 =
510 ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
511 auto dequeue5 = ops::QueueDequeue(
512 root.WithOpName("Dequeue5"), q5,
513 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
514
515 GrapplerItem item;
516 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
517
518 GraphProperties properties(item);
519 TF_ASSERT_OK(properties.InferStatically(false));
520
521 const auto props1 = properties.GetOutputProperties("Dequeue1");
522 ASSERT_EQ(1, props1.size());
523 EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
524
525 const auto props2 = properties.GetOutputProperties("Dequeue2");
526 ASSERT_EQ(1, props2.size());
527 EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
528
529 // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
530 // that we merge the 2 properly to determine the shape of the data coming out
531 // of the queue.
532 const auto props4 = properties.GetOutputProperties("Dequeue4");
533 ASSERT_EQ(1, props4.size());
534 EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
535
536 // The dequeue5 op shape is known.
537 const auto props5 = properties.GetOutputProperties("Dequeue5");
538 ASSERT_EQ(3, props5.size());
539 EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
540 EXPECT_EQ("double: [10]", PropToString(props5[1]));
541 EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
542 }
543
TEST_F(GraphPropertiesTest,MergeWithoutLoops)544 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
545 // Test graph produced in python using:
546 /*
547 with tf.Graph().as_default():
548 x = tf.constant(2)
549 y = tf.constant(5)
550 z = tf.ones([1,1,1])
551 def f1(): return tf.concat([z, z], axis=0)
552 def f2(): return tf.concat([z, z], axis=1)
553 r = tf.cond(tf.less(x, y), f1, f2)
554 tf.concat([r, r], axis=2)
555 with open('/tmp/graph.pbtxt', 'w') as f:
556 f.write(str(tf.get_default_graph().as_graph_def()))
557 */
558
559 GrapplerItem item;
560 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
561 "merge_without_loops.pbtxt");
562 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
563 GraphProperties properties(item);
564 TF_ASSERT_OK(properties.InferStatically(false));
565
566 std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
567 std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
568 "float: [1,2,1]"};
569 for (int i = 0; i < nodes.size(); i++) {
570 const auto props = properties.GetOutputProperties(nodes[i]);
571 const OpInfo::TensorProperties& prop = props[0];
572 EXPECT_EQ(DT_FLOAT, prop.dtype());
573 EXPECT_EQ(expected_outputs[i], PropToString(prop));
574 }
575
576 // The "Less" node should be fed by 2 int32 scalar constant values.
577 const auto props = properties.GetInputProperties("Less");
578 EXPECT_EQ(2, props.size());
579 for (int i = 0; i < props.size(); ++i) {
580 EXPECT_EQ(DT_INT32, props[i].dtype());
581 EXPECT_TRUE(props[i].has_value());
582 EXPECT_EQ("int32: []", PropToString(props[i]));
583 }
584 }
585
TEST_F(GraphPropertiesTest,WhileLoop)586 TEST_F(GraphPropertiesTest, WhileLoop) {
587 // Test graph produced in python using:
588 /*
589 with tf.Graph().as_default():
590 i0 = tf.constant(0)
591 m0 = tf.placeholder([-1, 2])
592 c = lambda i, m: i < 10
593 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
594 r = tf.while_loop(
595 c, b, loop_vars=[i0, m0],
596 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
597 with open('/tmp/graph.pbtxt', 'w') as f:
598 f.write(str(tf.get_default_graph().as_graph_def()))
599 */
600
601 GrapplerItem item;
602 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
603 "while_loop.pbtxt");
604 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
605 GraphProperties properties(item);
606 TF_ASSERT_OK(properties.InferStatically(false));
607
608 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
609 "while/Exit_1"};
610 for (const string& node : nodes) {
611 const auto props = properties.GetOutputProperties(node);
612 const OpInfo::TensorProperties& prop = props[0];
613 EXPECT_EQ(DT_FLOAT, prop.dtype());
614 EXPECT_EQ("float: [-1,2]", PropToString(prop));
615 }
616
617 // The loop outputs batch dim should be different from the input batch dim
618 // since we concatenated along the batch dim.
619 auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
620 auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
621 EXPECT_GE(-2, shape_in.dim(0).size());
622 EXPECT_GE(-2, shape_out.dim(0).size());
623 EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
624 }
625
TEST_F(GraphPropertiesTest,NestedLoop)626 TEST_F(GraphPropertiesTest, NestedLoop) {
627 // Test graph produced in python using:
628 /*
629 with tf.Graph().as_default():
630 i0 = tf.constant(0)
631
632 def inner(j, y):
633 def inner_cond(j, y):
634 return j < 3
635
636 def inner_body(j, y):
637 return j+1, tf.concat([y, y], axis=2)
638
639 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
640 shape_invariants=[i0.get_shape(),
641 tf.TensorShape([None, 1, None])])
642
643 def outer_cond(i, x):
644 return i < 3
645
646 def outer_body(i, x):
647 j, y = inner(0, x)
648 return i+1, tf.concat([x, x], axis=0)
649
650 r = tf.while_loop(outer_cond, outer_body,
651 loop_vars=[i0, tf.ones([1, 1, 1])],
652 shape_invariants=[i0.get_shape(),
653 tf.TensorShape([None, 1, None])])
654
655 with open('/tmp/graph.pbtxt', 'w') as f:
656 f.write(str(tf.get_default_graph().as_graph_def()))
657 */
658
659 GrapplerItem item;
660 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
661 "nested_loop.pbtxt");
662 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
663 GraphProperties properties(item);
664 TF_ASSERT_OK(properties.InferStatically(false));
665
666 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
667 "while/Exit_1"};
668 std::vector<string> inner_nodes{"while/while/Merge_1",
669 "while/while/NextIteration_1",
670 "while/while/Exit_1"};
671 for (const string& node : outer_nodes) {
672 const auto props = properties.GetOutputProperties(node);
673 const OpInfo::TensorProperties& prop = props[0];
674 EXPECT_EQ(DT_FLOAT, prop.dtype());
675 EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
676 }
677 for (const string& node : inner_nodes) {
678 const auto props = properties.GetOutputProperties(node);
679 const OpInfo::TensorProperties& prop = props[0];
680 EXPECT_EQ(DT_FLOAT, prop.dtype());
681 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
682 }
683 }
684
TEST_F(GraphPropertiesTest,LoopsAndQueues)685 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
686 // Test graph produced in python using:
687 /*
688 with tf.Graph().as_default():
689 i0 = tf.constant(0)
690 q = tf.FIFOQueue(1, "float")
691
692 def inner(j, y):
693 def inner_cond(j, y):
694 return j < 3
695
696 def inner_body(j, y):
697 return j+1, tf.concat([y, y], axis=0)
698
699 return tf.while_loop(inner_cond, inner_body,
700 loop_vars=[j, y],
701 shape_invariants=[i0.get_shape(),
702 tf.TensorShape(None)])
703
704 def outer_cond(i, x):
705 return i < 3
706
707 def outer_body(i, x):
708 q.enqueue(x)
709 y = tf.concat([x, x], axis=2)
710 inner(0, q.dequeue())
711 return i+1, y
712
713 i, z = tf.while_loop(outer_cond, outer_body,
714 loop_vars=[i0, tf.ones([1, 1, 1])],
715 shape_invariants=[i0.get_shape(),
716 tf.TensorShape([None, 1, None])])
717
718 with open('/tmp/graph.pbtxt', 'w') as f:
719 f.write(str(tf.get_default_graph().as_graph_def()))
720 */
721
722 GrapplerItem item;
723 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
724 "loops_and_queues.pbtxt");
725 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
726 GraphProperties properties(item);
727 TF_ASSERT_OK(properties.InferStatically(false));
728
729 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
730 "while/Exit_1"};
731 std::vector<string> inner_nodes{"while/while/Merge_1",
732 "while/while/NextIteration_1",
733 "while/while/Exit_1"};
734 for (const string& node : outer_nodes) {
735 const auto props = properties.GetOutputProperties(node);
736 const OpInfo::TensorProperties& prop = props[0];
737 EXPECT_EQ(DT_FLOAT, prop.dtype());
738 EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
739 }
740 for (const string& node : inner_nodes) {
741 const auto props = properties.GetOutputProperties(node);
742 const OpInfo::TensorProperties& prop = props[0];
743 EXPECT_EQ(DT_FLOAT, prop.dtype());
744 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
745 }
746 }
747
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)748 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
749 // Test graph produced in python using:
750 /*
751 with tf.Graph().as_default():
752 i0 = tf.constant(0)
753 with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
754 v = tf.get_variable(initializer=i0, name='loop_var')
755
756 def inner(j, y):
757 def inner_cond(j, y):
758 return j < 3
759
760 def inner_body(j, y):
761 return j + 1, y + y
762
763 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
764
765 def outer_cond(i, x):
766 return i < 3
767
768 def outer_body(i, x):
769 y = x + x
770 inner(0, v)
771 return i + 1, y
772
773 v, z = tf.while_loop(outer_cond, outer_body,
774 loop_vars=[v, tf.constant(1)])
775
776 with open('/tmp/graph.pbtxt', 'w') as f:
777 f.write(str(tf.get_default_graph().as_graph_def()))
778 */
779
780 GrapplerItem item;
781 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
782 "loops_and_resource_vars.pbtxt");
783 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
784 GraphProperties properties(item);
785 TF_ASSERT_OK(properties.InferStatically(false));
786
787 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
788 "while/Exit_1"};
789 std::vector<string> inner_nodes{"while/while/Merge_1",
790 "while/while/NextIteration_1",
791 "while/while/Exit_1"};
792 for (const string& node : outer_nodes) {
793 const auto props = properties.GetOutputProperties(node);
794 const OpInfo::TensorProperties& prop = props[0];
795 EXPECT_EQ(DT_INT32, prop.dtype());
796 EXPECT_EQ("int32: []", PropToString(prop));
797 }
798 for (const string& node : inner_nodes) {
799 const auto props = properties.GetOutputProperties(node);
800 const OpInfo::TensorProperties& prop = props[0];
801 EXPECT_EQ(DT_INT32, prop.dtype());
802 EXPECT_EQ("int32: []", PropToString(prop));
803 }
804 }
805
TEST_F(GraphPropertiesTest,QueuesAndLoops)806 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
807 // Test graph produced in python using:
808 /*
809 with tf.Graph().as_default():
810 i0 = tf.constant(0)
811 q0 = tf.FIFOQueue(1, "float")
812 q0.enqueue(tf.ones([2, 2]))
813 q1 = tf.FIFOQueue(1, "float")
814
815 def c(i, m):
816 return i < 10
817
818 def b(i, m):
819 return i+1, tf.concat([m, m], axis=0)
820
821 i, m = tf.while_loop(
822 c, b, loop_vars=[i0, q0.dequeue()],
823 shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
824
825 q1.enqueue(m)
826 v = q1.dequeue();
827 tf.concat([v, v], axis=1)
828 with open('/tmp/graph.pbtxt', 'w') as f:
829 f.write(str(tf.get_default_graph().as_graph_def()))
830 */
831
832 GrapplerItem item;
833 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
834 "queues_and_loops.pbtxt");
835 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
836 GraphProperties properties(item);
837 TF_ASSERT_OK(properties.InferStatically(false));
838
839 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
840 "while/Exit_1"};
841
842 for (const string& node : nodes) {
843 const auto props = properties.GetOutputProperties(node);
844 const OpInfo::TensorProperties& prop = props[0];
845 EXPECT_EQ(DT_FLOAT, prop.dtype());
846 EXPECT_EQ("float: [-1,2]", PropToString(prop));
847 }
848
849 const auto props = properties.GetOutputProperties("concat");
850 const OpInfo::TensorProperties& prop = props[0];
851 EXPECT_EQ(DT_FLOAT, prop.dtype());
852 EXPECT_EQ("float: [-1,4]", PropToString(prop));
853 }
854
TEST_F(GraphPropertiesTest,InferRestoreOpShape)855 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
856 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
857 Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
858 DataType::DT_FLOAT);
859 Output filename =
860 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
861 Output tensor_name =
862 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
863 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
864 DataType::DT_FLOAT);
865 Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
866
867 Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
868 string("256 256 0,128:-"), TensorShape());
869 Output restore_slice =
870 ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
871 shape_and_slice, DataType::DT_FLOAT);
872 Output init_restore_slice =
873 ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
874
875 Output restore_v2 =
876 ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
877 shape_and_slice, DataType::DT_FLOAT);
878 Output init_restore_v2 =
879 ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
880
881 GrapplerItem item;
882 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
883 item.fetch.push_back("init_restore");
884
885 GraphProperties properties(item);
886 TF_ASSERT_OK(properties.InferStatically(false));
887
888 const auto restore_props = properties.GetOutputProperties("restore");
889 const OpInfo::TensorProperties& restore_prop = restore_props[0];
890 EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
891 EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
892
893 const auto restore_slice_props =
894 properties.GetOutputProperties("restore_slice");
895 const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
896 EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
897 EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
898
899 const auto restorev2_props = properties.GetOutputProperties("restore_v2");
900 const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
901 EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
902 EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
903
904 // Check input shapes of assign op are propagated correctly.
905 const auto input_props = properties.GetInputProperties("init_restore");
906 ASSERT_EQ(2, input_props.size());
907 const OpInfo::TensorProperties& input_prop = input_props[1];
908 EXPECT_EQ(DT_FLOAT, input_prop.dtype());
909 EXPECT_EQ("float: [128,256]", PropToString(input_prop));
910 }
911
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)912 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
913 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
914 Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
915 DataType::DT_FLOAT);
916 Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
917 DataType::DT_FLOAT);
918 Output filename =
919 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
920 Output tensor_name =
921 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
922 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
923 DataType::DT_FLOAT);
924 Output init = ops::Assign(s.WithOpName("init"), var, restore);
925 Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
926
927 GrapplerItem item;
928 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
929 item.fetch.push_back("init");
930 item.fetch.push_back("init2");
931
932 GraphProperties properties(item);
933 TF_ASSERT_OK(properties.InferStatically(false));
934
935 const auto props = properties.GetOutputProperties("restore");
936 const OpInfo::TensorProperties& prop = props[0];
937 EXPECT_EQ(DT_FLOAT, prop.dtype());
938 EXPECT_EQ("float: [128,256]", PropToString(prop));
939 }
940
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)941 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
942 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
943 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
944 Output a1 = ops::Identity(s.WithOpName("a1"), a);
945 Output b = ops::Const(s.WithOpName("b"), 99, {});
946 Output b1 = ops::Identity(s.WithOpName("b1"), b);
947 Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
948 Output c1 = ops::Identity(s.WithOpName("c1"), c);
949
950 GrapplerItem item;
951 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
952 GraphProperties properties(item);
953 TF_ASSERT_OK(properties.InferStatically(false));
954
955 // Check output shapes.
956 EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
957 EXPECT_EQ("int32: [2]",
958 PropToString(properties.GetOutputProperties("a1")[0]));
959 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
960 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
961 EXPECT_EQ("int32: [4,4,4]",
962 PropToString(properties.GetOutputProperties("c")[0]));
963 EXPECT_EQ("int32: [4,4,4]",
964 PropToString(properties.GetOutputProperties("c1")[0]));
965
966 // Check has_value.
967 EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
968 EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
969 EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
970 EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
971 EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
972 EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
973 EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
974 EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
975 // Note that we propagate tensor value of only 1D vector and scalar.
976 EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
977
978 // Check values.
979 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
980 ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
981 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
982 ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
983 ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
984 ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
985 std::vector<int64> c_values;
986 for (int i = 0; i < 4 * 4 * 4; i++) {
987 c_values.push_back(1);
988 }
989 ExpectTensorValues({c_values},
990 properties.GetOutputProperties("c")[0].value());
991 ExpectTensorValues({c_values},
992 properties.GetInputProperties("c1")[0].value());
993 ExpectTensorValues({c_values},
994 properties.GetOutputProperties("c1")[0].value());
995 }
996
TEST_F(GraphPropertiesTest,IdentityPassingShape)997 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
998 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
999 Output a = ops::Const(s.WithOpName("a"), 5, {2});
1000 Output b = ops::Identity(s.WithOpName("b"), a);
1001 Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1002 // Fill needs not only e's shape but also the value of e to figure out output
1003 // shape; hence, Identity op (b) should pass a's value as
1004 // output_tensors_as_shape.
1005 Output d = ops::Fill(s.WithOpName("fill"), b, c);
1006
1007 GrapplerItem item;
1008 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1009 GraphProperties properties(item);
1010 TF_ASSERT_OK(properties.InferStatically(false));
1011 const auto out_props = properties.GetOutputProperties("fill");
1012 const OpInfo::TensorProperties out_prop0 = out_props[0];
1013 EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1014 }
1015
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1016 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1017 // When using aggressive_shape_inference, we run EvaluateNode() for
1018 // allowlisted ops and small input / output tensors. For instance, Fill op is
1019 // evaluated and produces output tensor value if output tensor size is small
1020 // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1021 // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1022 // initializing a large table using Fill.
1023 {
1024 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025 Output a = ops::Const(s.WithOpName("a"), 4, {2}); // 4x4
1026 Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1027 // Shape described by a is small; expect output values of Fill op.
1028 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1029
1030 GrapplerItem item;
1031 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1032 GraphProperties properties(item);
1033 TF_ASSERT_OK(properties.InferStatically(
1034 /*assume_valid_feeds=*/false,
1035 /*aggressive_shape_inference=*/true,
1036 /*include_tensor_values=*/true));
1037 const auto out_props = properties.GetOutputProperties("fill");
1038 const OpInfo::TensorProperties out_prop0 = out_props[0];
1039 EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
1040 EXPECT_TRUE(out_prop0.has_value());
1041 }
1042 {
1043 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1044 Output a = ops::Const(s.WithOpName("a"), 1000, {4}); // 1000x1000x1000x1000
1045 Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1046 // Shape described by a is huge; in that case we skip value inference.
1047 // Otherwise, it'd be too much overhead.
1048 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1049
1050 GrapplerItem item;
1051 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1052 GraphProperties properties(item);
1053 TF_ASSERT_OK(properties.InferStatically(
1054 /*assume_valid_feeds=*/false,
1055 /*aggressive_shape_inference=*/true,
1056 /*include_tensor_values=*/true));
1057 const auto out_props = properties.GetOutputProperties("fill");
1058 const OpInfo::TensorProperties out_prop0 = out_props[0];
1059 EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
1060 EXPECT_FALSE(out_prop0.has_value());
1061 }
1062 }
1063
TEST_F(GraphPropertiesTest,PackWithConstInput)1064 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1065 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1066 Output a = ops::Const(s.WithOpName("a"), 1, {});
1067 Output b = ops::Const(s.WithOpName("b"), 2, {});
1068 Output c = ops::Const(s.WithOpName("c"), 3, {});
1069 Output d = ops::Const(s.WithOpName("d"), 4, {});
1070 // Note ops::Stack instantiates Pack op.
1071 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1072 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1073 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1074 // Fill needs not only e's shape but also its value to figure out output
1075 // shape.
1076 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1077
1078 GrapplerItem item;
1079 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1080 GraphProperties properties(item);
1081 TF_ASSERT_OK(properties.InferStatically(false));
1082 const auto out_props = properties.GetOutputProperties("fill");
1083 const OpInfo::TensorProperties out_prop0 = out_props[0];
1084 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1085 }
1086
TEST_F(GraphPropertiesTest,RankOp)1087 TEST_F(GraphPropertiesTest, RankOp) {
1088 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1089 Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1090 Output r = ops::Rank(s.WithOpName("Rank"), c);
1091 Output i = ops::Identity(s.WithOpName("Identity"), r);
1092
1093 GrapplerItem item;
1094 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1095 GraphProperties properties(item);
1096 TF_ASSERT_OK(properties.InferStatically(false));
1097 const auto rank_props = properties.GetOutputProperties("Rank");
1098 const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1099 EXPECT_EQ("int32: []", PropToString(rank_prop0));
1100 EXPECT_TRUE(rank_prop0.has_value());
1101 ExpectTensorValues({3}, rank_prop0.value());
1102 const auto identity_props = properties.GetOutputProperties("Identity");
1103 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1104 EXPECT_EQ("int32: []", PropToString(identity_props0));
1105 EXPECT_TRUE(identity_props0.has_value());
1106 ExpectTensorValues({3}, identity_props0.value());
1107 }
1108
TEST_F(GraphPropertiesTest,SizeOp)1109 TEST_F(GraphPropertiesTest, SizeOp) {
1110 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1111 Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1112 Output r = ops::Size(s.WithOpName("Size"), c);
1113 Output i = ops::Identity(s.WithOpName("Identity"), r);
1114
1115 GrapplerItem item;
1116 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1117 GraphProperties properties(item);
1118 TF_ASSERT_OK(properties.InferStatically(false));
1119 const auto size_props = properties.GetOutputProperties("Size");
1120 const OpInfo::TensorProperties size_props0 = size_props[0];
1121 EXPECT_EQ("int32: []", PropToString(size_props0));
1122 EXPECT_TRUE(size_props0.has_value());
1123 ExpectTensorValues({24}, size_props0.value());
1124 const auto identity_props = properties.GetOutputProperties("Identity");
1125 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1126 EXPECT_EQ("int32: []", PropToString(identity_props0));
1127 EXPECT_TRUE(identity_props0.has_value());
1128 ExpectTensorValues({24}, identity_props0.value());
1129 }
1130
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1131 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1132 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1133 Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1134 Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1135 Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1136 // pack is [2], with values {4, -1}.
1137
1138 Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1139 Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1140
1141 Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1142 Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1143 // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1144 // their output shapes would be [4, -1]. However, though we use the same
1145 // shape input to the Reshape ops, their output shapes can be different;
1146 // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1147 // the same.
1148
1149 // if input to the Select ops. Note that s0 has a fully defined shape, while
1150 // s1 has unknown shape.
1151 Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1152 Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1153
1154 Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1155 Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1156
1157 // We instantiate SelectV2, but will replace it with Select. The shape
1158 // inference function for Select links all inputs and outputs as they should
1159 // have the same shapes.
1160 Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1161 Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1162
1163 // For z0, as we know the shape of s0, symbolic shape manager in shape
1164 // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1165 // which is [4, 16].
1166 // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1167 // at best.
1168 // Note that x0 and x1 share the same shape input to the Reshape op, but
1169 // -1 in the shape input should not be treated as the same symoblic unknown
1170 // dim; it is merely a constant value -1 for identitying unknown dim for
1171 // Reshape operation.
1172
1173 GrapplerItem item;
1174 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1175
1176 // Replace SelectV2 op with Select op.
1177 for (int i = 0; i < item.graph.node_size(); ++i) {
1178 auto* node = item.graph.mutable_node(i);
1179 if (node->op() == "SelectV2") {
1180 node->set_op("Select");
1181 }
1182 }
1183
1184 GraphProperties properties(item);
1185 TF_ASSERT_OK(properties.InferStatically(false));
1186 for (const auto& node_name : {"x0", "y0", "z0"}) {
1187 const auto out_props = properties.GetOutputProperties(node_name);
1188 const OpInfo::TensorProperties out_prop0 = out_props[0];
1189 EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1190 }
1191 {
1192 const auto out_props = properties.GetOutputProperties("s0");
1193 const OpInfo::TensorProperties out_prop0 = out_props[0];
1194 EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1195 }
1196
1197 for (const auto& node_name : {"x1", "y1", "z1"}) {
1198 const auto out_props = properties.GetOutputProperties(node_name);
1199 const OpInfo::TensorProperties out_prop0 = out_props[0];
1200 EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1201 }
1202 // if input of Select can be either vector or the same shape to the
1203 // input/output; in this case, even if we know input and output are
1204 // [4, ?], we can't say it's [4, ?] or a vector; hence, it shoudl be
1205 // unknown.
1206 {
1207 const auto out_props = properties.GetOutputProperties("s1");
1208 const OpInfo::TensorProperties out_prop0 = out_props[0];
1209 EXPECT_EQ("bool: ?", PropToString(out_prop0));
1210 }
1211 }
1212
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1213 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1214 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1215 // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1216 // from Const.
1217 // If output_tensors_as_shape is not set for those Shape ops or Pack op
1218 // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1219 // hence, its output shape becomes unknown.
1220 Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1221 Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1222 Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1223 Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1224 Output a = ops::Identity(s.WithOpName("a"), a0);
1225 Output b = ops::Identity(s.WithOpName("b"), b0);
1226 Output c = ops::Identity(s.WithOpName("c"), c0);
1227 Output d = ops::Identity(s.WithOpName("d"), d0);
1228 // Note ops::Stack instantiates Pack op.
1229 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1230 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1231 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1232 // Fill needs not only e's shape but also its value to figure out output
1233 // shape.
1234 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1235
1236 GrapplerItem item;
1237 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1238 GraphProperties properties(item);
1239 TF_ASSERT_OK(properties.InferStatically(false));
1240 const auto out_props = properties.GetOutputProperties("fill");
1241 const OpInfo::TensorProperties out_prop0 = out_props[0];
1242 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1243 }
1244
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1245 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1246 // Function ops may have DT_RESOURCE input; if not properly set shapes and
1247 // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1248 // function ops.
1249 GrapplerItem item;
1250 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1251 "function_with_dt_resource_input.pbtxt");
1252 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1253
1254 // This graph evaluates FunctionWithDtResourceInput with two inputs:
1255 // x [DT_FLOAT Const],
1256 // _Arg [DT_RESOURCE _Arg]
1257 // and has two outputs:
1258 // z1 = x + _Arg
1259 // z2 = x
1260 {
1261 GraphProperties properties(item);
1262 TF_ASSERT_OK(properties.InferStatically(false));
1263 const auto out_props =
1264 properties.GetOutputProperties("FunctionWithDtResourceInput");
1265 EXPECT_EQ(out_props.size(), 2);
1266 const OpInfo::TensorProperties out_prop0 = out_props[0];
1267 EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1268 const OpInfo::TensorProperties out_prop1 = out_props[1];
1269 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1270 }
1271
1272 {
1273 // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1274 for (int i = 0; i < item.graph.node_size(); i++) {
1275 auto* node = item.graph.mutable_node(i);
1276 if (node->name() == "y") { // _Arg node with DT_RESOURCE
1277 node->mutable_attr()->erase("_handle_dtypes");
1278 node->mutable_attr()->erase("_handle_shapes");
1279 break;
1280 }
1281 }
1282 // We cannot infer the function output shape correctly without those attr,
1283 // but still it shouldn't fail; also, there can be some shapes we can
1284 // infer in such a case. In this test graph,
1285 // z2 of the function node just returns x input; hence, even if _Arg's shape
1286 // cannot be inferred, we can infer z2 output shape.
1287 GraphProperties properties(item);
1288 TF_ASSERT_OK(properties.InferStatically(false));
1289 const auto out_props =
1290 properties.GetOutputProperties("FunctionWithDtResourceInput");
1291 EXPECT_EQ(out_props.size(), 2);
1292 const OpInfo::TensorProperties out_prop0 = out_props[0];
1293 // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1294 // for x + _Arg.
1295 EXPECT_EQ("float: ?", PropToString(out_prop0));
1296 // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1297 // infer this output shape.
1298 const OpInfo::TensorProperties out_prop1 = out_props[1];
1299 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1300 }
1301 }
1302
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1303 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1304 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1305 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1306 Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1307 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1308 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1309 s.graph()->op_registry());
1310 tensorflow::Node* func_op;
1311 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1312 auto _value = tensorflow::ops::AsNodeOut(s, value);
1313 TF_ASSERT_OK(
1314 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1315 GrapplerItem item;
1316 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1317
1318 GraphProperties properties(item);
1319 TF_ASSERT_OK(properties.InferStatically(false));
1320 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1321 const OpInfo::TensorProperties out_prop0 = out_props[0];
1322 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1323 }
1324
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1325 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1326 // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1327 // so tensor shapes, not tensor value, should be used as Const input to
1328 // function.
1329 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1330 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1331 Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1332 Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1333 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1334 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1335 s.graph()->op_registry());
1336 tensorflow::Node* func_op;
1337 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1338 auto _value = tensorflow::ops::AsNodeOut(s, value);
1339 TF_ASSERT_OK(
1340 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1341 GrapplerItem item;
1342 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1343
1344 GraphProperties properties(item);
1345 TF_ASSERT_OK(properties.InferStatically(false));
1346 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1347 const OpInfo::TensorProperties out_prop0 = out_props[0];
1348 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1349 }
1350
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1351 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1352 FunctionDefLibrary library;
1353 *library.add_function() = FunctionDefHelper::Create(
1354 "MyFunc", // Name
1355 {"x: int32"}, // Inputs
1356 {"out: int32"}, // Outputs
1357 {}, // Attrs
1358 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1359 {{"out", "a:output:0"}}); // Returns
1360 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1361 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1362
1363 // MyFunc takes Const (shape) and passes it with Identity. Expect function
1364 // output has the same shape as well as value (output_tensors_as_shape) as
1365 // input Const tensor.
1366 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1367 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1368 auto builder =
1369 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1370 tensorflow::Node* func_op;
1371 TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1372
1373 GrapplerItem item;
1374 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1375
1376 GraphProperties properties(item);
1377 TF_ASSERT_OK(properties.InferStatically(true));
1378 const auto out_props = properties.GetOutputProperties("MyFunc");
1379 const OpInfo::TensorProperties out_prop0 = out_props[0];
1380 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1381 EXPECT_TRUE(out_prop0.has_value());
1382 ExpectTensorValues({5, 7}, out_prop0.value());
1383 ExpectTensorValues({5, 7},
1384 properties.GetInputProperties("MyFunc")[0].value());
1385 }
1386
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1387 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1388 FunctionDefLibrary library;
1389 // Function that adds two input values.
1390 *library.add_function() = FunctionDefHelper::Create(
1391 "MyFunc", // Name
1392 {"x: int32", "y: int32"}, // Inputs
1393 {"out: int32"}, // Outputs
1394 {}, // Attrs
1395 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1396 {{"out", "a:z:0"}}); // Returns
1397 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1398 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1399
1400 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1401 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1402 auto builder =
1403 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1404 tensorflow::Node* func_op;
1405 TF_ASSERT_OK(
1406 builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1407
1408 GrapplerItem item;
1409 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1410 {
1411 GraphProperties properties(item);
1412 // Without aggressive_shape_inference, the internal function does not
1413 // evaluate output value.
1414 TF_ASSERT_OK(properties.InferStatically(
1415 /*assume_valid_feeds=*/true,
1416 /*aggressive_shape_inference=*/false,
1417 /*include_tensor_values=*/true));
1418 const auto out_props = properties.GetOutputProperties("MyFunc");
1419 const OpInfo::TensorProperties out_prop0 = out_props[0];
1420 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1421 EXPECT_FALSE(out_prop0.has_value());
1422 }
1423
1424 {
1425 GraphProperties properties(item);
1426 // With aggressive_shape_inference, output value is evaluated.
1427 TF_ASSERT_OK(properties.InferStatically(
1428 /*assume_valid_feeds=*/true,
1429 /*aggressive_shape_inference=*/true,
1430 /*include_tensor_values=*/true));
1431 const auto out_props = properties.GetOutputProperties("MyFunc");
1432 const OpInfo::TensorProperties out_prop0 = out_props[0];
1433 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1434 EXPECT_TRUE(out_prop0.has_value());
1435
1436 ExpectTensorValues({10, 14}, out_prop0.value());
1437 ExpectTensorValues({5, 7},
1438 properties.GetInputProperties("MyFunc")[0].value());
1439 ExpectTensorValues({5, 7},
1440 properties.GetInputProperties("MyFunc")[1].value());
1441 }
1442 }
1443
1444 // Same as the above, but float values; also, one of the function input is
1445 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1446 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1447 FunctionDefLibrary library;
1448 // Function that adds two input values.
1449 *library.add_function() = FunctionDefHelper::Create(
1450 "MyFunc", // Name
1451 {"x: float", "y: float"}, // Inputs
1452 {"out: float"}, // Outputs
1453 {}, // Attrs
1454 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1455 {{"out", "a:z:0"}}); // Returns
1456 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1457 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1458
1459 Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1460 Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1461 auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1462 auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1463 auto builder =
1464 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1465 tensorflow::Node* func_op;
1466 TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1467
1468 GrapplerItem item;
1469 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1470 {
1471 GraphProperties properties(item);
1472 // Without aggressive_shape_inference, the internal function does not
1473 // evaluate output value.
1474 TF_ASSERT_OK(properties.InferStatically(
1475 /*assume_valid_feeds=*/true,
1476 /*aggressive_shape_inference=*/false,
1477 /*include_tensor_values=*/true));
1478 const auto out_props = properties.GetOutputProperties("MyFunc");
1479 const OpInfo::TensorProperties out_prop0 = out_props[0];
1480 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1481 EXPECT_FALSE(out_prop0.has_value());
1482 }
1483
1484 {
1485 GraphProperties properties(item);
1486 // With aggressive_shape_inference, output value is evaluated.
1487 TF_ASSERT_OK(properties.InferStatically(
1488 /*assume_valid_feeds=*/true,
1489 /*aggressive_shape_inference=*/true,
1490 /*include_tensor_values=*/true));
1491 const auto out_props = properties.GetOutputProperties("MyFunc");
1492 const OpInfo::TensorProperties out_prop0 = out_props[0];
1493 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1494 EXPECT_TRUE(out_prop0.has_value());
1495
1496 ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1497 ExpectFloatTensorValues({5.0, 7.0},
1498 properties.GetInputProperties("MyFunc")[0].value());
1499 ExpectFloatTensorValues({5.0, 7.0},
1500 properties.GetInputProperties("MyFunc")[1].value());
1501 }
1502 }
1503
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1504 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1505 // Create graph with a function that takes a scalar value so that we use
1506 // Placeholder with scalar as for input to the function shape inference.
1507 // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1508 // the input; all tensors are scalars.
1509 FunctionDefLibrary library;
1510 *library.add_function() = FunctionDefHelper::Create(
1511 "MyFunc", // Name
1512 {"x: float"}, // Inputs
1513 {"out: float"}, // Outputs
1514 {}, // Attrs
1515 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1516 {{"out", "a:output:0"}}); // Returns
1517 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1518 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1519 Output placeholder =
1520 ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1521 ops::Placeholder::Shape(TensorShape({})));
1522 Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1523 auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1524 auto builder =
1525 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1526 tensorflow::Node* func_op;
1527 TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1528 GrapplerItem item;
1529 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1530
1531 // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1532 // as unknown, instead of scalar.
1533 EXPECT_GT(item.graph.versions().producer(), 21);
1534
1535 // MyFunc output shouldn't be unknown rank.
1536 GraphProperties properties(item);
1537 TF_ASSERT_OK(properties.InferStatically(true));
1538 const auto out_props = properties.GetOutputProperties("MyFunc");
1539 const OpInfo::TensorProperties out_prop0 = out_props[0];
1540 EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1541 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1542 }
1543
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1544 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1545 // Test graph produced in python using:
1546 /*
1547 @function.Defun(*[tf.float32] * 2, noinline=True)
1548 def MyAdd(x, y):
1549 return tf.add(x,y)
1550
1551 with tf.Graph().as_default():
1552 x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1553 y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1554 z = MyAdd(x, y)
1555 z = MyAdd(x, z)
1556 */
1557 GrapplerItem item;
1558 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1559 "simple_function.pbtxt");
1560 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1561 GraphProperties properties(item);
1562 TF_ASSERT_OK(properties.InferStatically(false));
1563 const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1564 const OpInfo::TensorProperties& out_prop = out_props[0];
1565 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1566 EXPECT_FALSE(out_prop.shape().unknown_rank());
1567 EXPECT_EQ(2, out_prop.shape().dim_size());
1568 EXPECT_EQ(1, out_prop.shape().dim(0).size());
1569 EXPECT_EQ(2, out_prop.shape().dim(1).size());
1570
1571 const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1572 EXPECT_EQ(2, in_props.size());
1573
1574 const OpInfo::TensorProperties& in_prop = in_props[0];
1575 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1576
1577 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1578 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1579 }
1580
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1581 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1582 GrapplerItem item;
1583 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1584 "large_function_graph.pbtxt");
1585 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1586 GraphProperties properties(item);
1587 TF_ASSERT_OK(properties.InferStatically(false));
1588
1589 const auto out_props = properties.GetOutputProperties("y0");
1590 EXPECT_EQ(2, out_props.size());
1591
1592 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1593 EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1594
1595 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1596 EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1597
1598 const auto in_props = properties.GetInputProperties("y0");
1599 EXPECT_EQ(4, in_props.size());
1600
1601 const OpInfo::TensorProperties& in_prop0 = in_props[0];
1602 EXPECT_EQ("float: [64]", PropToString(in_prop0));
1603
1604 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1605 EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1606
1607 const OpInfo::TensorProperties& in_prop2 = in_props[2];
1608 EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1609
1610 const OpInfo::TensorProperties& in_prop3 = in_props[3];
1611 EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1612 }
1613
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1614 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1615 // Test graph produced in python using:
1616 /*
1617 @function.Defun(noinline=True)
1618 def MyFunc():
1619 @function.Defun(*[tf.float32] * 2)
1620 def Cond(n, unused_x):
1621 return n > 0
1622
1623 @function.Defun(*[tf.float32] * 2)
1624 def Body(n, x):
1625 return n - 1, x + n
1626
1627 i = tf.constant(10)
1628 return functional_ops.While([i, 0.], Cond, Body)
1629
1630 with tf.Graph().as_default():
1631 z = MyFunc()
1632 */
1633 GrapplerItem item;
1634 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1635 "function_functional_while.pbtxt");
1636 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1637 GraphProperties properties(item);
1638 TF_ASSERT_OK(properties.InferStatically(false));
1639
1640 const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1641 EXPECT_EQ(2, out_props.size());
1642
1643 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1644 EXPECT_EQ(DT_INT32, out_prop0.dtype());
1645 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1646
1647 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1648 EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1649 EXPECT_FALSE(out_prop1.shape().unknown_rank());
1650 }
1651
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1652 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1653 GrapplerItem item;
1654 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1655 "function_error.pbtxt");
1656 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1657 GraphProperties properties(item);
1658 TF_ASSERT_OK(properties.InferStatically(false));
1659
1660 const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1661 EXPECT_EQ(1, out_props.size());
1662
1663 const OpInfo::TensorProperties& out_prop = out_props[0];
1664 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1665 EXPECT_TRUE(out_prop.shape().unknown_rank());
1666
1667 const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1668 EXPECT_EQ(2, in_props.size());
1669
1670 const OpInfo::TensorProperties& in_prop = in_props[0];
1671 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1672
1673 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1674 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1675 }
1676
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1677 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1678 // Test graph produced in python using:
1679 /*
1680 @function.Defun(*[tf.float32] * 2, noinline=True)
1681 def MyAdd(x, y):
1682 return tf.add(x, y)
1683
1684 with tf.Graph().as_default():
1685 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1686 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1687 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1688 z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1689 */
1690 GrapplerItem item;
1691 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1692 "function_switch.pbtxt");
1693 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1694 GraphProperties properties(item);
1695 TF_ASSERT_OK(properties.InferStatically(false));
1696 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1697 const OpInfo::TensorProperties& out_prop = out_props[0];
1698 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1699 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1700
1701 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1702 EXPECT_EQ(2, in_props.size());
1703
1704 const OpInfo::TensorProperties& in_prop = in_props[0];
1705 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1706
1707 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1708 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1709 }
1710
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1711 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1712 // Test graph produced in python using:
1713 /*
1714 @function.Defun(*[tf.float32] * 2, noinline=True)
1715 def MyAdd(x, y):
1716 return tf.add(x, y)
1717
1718 with tf.Graph().as_default():
1719 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1720 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1721 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1722 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1723 */
1724 GrapplerItem item;
1725 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1726 "function_switch_2.pbtxt");
1727 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1728 GraphProperties properties(item);
1729 TF_ASSERT_OK(properties.InferStatically(false));
1730 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1731 const OpInfo::TensorProperties& out_prop = out_props[0];
1732 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1733
1734 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1735 EXPECT_EQ(2, in_props.size());
1736
1737 const OpInfo::TensorProperties& in_prop = in_props[0];
1738 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1739
1740 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1741 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1742 }
1743
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1744 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1745 // Test graph produced in python using:
1746 /*
1747 @function.Defun(*[tf.float32] * 2, noinline=True)
1748 def MyAdd(x, y):
1749 a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1750 b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1751 c = tf.add(x, a)
1752 d = tf.add(y, b)
1753 return c
1754
1755 with tf.Graph().as_default():
1756 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1757 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1758 z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1759 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1760 */
1761 GrapplerItem item;
1762 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1763 "function_switch_shapes.pbtxt");
1764 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1765 GraphProperties properties(item);
1766 TF_ASSERT_OK(properties.InferStatically(false));
1767 const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1768 const OpInfo::TensorProperties& out_prop = out_props[0];
1769 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1770
1771 const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1772 EXPECT_EQ(2, in_props.size());
1773
1774 const OpInfo::TensorProperties& in_prop = in_props[0];
1775 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1776
1777 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1778 EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1779 }
1780
TEST_F(GraphPropertiesTest,SymbolicShapes)1781 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1782 // Build a simple graph with placeholders of unknown dimensions. These
1783 // dimensions will be encoded symbolically.
1784 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1785
1786 Output a =
1787 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1788 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1789 Output b =
1790 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1791 ops::Placeholder::Shape(PartialTensorShape({-1})));
1792 Output c = ops::Identity(s.WithOpName("c"), a);
1793 Output d = ops::Identity(s.WithOpName("d"), b);
1794 Output e = ops::Add(s.WithOpName("e"), c, d);
1795 Output f = ops::Add(s.WithOpName("f"), a, c);
1796
1797 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1798 Output g = ops::Shape(s.WithOpName("g"), c);
1799 Output h = ops::Fill(s.WithOpName("h"), g, zero);
1800 Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1801 Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1802
1803 GrapplerItem item;
1804 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1805
1806 GraphProperties properties(item);
1807 TF_ASSERT_OK(properties.InferStatically(false));
1808 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1809 const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1810 EXPECT_EQ(2, shape_a.dim_size());
1811 EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1812 EXPECT_GE(-2, shape_a.dim(0).size());
1813 EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1814 EXPECT_GE(-2, shape_a.dim(1).size());
1815 EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1816
1817 PartialTensorShape shape(shape_a);
1818 EXPECT_FALSE(shape.IsFullyDefined());
1819 EXPECT_FALSE(shape.unknown_rank());
1820
1821 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1822 const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1823 EXPECT_EQ(1, shape_b.dim_size());
1824 EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1825 EXPECT_GE(-2, shape_b.dim(0).size());
1826 EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1827 EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1828
1829 const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1830 ASSERT_EQ(2, shape_e.dim_size());
1831 EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1832 EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1833 EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1834
1835 const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1836 ASSERT_EQ(2, shape_f.dim_size());
1837 EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1838 EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1839
1840 const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1841 ASSERT_EQ(2, shape_f.dim_size());
1842 EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1843 EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1844
1845 const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
1846 ASSERT_EQ(1, shape_j.dim_size());
1847 EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
1848 }
1849
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)1850 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
1851 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1852 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
1853 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
1854 Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
1855 GrapplerItem item;
1856 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1857 // Create a graph with node a removed (say by some graph optimization
1858 // pass), noting that node c is colocated with a. This is fine as it
1859 // is in the late stage of graph execution, the colocation constraints have
1860 // been validated previously and the device placement of nodes has completed.
1861 GraphDef optimized_graph;
1862 for (const auto& node : item.graph.node()) {
1863 if (node.name() != "a") {
1864 *optimized_graph.add_node() = node;
1865 }
1866 }
1867 item.graph.Swap(&optimized_graph);
1868 GraphProperties properties(item);
1869 // This function should return OK, since it doesn't validate the colocation
1870 // constraints internally.
1871 TF_EXPECT_OK(properties.InferStatically(false));
1872 }
1873
TEST_F(GraphPropertiesTest,ShapeTracking)1874 TEST_F(GraphPropertiesTest, ShapeTracking) {
1875 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1876 Output a =
1877 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1878 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1879 Output b =
1880 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1881 ops::Placeholder::Shape(PartialTensorShape({-1})));
1882 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1883 auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
1884 Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
1885 Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
1886
1887 GrapplerItem item;
1888 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1889
1890 GraphProperties properties(item);
1891 TF_ASSERT_OK(properties.InferStatically(false));
1892 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1893 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1894 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1895 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1896 EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
1897 EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
1898 }
1899
TEST_F(GraphPropertiesTest,FedNodes)1900 TEST_F(GraphPropertiesTest, FedNodes) {
1901 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
1902 cluster_->GetDeviceNames());
1903 GrapplerItem item;
1904 CHECK(fake_input.NextItem(&item));
1905
1906 {
1907 // Conservative shape analysis: the shape of fed ports should be unknown
1908 GraphProperties properties(item);
1909 Status s = properties.InferStatically(false);
1910 TF_ASSERT_OK(s);
1911 for (const auto& node : item.graph.node()) {
1912 if (node.op() == "Const") {
1913 continue;
1914 }
1915 const auto in_props = properties.GetInputProperties(node.name());
1916 EXPECT_EQ(1, in_props.size());
1917 const OpInfo::TensorProperties& in_prop = in_props[0];
1918 const auto out_props = properties.GetOutputProperties(node.name());
1919 EXPECT_EQ(1, out_props.size());
1920 const OpInfo::TensorProperties& out_prop = out_props[0];
1921
1922 if (node.name() == "x") {
1923 // x is fed: its input should have a known shape, while its output
1924 // doesn't
1925 EXPECT_FALSE(in_prop.shape().unknown_rank());
1926 EXPECT_EQ(1, in_prop.shape().dim_size());
1927 EXPECT_EQ(2, in_prop.shape().dim(0).size());
1928 EXPECT_TRUE(out_prop.shape().unknown_rank());
1929 } else if (node.op() == "Square" || node.op() == "AddN") {
1930 // These nodes are in the fanout of x: their shapes should be unknown.
1931 EXPECT_TRUE(in_prop.shape().unknown_rank());
1932 EXPECT_TRUE(out_prop.shape().unknown_rank());
1933 }
1934 }
1935 }
1936 {
1937 // Optimistic shape analysis: the shape of fed ports should be derived from
1938 // the shape of the fanin.
1939 GraphProperties properties(item);
1940 Status s = properties.InferStatically(true);
1941 TF_ASSERT_OK(s);
1942 for (const auto& node : item.graph.node()) {
1943 if (node.op() == "Square" || node.op() == "AddN") {
1944 const auto in_props = properties.GetInputProperties(node.name());
1945 EXPECT_EQ(1, in_props.size());
1946 const OpInfo::TensorProperties& in_prop = in_props[0];
1947 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
1948 EXPECT_FALSE(in_prop.shape().unknown_rank());
1949 EXPECT_EQ(2, in_prop.shape().dim_size());
1950 const auto out_props = properties.GetOutputProperties(node.name());
1951 EXPECT_EQ(1, out_props.size());
1952 const OpInfo::TensorProperties& out_prop = out_props[0];
1953 EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
1954 EXPECT_EQ(in_prop.shape().DebugString(),
1955 out_prop.shape().DebugString());
1956 }
1957 }
1958 }
1959 }
1960
TEST_F(GraphPropertiesTest,Performance)1961 TEST_F(GraphPropertiesTest, Performance) {
1962 // Load a large graph with many nested loops to make sure we can infer shapes
1963 // quickly.
1964 GrapplerItem item;
1965 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1966 "large_graph.pbtxt.html");
1967 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1968 TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
1969 &item.graph,
1970 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
1971 true));
1972
1973 GraphProperties properties(item);
1974 TF_ASSERT_OK(properties.InferStatically(false));
1975 }
1976
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)1977 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
1978 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1979 Output a =
1980 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1981 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1982 auto shp = ops::Shape(s.WithOpName("shape"), {a});
1983
1984 Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
1985 Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
1986 Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
1987
1988 Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
1989 Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
1990
1991 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1992 Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
1993 Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
1994
1995 GrapplerItem item;
1996 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1997
1998 GraphProperties properties(item);
1999 TF_ASSERT_OK(properties.InferStatically(false));
2000 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2001 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2002 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2003 EXPECT_EQ(2, shape_a.dim_size());
2004 EXPECT_EQ(1, shape_o1.dim_size());
2005 EXPECT_EQ(1, shape_o2.dim_size());
2006 EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2007 EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2008 }
2009
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2010 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2011 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2012 Output placeholder =
2013 ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2014 ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2015 auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2016
2017 Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2018 Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2019 Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2020
2021 Output slice =
2022 ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2023 stride, ops::StridedSlice::ShrinkAxisMask(1));
2024
2025 GrapplerItem item;
2026 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2027
2028 // Without aggressive shape inference, it cannot infer output value of
2029 // StridedSlice with ShrinkAxisMask.
2030 {
2031 GraphProperties properties(item);
2032 TF_ASSERT_OK(properties.InferStatically(
2033 /*assume_valid_feeds=*/false,
2034 /*aggressive_shape_inference=*/false,
2035 /*include_tensor_values=*/true));
2036 EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2037 }
2038
2039 // InferStatically with aggressive shape inference can infer output value of
2040 // StridedSlice with ShrinkAxisMask.
2041 {
2042 GraphProperties properties(item);
2043 TF_ASSERT_OK(properties.InferStatically(
2044 /*assume_valid_feeds=*/false,
2045 /*aggressive_shape_inference=*/true,
2046 /*include_tensor_values=*/true));
2047 EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2048 const auto slice_value =
2049 properties.GetOutputProperties("slice").at(0).value();
2050 ExpectTensorValues({5}, slice_value);
2051 }
2052 }
2053
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2054 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2055 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2056 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2057 Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2058 Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2059
2060 Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2061 Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2062 Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2063 Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2064 Output c_plus_b_plus_2a =
2065 ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2066
2067 GrapplerItem item;
2068 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2069 GraphProperties properties(item);
2070 TF_ASSERT_OK(properties.InferStatically(
2071 /*assume_valid_feeds=*/false,
2072 /*aggressive_shape_inference=*/true,
2073 /*include_tensor_values=*/true));
2074
2075 // Check output shapes and values.
2076 const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2077 EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2078 EXPECT_TRUE(a_plus_one_prop.has_value());
2079 ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2080
2081 const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2082 EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2083 EXPECT_TRUE(a_plus_a_prop.has_value());
2084 ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2085
2086 const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2087 EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2088 EXPECT_TRUE(b_plus_2a_prop.has_value());
2089 ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2090
2091 const auto& c_plus_b_plus_2a_prop =
2092 properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2093 EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2094 EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2095 ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2096 }
2097
TEST_F(GraphPropertiesTest,ShapeAnnotation)2098 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2099 GrapplerItem item;
2100 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2101 .Attr("dtype", DT_FLOAT)
2102 .Attr("shape", PartialTensorShape({-1, -1}))
2103 .Finalize(item.graph.add_node()));
2104 // Annotate shapes.
2105 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2106 .Attr("dtype", DT_FLOAT)
2107 .Attr("_same_output_for_iterations", true)
2108 .Attr("_output_shape_vector", {TensorShape({5, 7})})
2109 .Input("Input", 0, DT_FLOAT)
2110 .Finalize(item.graph.add_node()));
2111 {
2112 GraphProperties properties(item);
2113 // Without aggressive_shape_inference, ignore annotated information.
2114 TF_ASSERT_OK(properties.InferStatically(
2115 /*assume_valid_feeds=*/false,
2116 /*aggressive_shape_inference=*/false,
2117 /*include_tensor_values=*/true));
2118 const auto props = properties.GetOutputProperties("Identity");
2119 EXPECT_EQ(1, props.size());
2120 const OpInfo::TensorProperties& prop = props[0];
2121 EXPECT_EQ(DT_FLOAT, prop.dtype());
2122 EXPECT_EQ(2, prop.shape().dim_size());
2123 // Get unknown shapes without using annotated information.
2124 EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2125 }
2126 {
2127 GraphProperties properties(item);
2128 // Use annotated information.
2129 TF_ASSERT_OK(properties.InferStatically(
2130 /*assume_valid_feeds=*/false,
2131 /*aggressive_shape_inference=*/true,
2132 /*include_tensor_values=*/true));
2133 const auto props = properties.GetOutputProperties("Identity");
2134 EXPECT_EQ(1, props.size());
2135 const OpInfo::TensorProperties& prop = props[0];
2136 EXPECT_EQ(DT_FLOAT, prop.dtype());
2137 EXPECT_EQ(2, prop.shape().dim_size());
2138 // Update output shape using annotated shapes.
2139 EXPECT_EQ("float: [5,7]", PropToString(prop));
2140 }
2141 }
2142
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2143 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2144 GrapplerItem item;
2145 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2146 .Attr("dtype", DT_FLOAT)
2147 .Attr("shape", PartialTensorShape({-1, 100}))
2148 .Finalize(item.graph.add_node()));
2149 // Annotate shapes.
2150 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2151 .Attr("dtype", DT_FLOAT)
2152 .Attr("_same_output_for_iterations", true)
2153 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2154 .Input("Input", 0, DT_FLOAT)
2155 .Finalize(item.graph.add_node()));
2156 GraphProperties properties(item);
2157 // Use annotated information.
2158 TF_ASSERT_OK(properties.InferStatically(
2159 /*assume_valid_feeds=*/false,
2160 /*aggressive_shape_inference=*/true,
2161 /*include_tensor_values=*/true));
2162 const auto props = properties.GetOutputProperties("Identity");
2163 EXPECT_EQ(1, props.size());
2164 const OpInfo::TensorProperties& prop = props[0];
2165 EXPECT_EQ(DT_FLOAT, prop.dtype());
2166 EXPECT_EQ(2, prop.shape().dim_size());
2167 // Compatible shapes. Update output shape using annotated shapes.
2168 EXPECT_EQ("float: [10,100]", PropToString(prop));
2169 }
2170
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2171 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2172 GrapplerItem item;
2173 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2174 .Attr("dtype", DT_FLOAT)
2175 .Attr("shape", PartialTensorShape({-1, 100}))
2176 .Finalize(item.graph.add_node()));
2177 // Annotate shapes.
2178 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2179 .Attr("dtype", DT_FLOAT)
2180 .Attr("_same_output_for_iterations", true)
2181 .Attr("_output_shape_vector", {TensorShape({10, 10})})
2182 .Input("Input", 0, DT_FLOAT)
2183 .Finalize(item.graph.add_node()));
2184 GraphProperties properties(item);
2185 // Use annotated information.
2186 TF_ASSERT_OK(properties.InferStatically(
2187 /*assume_valid_feeds=*/false,
2188 /*aggressive_shape_inference=*/true,
2189 /*include_tensor_values=*/true));
2190 const auto props = properties.GetOutputProperties("Identity");
2191 EXPECT_EQ(1, props.size());
2192 const OpInfo::TensorProperties& prop = props[0];
2193 EXPECT_EQ(DT_FLOAT, prop.dtype());
2194 EXPECT_EQ(2, prop.shape().dim_size());
2195 // Incompatible shapes. Do not use annotated shapes.
2196 EXPECT_EQ("float: [-1,100]", PropToString(prop));
2197 }
2198
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2199 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2200 GrapplerItem item;
2201 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2202 .Attr("dtype", DT_FLOAT)
2203 .Attr("shape", PartialTensorShape({-1, -1}))
2204 .Finalize(item.graph.add_node()));
2205 // Annotate shapes.
2206 TF_ASSERT_OK(
2207 NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2208 .Attr("_same_output_for_iterations", true)
2209 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2210 .Input("Input", 0, DT_FLOAT)
2211 .Finalize(item.graph.add_node()));
2212 GraphProperties properties(item);
2213 // Use annotated information.
2214 TF_ASSERT_OK(properties.InferStatically(
2215 /*assume_valid_feeds=*/false,
2216 /*aggressive_shape_inference=*/true,
2217 /*include_tensor_values=*/true));
2218 const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2219 EXPECT_EQ(1, props.size());
2220 const OpInfo::TensorProperties& prop = props[0];
2221 EXPECT_EQ(DT_FLOAT, prop.dtype());
2222 EXPECT_EQ(2, prop.shape().dim_size());
2223 EXPECT_EQ("float: [10,100]", PropToString(prop));
2224 }
2225
TEST_F(GraphPropertiesTest,PartitionedCallOp)2226 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2227 Scope root = Scope::NewRootScope().ExitOnError();
2228 FunctionDefLibrary library;
2229 FunctionDef called_func = FunctionDefHelper::Create(
2230 "identity_function",
2231 /*in_def=*/{"arg0: int32"},
2232 /*out_def=*/{"ret0: int32"},
2233 /*attr_def=*/{},
2234 {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2235 /*ret_def=*/{{"ret0", "Identity:output:0"}});
2236 *library.add_function() = called_func;
2237 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2238
2239 Output in = ops::Const(root, {3, 1, 2, 0});
2240 NameAttrList b_name_attr;
2241 b_name_attr.set_name("identity_function");
2242 ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2243 b_name_attr);
2244
2245 GrapplerItem item;
2246 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2247
2248 GraphProperties properties(item);
2249 TF_ASSERT_OK(properties.InferStatically(
2250 /*assume_valid_feeds=*/true,
2251 /*aggressive_shape_inference=*/false,
2252 /*include_tensor_values=*/true));
2253
2254 EXPECT_EQ("int32: [4]",
2255 PropToString(properties.GetOutputProperties("identity_call")[0]));
2256 }
2257
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2258 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2259 auto f = FunctionDefHelper::Create(
2260 // Name
2261 "FunctionWhichAdds",
2262 // Inputs
2263 {"arg0: int32", "arg1: int32"},
2264 // Outputs
2265 {"ret0: int32"},
2266 /*attr_def=*/{},
2267 // Nodes
2268 {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2269 /*ret_def=*/{{"ret0", "a:z:0"}});
2270
2271 FunctionDefLibrary function_lib;
2272 function_lib.add_function()->Swap(&f);
2273 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2274 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2275
2276 PartialTensorShape input_shape({2, 2, -1});
2277 Output in1 =
2278 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2279 Output in2 =
2280 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2281 NameAttrList b_name_attr;
2282 b_name_attr.set_name("FunctionWhichAdds");
2283 ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2284 b_name_attr);
2285
2286 GrapplerItem item;
2287 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2288
2289 GraphProperties properties(item);
2290 TF_ASSERT_OK(properties.InferStatically(
2291 /*assume_valid_feeds=*/true,
2292 /*aggressive_shape_inference=*/false,
2293 /*include_tensor_values=*/true));
2294
2295 EXPECT_EQ("int32: [2,2,-1]",
2296 PropToString(properties.GetOutputProperties("add_call")[0]));
2297 }
2298
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2299 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2300 // A function, which we cannot infer output shape statically.
2301 auto f = FunctionDefHelper::Create(
2302 // Name
2303 "FuncShapeCannotBeInferred",
2304 // Inputs
2305 {},
2306 // Outputs
2307 {"output: float"},
2308 // Attrs
2309 {},
2310 // Nodes
2311 {
2312 // Placeholder without shape attr; unknown rank.
2313 {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2314 },
2315 // Returns
2316 {{"output", "p:output:0"}});
2317 FunctionDefLibrary function_lib;
2318 function_lib.add_function()->Swap(&f);
2319 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2320 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2321 tensorflow::Node* func_op;
2322 TensorShapeProto output_shape;
2323 output_shape.set_unknown_rank(false);
2324 output_shape.add_dim()->set_size(1);
2325 output_shape.add_dim()->set_size(2);
2326 output_shape.add_dim()->set_size(3);
2327 output_shape.add_dim()->set_size(4);
2328 // The function node, f, includes shape annotation.
2329 TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2330 s.graph()->op_registry())
2331 .Attr("_execution_count", 1)
2332 .Attr("_same_output_for_iterations", true)
2333 .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2334 .Attr("_output_shape_vector", {output_shape})
2335 .Finalize(s.graph(), &func_op));
2336 GrapplerItem item;
2337 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2338
2339 // InferStatically with aggressive_shape_inference would fail to infer
2340 // the output shape of the node f.
2341 {
2342 GraphProperties properties(item);
2343 TF_ASSERT_OK(properties.InferStatically(
2344 /*assume_valid_feeds=*/false,
2345 /*aggressive_shape_inference=*/false,
2346 /*include_tensor_values=*/false));
2347 const auto out_props = properties.GetOutputProperties("f");
2348 const OpInfo::TensorProperties out_prop0 = out_props[0];
2349 EXPECT_EQ("float: ?", PropToString(out_prop0));
2350 }
2351 // With aggressive_shape_inference, it skips recursively callying
2352 // InferStatically for the function node and outputs annotated shape info.
2353 {
2354 GraphProperties properties(item);
2355 TF_ASSERT_OK(properties.InferStatically(
2356 /*assume_valid_feeds=*/false,
2357 /*aggressive_shape_inference=*/true,
2358 /*include_tensor_values=*/true));
2359 const auto out_props = properties.GetOutputProperties("f");
2360 const OpInfo::TensorProperties out_prop0 = out_props[0];
2361 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2362 }
2363 }
2364
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2365 TEST_F(GraphPropertiesTest,
2366 SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2367 GrapplerItem item;
2368 // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2369 // used for two reshape ops. One reshape op is segment_ids input to
2370 // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2371 // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2372 // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2373 // SymbolicShapeRefiner.
2374 // This dim value (10), however, should not affect the other reshape op, even
2375 // though it shares the shape input; -1 in the shape input of Reshape op is
2376 // a special case of computed output dim, not unknown dim.
2377 // data and num_segments are inputs to UnsortedSegmenetSum.
2378
2379 TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2380 .Attr("dtype", DT_FLOAT)
2381 .Attr("shape", TensorShape({10, 10, 10, 10}))
2382 .Finalize(item.graph.add_node()));
2383 Tensor num_segments(DT_INT32, TensorShape({}));
2384 // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2385 // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2386 // output to form shape vector [-1, 10] to Reshape.
2387 test::FillIota<int>(&num_segments, 3);
2388 TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2389 .Attr("dtype", DT_INT32)
2390 .Attr("value", num_segments)
2391 .Finalize(item.graph.add_node()));
2392 Tensor minus_one(DT_INT32, TensorShape({1}));
2393 test::FillIota<int>(&minus_one, -1);
2394 TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2395 .Attr("dtype", DT_INT32)
2396 .Attr("value", minus_one)
2397 .Finalize(item.graph.add_node()));
2398 Tensor plus_ten(DT_INT32, TensorShape({1}));
2399 test::FillIota<int>(&plus_ten, 10);
2400 TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2401 .Attr("dtype", DT_INT32)
2402 .Attr("value", plus_ten)
2403 .Finalize(item.graph.add_node()));
2404 Tensor axis(DT_INT32, TensorShape({}));
2405 test::FillIota<int>(&axis, -1);
2406 TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2407 .Attr("dtype", DT_INT32)
2408 .Attr("value", axis)
2409 .Finalize(item.graph.add_node()));
2410 std::vector<NodeDefBuilder::NodeOut> inputs(2);
2411 inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2412 inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2413 TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2414 .Input(inputs)
2415 .Input("axis", 0, DT_INT32)
2416 .Attr("N", 2)
2417 .Attr("T", DT_INT32)
2418 .Attr("Tidx", DT_INT32)
2419 .Finalize(item.graph.add_node()));
2420 TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2421 .Attr("dtype", DT_FLOAT)
2422 .Finalize(item.graph.add_node()));
2423 TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2424 .Input("segment_ids_", 0, DT_FLOAT)
2425 .Attr("T", DT_FLOAT)
2426 .Attr("out_type", DT_INT32)
2427 .Finalize(item.graph.add_node()));
2428 TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2429 .Input("segment_ids_", 0, DT_FLOAT)
2430 .Input("concat", 0, DT_INT32)
2431 .Attr("T", DT_FLOAT)
2432 .Attr("Tshape", DT_INT32)
2433 .Finalize(item.graph.add_node()));
2434 // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2435 // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2436 // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2437 // assign 10 from data shape to the unknown dim in segment_ids.
2438 TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2439 .Input("data", 0, DT_FLOAT)
2440 .Input("segment_ids", 0, DT_INT32)
2441 .Input("num_segments", 0, DT_INT32)
2442 .Attr("T", DT_FLOAT)
2443 .Attr("Tindices", DT_INT32)
2444 .Attr("Tnumsegments", DT_INT32)
2445 .Finalize(item.graph.add_node()));
2446 // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2447 // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2448 TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2449 .Attr("dtype", DT_FLOAT)
2450 .Finalize(item.graph.add_node()));
2451 TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2452 .Input("x1", 0, DT_FLOAT)
2453 .Input("concat", 0, DT_INT32)
2454 .Attr("T", DT_FLOAT)
2455 .Attr("Tshape", DT_INT32)
2456 .Finalize(item.graph.add_node()));
2457
2458 GraphProperties properties(item);
2459 TF_ASSERT_OK(properties.InferStatically(true));
2460 const auto& y1_output_properties = properties.GetOutputProperties("y1");
2461 // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2462 // The first dimension should not be 10.
2463 EXPECT_EQ(y1_output_properties.size(), 1);
2464 EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2465 EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2466 EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2467 }
2468
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2469 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2470 // Make a dummy InferenceContext.
2471 NodeDef node_def;
2472 OpRegistrationData op_reg_data;
2473 OpDefBuilder b("dummy");
2474 CHECK(b.Finalize(&op_reg_data).ok());
2475 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2476 input_handle_shapes_and_types;
2477 InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2478 /*input_shapes=*/{},
2479 /*input_tensors=*/{},
2480 /*input_tensors_as_shapes=*/{},
2481 std::move(input_handle_shapes_and_types));
2482
2483 // ShapeHandles for testing.
2484 ShapeHandle fully_defined_vector = ic.MakeShape(
2485 {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2486 ShapeHandle vector_with_unknown = ic.MakeShape(
2487 {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2488 // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2489 // graph_properties.cc
2490 ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2491 {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2492 ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2493
2494 // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2495 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2496 &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2497 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2498 &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2499
2500 // Non-integer data type.
2501 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2502 &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2503
2504 // tensor_as_shape including Unknown or UnknownFromConst.
2505 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2506 &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2507 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2508 &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2509 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2510 &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2511 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2512 &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2513
2514 // shape rank > 1.
2515 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2516 &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2517 }
2518 } // namespace
2519 } // namespace grappler
2520 } // namespace tensorflow
2521