1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 #include "tensorflow/cc/framework/scope.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/graph_def_util.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/grappler/clusters/single_machine.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/inputs/utils.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/io/path.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace tensorflow {
37 namespace grappler {
38 namespace {
39
40 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
41
42 class GraphPropertiesTest : public ::testing::Test {
43 public:
SetUp()44 void SetUp() override {
45 // Provision a single machine with 3 cpu cores
46 cluster_.reset(new SingleMachine(5 * 60, 3, 0));
47 TF_CHECK_OK(cluster_->Provision());
48
49 // This function is simply
50 // out = Fill(shape, value), but
51 // Fill requires values in the shape input, not just shape of it, to infer
52 // output shape.
53 auto f = FunctionDefHelper::Create(
54 // Name
55 "MyFillFunc",
56 // Inputs
57 {"shape: int32", "value: float"},
58 // Outputs
59 {"out: float"},
60 // Attrs
61 {},
62 // Nodes
63 {
64 {{"a"},
65 "Fill",
66 {"shape", "value"},
67 {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
68 },
69 // Returns
70 {{"out", "a:output:0"}});
71 function_lib_.add_function()->Swap(&f);
72 }
73
TearDown()74 void TearDown() override {
75 TF_CHECK_OK(cluster_->Shutdown());
76 cluster_.reset();
77 }
78
79 protected:
80 // Returns a string form of <p>, suitable for comparing type and shape.
81 // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)82 string PropToString(const OpInfo::TensorProperties& p) {
83 string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
84 if (p.shape().unknown_rank()) {
85 strings::StrAppend(&s, "?");
86 } else {
87 strings::StrAppend(&s, "[");
88 for (int i = 0; i < p.shape().dim_size(); ++i) {
89 strings::StrAppend(&s, i == 0 ? "" : ",",
90 std::max<int64>(p.shape().dim(i).size(), -1));
91 }
92 strings::StrAppend(&s, "]");
93 }
94 return s;
95 }
96
97 // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
98 // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)99 void ExpectTensorValues(const std::vector<int64>& expected,
100 const TensorProto& tensor_proto_to_compare) {
101 Tensor tensor;
102 EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
103 EXPECT_EQ(expected.size(), tensor.NumElements());
104 // We're interested in only integer tensors as only shapes are exported as
105 // graph properties values.
106 CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
107 if (tensor.dtype() == DT_INT32) {
108 for (int i = 0; i < tensor.NumElements(); i++) {
109 EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
110 }
111 } else {
112 for (int i = 0; i < tensor.NumElements(); i++) {
113 EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
114 }
115 }
116 }
117
118 std::unique_ptr<SingleMachine> cluster_;
119 FunctionDefLibrary function_lib_;
120 };
121
TEST_F(GraphPropertiesTest,StaticProperties)122 TEST_F(GraphPropertiesTest, StaticProperties) {
123 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
124 cluster_->GetDeviceNames());
125 GrapplerItem item;
126 CHECK(fake_input.NextItem(&item));
127
128 GraphProperties properties(item);
129 Status s = properties.InferStatically(true);
130 TF_CHECK_OK(s);
131
132 for (const auto& node : item.graph.node()) {
133 if (node.op() == "RandomStandardNormal") {
134 // The node has one input (the shape of the tensor to generate).
135 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
136 // The const node has one output.
137 const auto props = properties.GetOutputProperties(node.name());
138 EXPECT_EQ(1, props.size());
139 const OpInfo::TensorProperties& prop = props[0];
140 EXPECT_EQ(DT_FLOAT, prop.dtype());
141 EXPECT_FALSE(prop.shape().unknown_rank());
142 EXPECT_EQ(2, prop.shape().dim_size());
143 EXPECT_EQ(10, prop.shape().dim(0).size());
144 EXPECT_EQ(1, prop.shape().dim(1).size());
145 } else if (node.op() == "AddN") {
146 const auto in_props = properties.GetInputProperties(node.name());
147 EXPECT_EQ(1, in_props.size());
148 const OpInfo::TensorProperties& in_prop = in_props[0];
149 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
150 EXPECT_FALSE(in_prop.shape().unknown_rank());
151 EXPECT_EQ(2, in_prop.shape().dim_size());
152 EXPECT_EQ(10, in_prop.shape().dim(0).size());
153 EXPECT_EQ(1, in_prop.shape().dim(1).size());
154 const auto out_props = properties.GetOutputProperties(node.name());
155 EXPECT_EQ(1, out_props.size());
156 string in_prop_str;
157 ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str);
158 string out_prop_str;
159 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
160 &out_prop_str);
161 EXPECT_EQ(in_prop_str, out_prop_str);
162 }
163 }
164 }
165
TEST_F(GraphPropertiesTest,ClearProperties)166 TEST_F(GraphPropertiesTest, ClearProperties) {
167 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
168 cluster_->GetDeviceNames());
169 GrapplerItem item;
170 CHECK(fake_input.NextItem(&item));
171
172 GraphProperties properties(item);
173 Status s = properties.InferStatically(true);
174 TF_CHECK_OK(s);
175
176 for (const auto& node : item.graph.node()) {
177 if (node.op() == "RandomStandardNormal") {
178 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
179 const auto props = properties.GetOutputProperties(node.name());
180 properties.ClearOutputProperties(node.name());
181 const auto cleared_props = properties.GetOutputProperties(node.name());
182 EXPECT_TRUE(cleared_props.empty());
183 } else if (node.op() == "AddN") {
184 const auto in_props = properties.GetInputProperties(node.name());
185 EXPECT_EQ(1, in_props.size());
186 properties.ClearInputProperties(node.name());
187 const auto cleared_props = properties.GetInputProperties(node.name());
188 EXPECT_TRUE(cleared_props.empty());
189 }
190 }
191 }
192
TEST_F(GraphPropertiesTest,DynamicProperties)193 TEST_F(GraphPropertiesTest, DynamicProperties) {
194 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
195 cluster_->GetDeviceNames());
196 GrapplerItem item;
197 CHECK(fake_input.NextItem(&item));
198
199 GraphProperties properties(item);
200 TF_CHECK_OK(cluster_->Initialize(item));
201 Status s = properties.InferDynamically(cluster_.get());
202 TF_CHECK_OK(s);
203
204 for (const auto& node : item.graph.node()) {
205 if (node.op() == "RandomStandardNormal") {
206 // The random node is missing from the cost graph (why ?)
207 EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
208 } else if (node.op() == "AddN") {
209 // Since the random node is missing, we can't infer the input properties
210 // of the first AddN node. The other AddN nodes have the expected
211 // properties.
212 if (node.name() == "AddN") {
213 const auto props = properties.GetInputProperties(node.name());
214 EXPECT_EQ(1, props.size());
215 const OpInfo::TensorProperties& prop = props[0];
216 EXPECT_EQ(DT_INVALID, prop.dtype());
217 EXPECT_TRUE(prop.shape().unknown_rank());
218 } else {
219 const auto props = properties.GetInputProperties(node.name());
220 EXPECT_EQ(1, props.size());
221 const OpInfo::TensorProperties& prop = props[0];
222 EXPECT_EQ(DT_FLOAT, prop.dtype());
223 EXPECT_FALSE(prop.shape().unknown_rank());
224 EXPECT_EQ(2, prop.shape().dim_size());
225 EXPECT_EQ(10, prop.shape().dim(0).size());
226 EXPECT_EQ(1, prop.shape().dim(1).size());
227 const auto out_props = properties.GetOutputProperties(node.name());
228 EXPECT_EQ(1, out_props.size());
229 string prop_str;
230 ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
231 string out_prop_str;
232 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
233 &out_prop_str);
234 EXPECT_EQ(prop_str, out_prop_str);
235 }
236 }
237 }
238 }
239
TEST_F(GraphPropertiesTest,Variables)240 TEST_F(GraphPropertiesTest, Variables) {
241 GrapplerItem item;
242 TF_CHECK_OK(NodeDefBuilder("Var", "Variable")
243 .Attr("dtype", DT_FLOAT)
244 .Attr("shape", TensorShape({3, 7}))
245 .Finalize(item.graph.add_node()));
246 item.fetch.push_back("Var");
247
248 Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
249 test::FillIota<float>(&initial_val, 0);
250 TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const")
251 .Attr("dtype", DT_FLOAT)
252 .Attr("value", initial_val)
253 .Finalize(item.graph.add_node()));
254 TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign")
255 .Input("Var", 0, DT_FLOAT_REF)
256 .Input("InitialVal", 0, DT_FLOAT)
257 .Finalize(item.graph.add_node()));
258 item.init_ops.push_back("InitVar");
259
260 {
261 GraphProperties static_properties(item);
262 TF_CHECK_OK(static_properties.InferStatically(false));
263
264 const auto props = static_properties.GetOutputProperties("Var");
265 EXPECT_EQ(1, props.size());
266 const OpInfo::TensorProperties& prop = props[0];
267 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
268 EXPECT_FALSE(prop.shape().unknown_rank());
269 EXPECT_EQ(2, prop.shape().dim_size());
270 EXPECT_EQ(3, prop.shape().dim(0).size());
271 EXPECT_EQ(7, prop.shape().dim(1).size());
272 }
273 {
274 TF_CHECK_OK(cluster_->Initialize(item));
275 GraphProperties dynamic_properties(item);
276 TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get()));
277
278 const auto props = dynamic_properties.GetOutputProperties("Var");
279 EXPECT_EQ(1, props.size());
280 const OpInfo::TensorProperties& prop = props[0];
281 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
282 EXPECT_FALSE(prop.shape().unknown_rank());
283 EXPECT_EQ(2, prop.shape().dim_size());
284 EXPECT_EQ(3, prop.shape().dim(0).size());
285 EXPECT_EQ(7, prop.shape().dim(1).size());
286 }
287 }
288
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)289 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
290 GrapplerItem item;
291 TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
292 .Attr("dtype", DT_FLOAT)
293 .Attr("shape", TensorShape({3, 7}))
294 .Finalize(item.graph.add_node()));
295 TF_CHECK_OK(NodeDefBuilder("Enter", "Enter")
296 .Attr("T", DT_RESOURCE)
297 .Attr("frame_name", "while_context")
298 .Attr("is_constant", true)
299 .Attr("parallel_iterations", 10)
300 .Input("Var", 0, DT_RESOURCE)
301 .Finalize(item.graph.add_node()));
302 TF_CHECK_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
303 .Attr("dtype", DT_FLOAT)
304 .Input("Enter", 0, DT_RESOURCE)
305 .Finalize(item.graph.add_node()));
306
307 GraphProperties properties(item);
308 TF_CHECK_OK(properties.InferStatically(false));
309 const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
310 EXPECT_EQ(1, props.size());
311 const OpInfo::TensorProperties& prop = props[0];
312 EXPECT_EQ(DT_FLOAT, prop.dtype());
313 EXPECT_FALSE(prop.shape().unknown_rank());
314 EXPECT_EQ(2, prop.shape().dim_size());
315 EXPECT_EQ(3, prop.shape().dim(0).size());
316 EXPECT_EQ(7, prop.shape().dim(1).size());
317 }
318
TEST_F(GraphPropertiesTest,VarHandles)319 TEST_F(GraphPropertiesTest, VarHandles) {
320 GrapplerItem item;
321 TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
322 .Attr("dtype", DT_FLOAT)
323 .Attr("shape", TensorShape({3, 7}))
324 .Finalize(item.graph.add_node()));
325
326 TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
327 .Attr("dtype", DT_FLOAT)
328 .Input("Var", 0, DT_RESOURCE)
329 .Finalize(item.graph.add_node()));
330
331 GraphProperties properties(item);
332 TF_CHECK_OK(properties.InferStatically(false));
333
334 const auto props = properties.GetOutputProperties("VarRead");
335 EXPECT_EQ(1, props.size());
336 const OpInfo::TensorProperties& prop = props[0];
337 EXPECT_EQ(DT_FLOAT, prop.dtype());
338 EXPECT_FALSE(prop.shape().unknown_rank());
339 EXPECT_EQ(2, prop.shape().dim_size());
340 EXPECT_EQ(3, prop.shape().dim(0).size());
341 EXPECT_EQ(7, prop.shape().dim(1).size());
342 }
343
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)344 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
345 // Test graph is first generated in python using:
346 /*
347 i0 = tf.constant(0)
348 v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
349 def cond(i, x):
350 return i < 3
351 def body(i, x):
352 return i + 1, x + x
353 v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
354 */
355 // and then modified by hand such that the ReadVariableOp is inside the loop
356 // body instead of outside the while loop (which is the case when constructed
357 // using the python API), such that we have the following pattern: VarHandleOp
358 // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
359 // DT_RESOURCE is passed all the way until ReadVariableOp.
360 GrapplerItem item;
361 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
362 "while_loop_var_handle_op.pbtxt");
363 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
364 GraphProperties properties(item);
365 TF_CHECK_OK(properties.InferStatically(false));
366
367 std::vector<string> resource_nodes{
368 "loop_var", "while/Enter", "while/Merge", "while/Switch",
369 "while/Identity", "while/NextIteration", "while/Exit"};
370 for (const string& node : resource_nodes) {
371 const auto props = properties.GetOutputProperties(node);
372 EXPECT_GE(props.size(), 1); // Merge has 2 outputs.
373 EXPECT_EQ("resource: []", PropToString(props[0]));
374 }
375
376 // After ReadVariableOp, the shape should be recovered.
377 const auto props = properties.GetOutputProperties("while/ReadVariableOp");
378 EXPECT_EQ(1, props.size());
379 EXPECT_EQ("int32: []", PropToString(props[0]));
380 }
381
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)382 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
383 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
384 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
385 auto dequeue1 =
386 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
387
388 GrapplerItem item;
389 TF_CHECK_OK(root.ToGraphDef(&item.graph));
390
391 GraphProperties properties(item);
392 TF_CHECK_OK(properties.InferStatically(false));
393
394 const auto props1 = properties.GetOutputProperties("Dequeue1");
395 ASSERT_EQ(1, props1.size());
396 EXPECT_EQ("float: ?", PropToString(props1[0]));
397 }
398
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)399 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
400 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
401 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
402 ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
403 auto dequeue1 =
404 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
405
406 GrapplerItem item;
407 TF_CHECK_OK(root.ToGraphDef(&item.graph));
408
409 GraphProperties properties(item);
410 TF_CHECK_OK(properties.InferStatically(false));
411
412 const auto props1 = properties.GetOutputProperties("Dequeue1");
413 ASSERT_EQ(1, props1.size());
414 EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
415 }
416
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)417 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
418 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
419 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
420 ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
421 auto dequeue1 =
422 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
423
424 GrapplerItem item;
425 TF_CHECK_OK(root.ToGraphDef(&item.graph));
426
427 GraphProperties properties(item);
428 TF_CHECK_OK(properties.InferStatically(false));
429
430 const auto props1 = properties.GetOutputProperties("Dequeue1");
431 ASSERT_EQ(1, props1.size());
432 EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
433 }
434
TEST_F(GraphPropertiesTest,Queues)435 TEST_F(GraphPropertiesTest, Queues) {
436 // Create a graph with known input shapes, and propagate the shapes through a
437 // couple of queues.
438 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439
440 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
441 Output rnd =
442 ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
443 Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
444 auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
445 auto dequeue1 =
446 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
447
448 auto q2 =
449 ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
450 Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
451 auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
452 auto dequeue2 =
453 ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
454
455 auto q4 =
456 ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
457 auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
458 auto enqueue4_2 =
459 ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
460 auto dequeue4 =
461 ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
462
463 // Create a queue that takes in three tensors.
464 auto q5 = ops::RandomShuffleQueue(
465 root.WithOpName("Queue5"),
466 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
467 Output rnd2 =
468 ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
469 Output rnd3 =
470 ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
471 auto enqueue5 =
472 ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
473 auto dequeue5 = ops::QueueDequeue(
474 root.WithOpName("Dequeue5"), q5,
475 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
476
477 GrapplerItem item;
478 TF_CHECK_OK(root.ToGraphDef(&item.graph));
479
480 GraphProperties properties(item);
481 TF_CHECK_OK(properties.InferStatically(false));
482
483 const auto props1 = properties.GetOutputProperties("Dequeue1");
484 ASSERT_EQ(1, props1.size());
485 EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
486
487 const auto props2 = properties.GetOutputProperties("Dequeue2");
488 ASSERT_EQ(1, props2.size());
489 EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
490
491 // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
492 // that we merge the 2 properly to determine the shape of the data coming out
493 // of the queue.
494 const auto props4 = properties.GetOutputProperties("Dequeue4");
495 ASSERT_EQ(1, props4.size());
496 EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
497
498 // The dequeue5 op shape is known.
499 const auto props5 = properties.GetOutputProperties("Dequeue5");
500 ASSERT_EQ(3, props5.size());
501 EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
502 EXPECT_EQ("double: [10]", PropToString(props5[1]));
503 EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
504 }
505
TEST_F(GraphPropertiesTest,MergeWithoutLoops)506 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
507 // Test graph produced in python using:
508 /*
509 with tf.Graph().as_default():
510 x = tf.constant(2)
511 y = tf.constant(5)
512 z = tf.ones([1,1,1])
513 def f1(): return tf.concat([z, z], axis=0)
514 def f2(): return tf.concat([z, z], axis=1)
515 r = tf.cond(tf.less(x, y), f1, f2)
516 tf.concat([r, r], axis=2)
517 with open('/tmp/graph.pbtxt', 'w') as f:
518 f.write(str(tf.get_default_graph().as_graph_def()))
519 */
520
521 GrapplerItem item;
522 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
523 "merge_without_loops.pbtxt");
524 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
525 GraphProperties properties(item);
526 TF_CHECK_OK(properties.InferStatically(false));
527
528 std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
529 std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
530 "float: [1,2,1]"};
531 for (int i = 0; i < nodes.size(); i++) {
532 const auto props = properties.GetOutputProperties(nodes[i]);
533 const OpInfo::TensorProperties& prop = props[0];
534 EXPECT_EQ(DT_FLOAT, prop.dtype());
535 EXPECT_EQ(expected_outputs[i], PropToString(prop));
536 }
537
538 // The "Less" node should be fed by 2 int32 scalar constant values.
539 const auto props = properties.GetInputProperties("Less");
540 EXPECT_EQ(2, props.size());
541 for (int i = 0; i < props.size(); ++i) {
542 EXPECT_EQ(DT_INT32, props[i].dtype());
543 EXPECT_TRUE(props[i].has_value());
544 EXPECT_EQ("int32: []", PropToString(props[i]));
545 }
546 }
547
TEST_F(GraphPropertiesTest,WhileLoop)548 TEST_F(GraphPropertiesTest, WhileLoop) {
549 // Test graph produced in python using:
550 /*
551 with tf.Graph().as_default():
552 i0 = tf.constant(0)
553 m0 = tf.placeholder([-1, 2])
554 c = lambda i, m: i < 10
555 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
556 r = tf.while_loop(
557 c, b, loop_vars=[i0, m0],
558 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
559 with open('/tmp/graph.pbtxt', 'w') as f:
560 f.write(str(tf.get_default_graph().as_graph_def()))
561 */
562
563 GrapplerItem item;
564 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
565 "while_loop.pbtxt");
566 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
567 GraphProperties properties(item);
568 TF_CHECK_OK(properties.InferStatically(false));
569
570 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
571 "while/Exit_1"};
572 for (const string& node : nodes) {
573 const auto props = properties.GetOutputProperties(node);
574 const OpInfo::TensorProperties& prop = props[0];
575 EXPECT_EQ(DT_FLOAT, prop.dtype());
576 EXPECT_EQ("float: [-1,2]", PropToString(prop));
577 }
578
579 // The loop outputs batch dim should be different from the input batch dim
580 // since we concatenated along the batch dim.
581 auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
582 auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
583 EXPECT_GE(-2, shape_in.dim(0).size());
584 EXPECT_GE(-2, shape_out.dim(0).size());
585 EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
586 }
587
TEST_F(GraphPropertiesTest,NestedLoop)588 TEST_F(GraphPropertiesTest, NestedLoop) {
589 // Test graph produced in python using:
590 /*
591 with tf.Graph().as_default():
592 i0 = tf.constant(0)
593
594 def inner(j, y):
595 def inner_cond(j, y):
596 return j < 3
597
598 def inner_body(j, y):
599 return j+1, tf.concat([y, y], axis=2)
600
601 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
602 shape_invariants=[i0.get_shape(),
603 tf.TensorShape([None, 1, None])])
604
605 def outer_cond(i, x):
606 return i < 3
607
608 def outer_body(i, x):
609 j, y = inner(0, x)
610 return i+1, tf.concat([x, x], axis=0)
611
612 r = tf.while_loop(outer_cond, outer_body,
613 loop_vars=[i0, tf.ones([1, 1, 1])],
614 shape_invariants=[i0.get_shape(),
615 tf.TensorShape([None, 1, None])])
616
617 with open('/tmp/graph.pbtxt', 'w') as f:
618 f.write(str(tf.get_default_graph().as_graph_def()))
619 */
620
621 GrapplerItem item;
622 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
623 "nested_loop.pbtxt");
624 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
625 GraphProperties properties(item);
626 TF_CHECK_OK(properties.InferStatically(false));
627
628 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
629 "while/Exit_1"};
630 std::vector<string> inner_nodes{"while/while/Merge_1",
631 "while/while/NextIteration_1",
632 "while/while/Exit_1"};
633 for (const string& node : outer_nodes) {
634 const auto props = properties.GetOutputProperties(node);
635 const OpInfo::TensorProperties& prop = props[0];
636 EXPECT_EQ(DT_FLOAT, prop.dtype());
637 EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
638 }
639 for (const string& node : inner_nodes) {
640 const auto props = properties.GetOutputProperties(node);
641 const OpInfo::TensorProperties& prop = props[0];
642 EXPECT_EQ(DT_FLOAT, prop.dtype());
643 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
644 }
645 }
646
TEST_F(GraphPropertiesTest,LoopsAndQueues)647 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
648 // Test graph produced in python using:
649 /*
650 with tf.Graph().as_default():
651 i0 = tf.constant(0)
652 q = tf.FIFOQueue(1, "float")
653
654 def inner(j, y):
655 def inner_cond(j, y):
656 return j < 3
657
658 def inner_body(j, y):
659 return j+1, tf.concat([y, y], axis=0)
660
661 return tf.while_loop(inner_cond, inner_body,
662 loop_vars=[j, y],
663 shape_invariants=[i0.get_shape(),
664 tf.TensorShape(None)])
665
666 def outer_cond(i, x):
667 return i < 3
668
669 def outer_body(i, x):
670 q.enqueue(x)
671 y = tf.concat([x, x], axis=2)
672 inner(0, q.dequeue())
673 return i+1, y
674
675 i, z = tf.while_loop(outer_cond, outer_body,
676 loop_vars=[i0, tf.ones([1, 1, 1])],
677 shape_invariants=[i0.get_shape(),
678 tf.TensorShape([None, 1, None])])
679
680 with open('/tmp/graph.pbtxt', 'w') as f:
681 f.write(str(tf.get_default_graph().as_graph_def()))
682 */
683
684 GrapplerItem item;
685 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
686 "loops_and_queues.pbtxt");
687 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
688 GraphProperties properties(item);
689 TF_CHECK_OK(properties.InferStatically(false));
690
691 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
692 "while/Exit_1"};
693 std::vector<string> inner_nodes{"while/while/Merge_1",
694 "while/while/NextIteration_1",
695 "while/while/Exit_1"};
696 for (const string& node : outer_nodes) {
697 const auto props = properties.GetOutputProperties(node);
698 const OpInfo::TensorProperties& prop = props[0];
699 EXPECT_EQ(DT_FLOAT, prop.dtype());
700 EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
701 }
702 for (const string& node : inner_nodes) {
703 const auto props = properties.GetOutputProperties(node);
704 const OpInfo::TensorProperties& prop = props[0];
705 EXPECT_EQ(DT_FLOAT, prop.dtype());
706 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
707 }
708 }
709
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)710 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
711 // Test graph produced in python using:
712 /*
713 with tf.Graph().as_default():
714 i0 = tf.constant(0)
715 with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
716 v = tf.get_variable(initializer=i0, name='loop_var')
717
718 def inner(j, y):
719 def inner_cond(j, y):
720 return j < 3
721
722 def inner_body(j, y):
723 return j + 1, y + y
724
725 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
726
727 def outer_cond(i, x):
728 return i < 3
729
730 def outer_body(i, x):
731 y = x + x
732 inner(0, v)
733 return i + 1, y
734
735 v, z = tf.while_loop(outer_cond, outer_body,
736 loop_vars=[v, tf.constant(1)])
737
738 with open('/tmp/graph.pbtxt', 'w') as f:
739 f.write(str(tf.get_default_graph().as_graph_def()))
740 */
741
742 GrapplerItem item;
743 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
744 "loops_and_resource_vars.pbtxt");
745 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
746 GraphProperties properties(item);
747 TF_CHECK_OK(properties.InferStatically(false));
748
749 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
750 "while/Exit_1"};
751 std::vector<string> inner_nodes{"while/while/Merge_1",
752 "while/while/NextIteration_1",
753 "while/while/Exit_1"};
754 for (const string& node : outer_nodes) {
755 const auto props = properties.GetOutputProperties(node);
756 const OpInfo::TensorProperties& prop = props[0];
757 EXPECT_EQ(DT_INT32, prop.dtype());
758 EXPECT_EQ("int32: []", PropToString(prop));
759 }
760 for (const string& node : inner_nodes) {
761 const auto props = properties.GetOutputProperties(node);
762 const OpInfo::TensorProperties& prop = props[0];
763 EXPECT_EQ(DT_INT32, prop.dtype());
764 EXPECT_EQ("int32: []", PropToString(prop));
765 }
766 }
767
TEST_F(GraphPropertiesTest,QueuesAndLoops)768 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
769 // Test graph produced in python using:
770 /*
771 with tf.Graph().as_default():
772 i0 = tf.constant(0)
773 q0 = tf.FIFOQueue(1, "float")
774 q0.enqueue(tf.ones([2, 2]))
775 q1 = tf.FIFOQueue(1, "float")
776
777 def c(i, m):
778 return i < 10
779
780 def b(i, m):
781 return i+1, tf.concat([m, m], axis=0)
782
783 i, m = tf.while_loop(
784 c, b, loop_vars=[i0, q0.dequeue()],
785 shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
786
787 q1.enqueue(m)
788 v = q1.dequeue();
789 tf.concat([v, v], axis=1)
790 with open('/tmp/graph.pbtxt', 'w') as f:
791 f.write(str(tf.get_default_graph().as_graph_def()))
792 */
793
794 GrapplerItem item;
795 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
796 "queues_and_loops.pbtxt");
797 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
798 GraphProperties properties(item);
799 TF_CHECK_OK(properties.InferStatically(false));
800
801 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
802 "while/Exit_1"};
803
804 for (const string& node : nodes) {
805 const auto props = properties.GetOutputProperties(node);
806 const OpInfo::TensorProperties& prop = props[0];
807 EXPECT_EQ(DT_FLOAT, prop.dtype());
808 EXPECT_EQ("float: [-1,2]", PropToString(prop));
809 }
810
811 const auto props = properties.GetOutputProperties("concat");
812 const OpInfo::TensorProperties& prop = props[0];
813 EXPECT_EQ(DT_FLOAT, prop.dtype());
814 EXPECT_EQ("float: [-1,4]", PropToString(prop));
815 }
816
TEST_F(GraphPropertiesTest,InferRestoreOpShape)817 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
818 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
819 Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
820 DataType::DT_FLOAT);
821 Output filename =
822 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
823 Output tensor_name =
824 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
825 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
826 DataType::DT_FLOAT);
827 Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
828
829 Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
830 string("256 256 0,128:-"), TensorShape());
831 Output restore_slice =
832 ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
833 shape_and_slice, DataType::DT_FLOAT);
834 Output init_restore_slice =
835 ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
836
837 Output restore_v2 =
838 ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
839 shape_and_slice, DataType::DT_FLOAT);
840 Output init_restore_v2 =
841 ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
842
843 GrapplerItem item;
844 TF_CHECK_OK(s.ToGraphDef(&item.graph));
845 item.fetch.push_back("init_restore");
846
847 GraphProperties properties(item);
848 TF_CHECK_OK(properties.InferStatically(false));
849
850 const auto restore_props = properties.GetOutputProperties("restore");
851 const OpInfo::TensorProperties& restore_prop = restore_props[0];
852 EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
853 EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
854
855 const auto restore_slice_props =
856 properties.GetOutputProperties("restore_slice");
857 const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
858 EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
859 EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
860
861 const auto restorev2_props = properties.GetOutputProperties("restore_v2");
862 const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
863 EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
864 EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
865
866 // Check input shapes of assign op are propagted correctly.
867 const auto input_props = properties.GetInputProperties("init_restore");
868 ASSERT_EQ(2, input_props.size());
869 const OpInfo::TensorProperties& input_prop = input_props[1];
870 EXPECT_EQ(DT_FLOAT, input_prop.dtype());
871 EXPECT_EQ("float: [128,256]", PropToString(input_prop));
872 }
873
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)874 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
875 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
876 Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
877 DataType::DT_FLOAT);
878 Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
879 DataType::DT_FLOAT);
880 Output filename =
881 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
882 Output tensor_name =
883 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
884 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
885 DataType::DT_FLOAT);
886 Output init = ops::Assign(s.WithOpName("init"), var, restore);
887 Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
888
889 GrapplerItem item;
890 TF_CHECK_OK(s.ToGraphDef(&item.graph));
891 item.fetch.push_back("init");
892 item.fetch.push_back("init2");
893
894 GraphProperties properties(item);
895 TF_CHECK_OK(properties.InferStatically(false));
896
897 const auto props = properties.GetOutputProperties("restore");
898 const OpInfo::TensorProperties& prop = props[0];
899 EXPECT_EQ(DT_FLOAT, prop.dtype());
900 EXPECT_EQ("float: [128,256]", PropToString(prop));
901 }
902
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)903 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
904 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
905 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
906 Output a1 = ops::Identity(s.WithOpName("a1"), a);
907 Output b = ops::Const(s.WithOpName("b"), 99, {});
908 Output b1 = ops::Identity(s.WithOpName("b1"), b);
909 Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
910 Output c1 = ops::Identity(s.WithOpName("c1"), c);
911
912 GrapplerItem item;
913 TF_CHECK_OK(s.ToGraphDef(&item.graph));
914 GraphProperties properties(item);
915 TF_CHECK_OK(properties.InferStatically(false));
916
917 // Check output shapes.
918 EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
919 EXPECT_EQ("int32: [2]",
920 PropToString(properties.GetOutputProperties("a1")[0]));
921 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
922 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
923 EXPECT_EQ("int32: [4,4,4]",
924 PropToString(properties.GetOutputProperties("c")[0]));
925 EXPECT_EQ("int32: [4,4,4]",
926 PropToString(properties.GetOutputProperties("c1")[0]));
927
928 // Check has_value.
929 EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
930 EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
931 EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
932 EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
933 EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
934 EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
935 EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
936 EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
937 // Note that we propagate tensor value of only 1D vector and scalar.
938 EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
939
940 // Check values.
941 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
942 ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
943 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
944 ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
945 ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
946 ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
947 std::vector<int64> c_values;
948 for (int i = 0; i < 4 * 4 * 4; i++) {
949 c_values.push_back(1);
950 }
951 ExpectTensorValues({c_values},
952 properties.GetOutputProperties("c")[0].value());
953 ExpectTensorValues({c_values},
954 properties.GetInputProperties("c1")[0].value());
955 ExpectTensorValues({c_values},
956 properties.GetOutputProperties("c1")[0].value());
957 }
958
TEST_F(GraphPropertiesTest,IdentityPassingShape)959 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
960 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
961 Output a = ops::Const(s.WithOpName("a"), 5, {2});
962 Output b = ops::Identity(s.WithOpName("b"), a);
963 Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
964 // Fill needs not only e's shape but also the value of e to figure out output
965 // shape; hence, Identity op (b) should pass a's value as
966 // output_tensors_as_shape.
967 Output d = ops::Fill(s.WithOpName("fill"), b, c);
968
969 GrapplerItem item;
970 TF_CHECK_OK(s.ToGraphDef(&item.graph));
971 GraphProperties properties(item);
972 TF_CHECK_OK(properties.InferStatically(false));
973 const auto out_props = properties.GetOutputProperties("fill");
974 const OpInfo::TensorProperties out_prop0 = out_props[0];
975 EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
976 }
977
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)978 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
979 // When using aggressive_shape_inference, we run EvaluateNode() for
980 // whitelisted ops and small input / output tensors. For instance, Fill op is
981 // evaluated and produces output tensor value if output tensor size is smal
982 // (currently, fewer than 17 elements); otherwise we don't run EvalauteNode().
983 // This is to avoid wasting time and memory for producing huge tensors (e.g.,
984 // initializing a large table using Fill.
985 {
986 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
987 Output a = ops::Const(s.WithOpName("a"), 4, {2}); // 4x4
988 Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
989 // Shape described by a is small; expect output values of Fill op.
990 Output c = ops::Fill(s.WithOpName("fill"), a, b);
991
992 GrapplerItem item;
993 TF_CHECK_OK(s.ToGraphDef(&item.graph));
994 GraphProperties properties(item);
995 TF_CHECK_OK(properties.InferStatically(
996 /*assume_valid_feeds=*/false,
997 /*aggressive_shape_inference=*/true));
998 const auto out_props = properties.GetOutputProperties("fill");
999 const OpInfo::TensorProperties out_prop0 = out_props[0];
1000 EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
1001 EXPECT_TRUE(out_prop0.has_value());
1002 }
1003 {
1004 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1005 Output a = ops::Const(s.WithOpName("a"), 1000, {4}); // 1000x1000x1000x1000
1006 Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1007 // Shape described by a is huge; in that case we skip value inference.
1008 // Otherwise, it'd be too much overhead.
1009 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1010
1011 GrapplerItem item;
1012 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1013 GraphProperties properties(item);
1014 TF_CHECK_OK(properties.InferStatically(
1015 /*assume_valid_feeds=*/false,
1016 /*aggressive_shape_inference=*/true));
1017 const auto out_props = properties.GetOutputProperties("fill");
1018 const OpInfo::TensorProperties out_prop0 = out_props[0];
1019 EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
1020 EXPECT_FALSE(out_prop0.has_value());
1021 }
1022 }
1023
TEST_F(GraphPropertiesTest,PackWithConstInput)1024 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1025 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1026 Output a = ops::Const(s.WithOpName("a"), 1, {});
1027 Output b = ops::Const(s.WithOpName("b"), 2, {});
1028 Output c = ops::Const(s.WithOpName("c"), 3, {});
1029 Output d = ops::Const(s.WithOpName("d"), 4, {});
1030 // Note ops::Stack instantiates Pack op.
1031 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1032 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1033 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1034 // Fill needs not only e's shape but also its value to figure out output
1035 // shape.
1036 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1037
1038 GrapplerItem item;
1039 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1040 GraphProperties properties(item);
1041 TF_CHECK_OK(properties.InferStatically(false));
1042 const auto out_props = properties.GetOutputProperties("fill");
1043 const OpInfo::TensorProperties out_prop0 = out_props[0];
1044 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1045 }
1046
TEST_F(GraphPropertiesTest,RankOp)1047 TEST_F(GraphPropertiesTest, RankOp) {
1048 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1049 Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1050 Output r = ops::Rank(s.WithOpName("Rank"), c);
1051 Output i = ops::Identity(s.WithOpName("Identity"), r);
1052
1053 GrapplerItem item;
1054 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1055 GraphProperties properties(item);
1056 TF_CHECK_OK(properties.InferStatically(false));
1057 const auto rank_props = properties.GetOutputProperties("Rank");
1058 const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1059 EXPECT_EQ("int32: []", PropToString(rank_prop0));
1060 EXPECT_TRUE(rank_prop0.has_value());
1061 ExpectTensorValues({3}, rank_prop0.value());
1062 const auto identity_props = properties.GetOutputProperties("Identity");
1063 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1064 EXPECT_EQ("int32: []", PropToString(identity_props0));
1065 EXPECT_TRUE(identity_props0.has_value());
1066 ExpectTensorValues({3}, identity_props0.value());
1067 }
1068
TEST_F(GraphPropertiesTest,SizeOp)1069 TEST_F(GraphPropertiesTest, SizeOp) {
1070 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1071 Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1072 Output r = ops::Size(s.WithOpName("Size"), c);
1073 Output i = ops::Identity(s.WithOpName("Identity"), r);
1074
1075 GrapplerItem item;
1076 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1077 GraphProperties properties(item);
1078 TF_CHECK_OK(properties.InferStatically(false));
1079 const auto size_props = properties.GetOutputProperties("Size");
1080 const OpInfo::TensorProperties size_props0 = size_props[0];
1081 EXPECT_EQ("int32: []", PropToString(size_props0));
1082 EXPECT_TRUE(size_props0.has_value());
1083 ExpectTensorValues({24}, size_props0.value());
1084 const auto identity_props = properties.GetOutputProperties("Identity");
1085 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1086 EXPECT_EQ("int32: []", PropToString(identity_props0));
1087 EXPECT_TRUE(identity_props0.has_value());
1088 ExpectTensorValues({24}, identity_props0.value());
1089 }
1090
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1091 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1092 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1093 // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1094 // from Const.
1095 // If output_tensors_as_shape is not not set for those Shape ops or Pack op
1096 // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1097 // hence, its output shape becomes unknown.
1098 Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1099 Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1100 Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1101 Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1102 Output a = ops::Identity(s.WithOpName("a"), a0);
1103 Output b = ops::Identity(s.WithOpName("b"), b0);
1104 Output c = ops::Identity(s.WithOpName("c"), c0);
1105 Output d = ops::Identity(s.WithOpName("d"), d0);
1106 // Note ops::Stack instantiates Pack op.
1107 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1108 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1109 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1110 // Fill needs not only e's shape but also its value to figure out output
1111 // shape.
1112 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1113
1114 GrapplerItem item;
1115 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1116 GraphProperties properties(item);
1117 TF_CHECK_OK(properties.InferStatically(false));
1118 const auto out_props = properties.GetOutputProperties("fill");
1119 const OpInfo::TensorProperties out_prop0 = out_props[0];
1120 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1121 }
1122
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1123 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1124 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1125 TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
1126 Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1127 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1128 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1129 s.graph()->op_registry());
1130 tensorflow::Node* func_op;
1131 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1132 auto _value = tensorflow::ops::AsNodeOut(s, value);
1133 TF_CHECK_OK(
1134 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1135 GrapplerItem item;
1136 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1137
1138 GraphProperties properties(item);
1139 TF_CHECK_OK(properties.InferStatically(false));
1140 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1141 const OpInfo::TensorProperties out_prop0 = out_props[0];
1142 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1143 }
1144
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1145 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1146 // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1147 // so tensor shapes, not tensor value, should be used as Const input to
1148 // function.
1149 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1150 TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
1151 Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1152 Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1153 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1154 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1155 s.graph()->op_registry());
1156 tensorflow::Node* func_op;
1157 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1158 auto _value = tensorflow::ops::AsNodeOut(s, value);
1159 TF_CHECK_OK(
1160 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1161 GrapplerItem item;
1162 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1163
1164 GraphProperties properties(item);
1165 TF_CHECK_OK(properties.InferStatically(false));
1166 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1167 const OpInfo::TensorProperties out_prop0 = out_props[0];
1168 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1169 }
1170
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1171 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1172 FunctionDefLibrary library;
1173 *library.add_function() = FunctionDefHelper::Create(
1174 "MyFunc", // Name
1175 {"x: int32"}, // Inputs
1176 {"out: int32"}, // Outputs
1177 {}, // Attrs
1178 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1179 {{"out", "a:output:0"}}); // Returns
1180 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1181 TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1182
1183 // MyFunc takes Const (shape) and passes it with Identity. Expect function
1184 // output has the same shape as well as value (output_tensors_as_shape) as
1185 // input Const tensor.
1186 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1187 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1188 auto builder =
1189 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1190 tensorflow::Node* func_op;
1191 TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1192
1193 GrapplerItem item;
1194 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1195
1196 GraphProperties properties(item);
1197 TF_CHECK_OK(properties.InferStatically(true));
1198 const auto out_props = properties.GetOutputProperties("MyFunc");
1199 const OpInfo::TensorProperties out_prop0 = out_props[0];
1200 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1201 EXPECT_TRUE(out_prop0.has_value());
1202 ExpectTensorValues({5, 7}, out_prop0.value());
1203 ExpectTensorValues({5, 7},
1204 properties.GetInputProperties("MyFunc")[0].value());
1205 }
1206
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1207 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1208 FunctionDefLibrary library;
1209 // Function that adds two input values.
1210 *library.add_function() = FunctionDefHelper::Create(
1211 "MyFunc", // Name
1212 {"x: int32", "y: int32"}, // Inputs
1213 {"out: int32"}, // Outputs
1214 {}, // Attrs
1215 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1216 {{"out", "a:z:0"}}); // Returns
1217 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1218 TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1219
1220 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1221 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1222 auto builder =
1223 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1224 tensorflow::Node* func_op;
1225 TF_CHECK_OK(
1226 builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1227
1228 GrapplerItem item;
1229 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1230 {
1231 GraphProperties properties(item);
1232 // Without aggressive_shape_inference, the internal function does not
1233 // evaluate output value.
1234 TF_CHECK_OK(properties.InferStatically(
1235 /*assume_valid_feeds=*/true,
1236 /*aggressive_shape_inference=*/false));
1237 const auto out_props = properties.GetOutputProperties("MyFunc");
1238 const OpInfo::TensorProperties out_prop0 = out_props[0];
1239 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1240 EXPECT_FALSE(out_prop0.has_value());
1241 }
1242
1243 {
1244 GraphProperties properties(item);
1245 // With aggressive_shape_inference, output value is evaluated.
1246 TF_CHECK_OK(properties.InferStatically(
1247 /*assume_valid_feeds=*/true,
1248 /*aggressive_shape_inference=*/true));
1249 const auto out_props = properties.GetOutputProperties("MyFunc");
1250 const OpInfo::TensorProperties out_prop0 = out_props[0];
1251 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1252 EXPECT_TRUE(out_prop0.has_value());
1253
1254 ExpectTensorValues({10, 14}, out_prop0.value());
1255 ExpectTensorValues({5, 7},
1256 properties.GetInputProperties("MyFunc")[0].value());
1257 ExpectTensorValues({5, 7},
1258 properties.GetInputProperties("MyFunc")[1].value());
1259 }
1260 }
1261
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1262 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1263 // Create graph with a function that takes a scalar value so that we use
1264 // Placeholder with scalar as for input to the function shape inference.
1265 // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1266 // the input; all tensors are scalars.
1267 FunctionDefLibrary library;
1268 *library.add_function() = FunctionDefHelper::Create(
1269 "MyFunc", // Name
1270 {"x: float"}, // Inputs
1271 {"out: float"}, // Outputs
1272 {}, // Attrs
1273 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1274 {{"out", "a:output:0"}}); // Returns
1275 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1276 TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1277 Output placeholder =
1278 ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1279 ops::Placeholder::Shape(TensorShape({})));
1280 Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1281 auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1282 auto builder =
1283 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1284 tensorflow::Node* func_op;
1285 TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1286 GrapplerItem item;
1287 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1288
1289 // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1290 // as unknown, instead of scalar.
1291 EXPECT_GT(item.graph.versions().producer(), 21);
1292
1293 // MyFunc output shouldn't be unknown rank.
1294 GraphProperties properties(item);
1295 TF_CHECK_OK(properties.InferStatically(true));
1296 const auto out_props = properties.GetOutputProperties("MyFunc");
1297 const OpInfo::TensorProperties out_prop0 = out_props[0];
1298 EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1299 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1300 }
1301
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1302 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1303 // Test graph produced in python using:
1304 /*
1305 @function.Defun(*[tf.float32] * 2, noinline=True)
1306 def MyAdd(x, y):
1307 return tf.add(x,y)
1308
1309 with tf.Graph().as_default():
1310 x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1311 y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1312 z = MyAdd(x, y)
1313 z = MyAdd(x, z)
1314 */
1315 GrapplerItem item;
1316 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1317 "simple_function.pbtxt");
1318 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1319 GraphProperties properties(item);
1320 TF_CHECK_OK(properties.InferStatically(false));
1321 const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1322 const OpInfo::TensorProperties& out_prop = out_props[0];
1323 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1324 EXPECT_FALSE(out_prop.shape().unknown_rank());
1325 EXPECT_EQ(2, out_prop.shape().dim_size());
1326 EXPECT_EQ(1, out_prop.shape().dim(0).size());
1327 EXPECT_EQ(2, out_prop.shape().dim(1).size());
1328
1329 const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1330 EXPECT_EQ(2, in_props.size());
1331
1332 const OpInfo::TensorProperties& in_prop = in_props[0];
1333 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1334
1335 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1336 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1337 }
1338
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1339 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1340 GrapplerItem item;
1341 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1342 "large_function_graph.pbtxt");
1343 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1344 GraphProperties properties(item);
1345 TF_CHECK_OK(properties.InferStatically(false));
1346
1347 const auto out_props = properties.GetOutputProperties("y0");
1348 EXPECT_EQ(2, out_props.size());
1349
1350 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1351 EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1352
1353 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1354 EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1355
1356 const auto in_props = properties.GetInputProperties("y0");
1357 EXPECT_EQ(4, in_props.size());
1358
1359 const OpInfo::TensorProperties& in_prop0 = in_props[0];
1360 EXPECT_EQ("float: [64]", PropToString(in_prop0));
1361
1362 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1363 EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1364
1365 const OpInfo::TensorProperties& in_prop2 = in_props[2];
1366 EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1367
1368 const OpInfo::TensorProperties& in_prop3 = in_props[3];
1369 EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1370 }
1371
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1372 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1373 // Test graph produced in python using:
1374 /*
1375 @function.Defun(noinline=True)
1376 def MyFunc():
1377 @function.Defun(*[tf.float32] * 2)
1378 def Cond(n, unused_x):
1379 return n > 0
1380
1381 @function.Defun(*[tf.float32] * 2)
1382 def Body(n, x):
1383 return n - 1, x + n
1384
1385 i = tf.constant(10)
1386 return functional_ops.While([i, 0.], Cond, Body)
1387
1388 with tf.Graph().as_default():
1389 z = MyFunc()
1390 */
1391 GrapplerItem item;
1392 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1393 "function_functional_while.pbtxt");
1394 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1395 GraphProperties properties(item);
1396 TF_CHECK_OK(properties.InferStatically(false));
1397
1398 const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1399 EXPECT_EQ(2, out_props.size());
1400
1401 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1402 EXPECT_EQ(DT_INT32, out_prop0.dtype());
1403 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1404
1405 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1406 EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1407 EXPECT_FALSE(out_prop1.shape().unknown_rank());
1408 }
1409
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1410 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1411 GrapplerItem item;
1412 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1413 "function_error.pbtxt");
1414 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1415 GraphProperties properties(item);
1416 TF_CHECK_OK(properties.InferStatically(false));
1417
1418 const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1419 EXPECT_EQ(1, out_props.size());
1420
1421 const OpInfo::TensorProperties& out_prop = out_props[0];
1422 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1423 EXPECT_TRUE(out_prop.shape().unknown_rank());
1424
1425 const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1426 EXPECT_EQ(2, in_props.size());
1427
1428 const OpInfo::TensorProperties& in_prop = in_props[0];
1429 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1430
1431 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1432 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1433 }
1434
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1435 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1436 // Test graph produced in python using:
1437 /*
1438 @function.Defun(*[tf.float32] * 2, noinline=True)
1439 def MyAdd(x, y):
1440 return tf.add(x, y)
1441
1442 with tf.Graph().as_default():
1443 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1444 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1445 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1446 z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1447 */
1448 GrapplerItem item;
1449 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1450 "function_switch.pbtxt");
1451 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1452 GraphProperties properties(item);
1453 TF_CHECK_OK(properties.InferStatically(false));
1454 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1455 const OpInfo::TensorProperties& out_prop = out_props[0];
1456 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1457 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1458
1459 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1460 EXPECT_EQ(2, in_props.size());
1461
1462 const OpInfo::TensorProperties& in_prop = in_props[0];
1463 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1464
1465 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1466 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1467 }
1468
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1469 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1470 // Test graph produced in python using:
1471 /*
1472 @function.Defun(*[tf.float32] * 2, noinline=True)
1473 def MyAdd(x, y):
1474 return tf.add(x, y)
1475
1476 with tf.Graph().as_default():
1477 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1478 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1479 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1480 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1481 */
1482 GrapplerItem item;
1483 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1484 "function_switch_2.pbtxt");
1485 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1486 GraphProperties properties(item);
1487 TF_CHECK_OK(properties.InferStatically(false));
1488 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1489 const OpInfo::TensorProperties& out_prop = out_props[0];
1490 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1491
1492 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1493 EXPECT_EQ(2, in_props.size());
1494
1495 const OpInfo::TensorProperties& in_prop = in_props[0];
1496 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1497
1498 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1499 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1500 }
1501
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1502 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1503 // Test graph produced in python using:
1504 /*
1505 @function.Defun(*[tf.float32] * 2, noinline=True)
1506 def MyAdd(x, y):
1507 a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1508 b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1509 c = tf.add(x, a)
1510 d = tf.add(y, b)
1511 return c
1512
1513 with tf.Graph().as_default():
1514 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1515 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1516 z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1517 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1518 */
1519 GrapplerItem item;
1520 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1521 "function_switch_shapes.pbtxt");
1522 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1523 GraphProperties properties(item);
1524 TF_CHECK_OK(properties.InferStatically(false));
1525 const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1526 const OpInfo::TensorProperties& out_prop = out_props[0];
1527 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1528
1529 const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1530 EXPECT_EQ(2, in_props.size());
1531
1532 const OpInfo::TensorProperties& in_prop = in_props[0];
1533 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1534
1535 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1536 EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1537 }
1538
TEST_F(GraphPropertiesTest,SymbolicShapes)1539 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1540 // Build a simple graph with placeholders of unknown dimensions. These
1541 // dimensions will be encoded symbolically.
1542 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1543
1544 Output a =
1545 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1546 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1547 Output b =
1548 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1549 ops::Placeholder::Shape(PartialTensorShape({-1})));
1550 Output c = ops::Identity(s.WithOpName("c"), a);
1551 Output d = ops::Identity(s.WithOpName("d"), b);
1552 Output e = ops::Add(s.WithOpName("e"), c, d);
1553 Output f = ops::Add(s.WithOpName("f"), a, c);
1554
1555 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1556 Output g = ops::Shape(s.WithOpName("g"), c);
1557 Output h = ops::Fill(s.WithOpName("h"), g, zero);
1558 Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1559 Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1560
1561 GrapplerItem item;
1562 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1563
1564 GraphProperties properties(item);
1565 TF_CHECK_OK(properties.InferStatically(false));
1566 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1567 const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1568 EXPECT_EQ(2, shape_a.dim_size());
1569 EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1570 EXPECT_GE(-2, shape_a.dim(0).size());
1571 EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1572 EXPECT_GE(-2, shape_a.dim(1).size());
1573 EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1574
1575 PartialTensorShape shape(shape_a);
1576 EXPECT_FALSE(shape.IsFullyDefined());
1577 EXPECT_FALSE(shape.unknown_rank());
1578
1579 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1580 const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1581 EXPECT_EQ(1, shape_b.dim_size());
1582 EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1583 EXPECT_GE(-2, shape_b.dim(0).size());
1584 EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1585 EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1586
1587 const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1588 ASSERT_EQ(2, shape_e.dim_size());
1589 EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1590 EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1591 EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1592
1593 const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1594 ASSERT_EQ(2, shape_f.dim_size());
1595 EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1596 EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1597
1598 const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1599 ASSERT_EQ(2, shape_f.dim_size());
1600 EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1601 EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1602
1603 const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
1604 ASSERT_EQ(1, shape_j.dim_size());
1605 EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
1606 }
1607
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)1608 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
1609 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1610 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
1611 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
1612 Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
1613 GrapplerItem item;
1614 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1615 // Create a graph with node a removed (say by some graph optimization
1616 // pass), noting that node c is colocated with a. This is fine as it
1617 // is in the late stage of graph execution, the colocation constraints have
1618 // been validated previously and the device placement of nodes has completed.
1619 GraphDef optimized_graph;
1620 for (const auto& node : item.graph.node()) {
1621 if (node.name() != "a") {
1622 *optimized_graph.add_node() = node;
1623 }
1624 }
1625 item.graph.Swap(&optimized_graph);
1626 GraphProperties properties(item);
1627 // This function should return OK, since it doesn't validate the colocation
1628 // constraints internally.
1629 TF_EXPECT_OK(properties.InferStatically(false));
1630 }
1631
TEST_F(GraphPropertiesTest,ShapeTracking)1632 TEST_F(GraphPropertiesTest, ShapeTracking) {
1633 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1634 Output a =
1635 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1636 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1637 Output b =
1638 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1639 ops::Placeholder::Shape(PartialTensorShape({-1})));
1640 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1641 auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
1642 Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
1643 Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
1644
1645 GrapplerItem item;
1646 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1647
1648 GraphProperties properties(item);
1649 TF_CHECK_OK(properties.InferStatically(false));
1650 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1651 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1652 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1653 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1654 EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
1655 EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
1656 }
1657
TEST_F(GraphPropertiesTest,FedNodes)1658 TEST_F(GraphPropertiesTest, FedNodes) {
1659 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
1660 cluster_->GetDeviceNames());
1661 GrapplerItem item;
1662 CHECK(fake_input.NextItem(&item));
1663
1664 {
1665 // Conservative shape analysis: the shape of fed ports should be unknown
1666 GraphProperties properties(item);
1667 Status s = properties.InferStatically(false);
1668 TF_CHECK_OK(s);
1669 for (const auto& node : item.graph.node()) {
1670 if (node.op() == "Const") {
1671 continue;
1672 }
1673 const auto in_props = properties.GetInputProperties(node.name());
1674 EXPECT_EQ(1, in_props.size());
1675 const OpInfo::TensorProperties& in_prop = in_props[0];
1676 const auto out_props = properties.GetOutputProperties(node.name());
1677 EXPECT_EQ(1, out_props.size());
1678 const OpInfo::TensorProperties& out_prop = out_props[0];
1679
1680 if (node.name() == "x") {
1681 // x is fed: its input should have a known shape, while its output
1682 // doesn't
1683 EXPECT_FALSE(in_prop.shape().unknown_rank());
1684 EXPECT_EQ(1, in_prop.shape().dim_size());
1685 EXPECT_EQ(2, in_prop.shape().dim(0).size());
1686 EXPECT_TRUE(out_prop.shape().unknown_rank());
1687 } else if (node.op() == "Square" || node.op() == "AddN") {
1688 // These nodes are in the fanout of x: their shapes should be unknown.
1689 EXPECT_TRUE(in_prop.shape().unknown_rank());
1690 EXPECT_TRUE(out_prop.shape().unknown_rank());
1691 }
1692 }
1693 }
1694 {
1695 // Optimistic shape analysis: the shape of fed ports should be derived from
1696 // the shape of the fanin.
1697 GraphProperties properties(item);
1698 Status s = properties.InferStatically(true);
1699 TF_CHECK_OK(s);
1700 for (const auto& node : item.graph.node()) {
1701 if (node.op() == "Square" || node.op() == "AddN") {
1702 const auto in_props = properties.GetInputProperties(node.name());
1703 EXPECT_EQ(1, in_props.size());
1704 const OpInfo::TensorProperties& in_prop = in_props[0];
1705 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
1706 EXPECT_FALSE(in_prop.shape().unknown_rank());
1707 EXPECT_EQ(2, in_prop.shape().dim_size());
1708 const auto out_props = properties.GetOutputProperties(node.name());
1709 EXPECT_EQ(1, out_props.size());
1710 const OpInfo::TensorProperties& out_prop = out_props[0];
1711 EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
1712 }
1713 }
1714 }
1715 }
1716
TEST_F(GraphPropertiesTest,Performance)1717 TEST_F(GraphPropertiesTest, Performance) {
1718 // Load a large graph with many nested loops to make sure we can infer shapes
1719 // quickly.
1720 GrapplerItem item;
1721 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1722 "large_graph.pbtxt.html");
1723 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1724 TF_CHECK_OK(AddDefaultAttrsToGraphDef(
1725 &item.graph,
1726 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
1727 true));
1728
1729 GraphProperties properties(item);
1730 TF_CHECK_OK(properties.InferStatically(false));
1731 }
1732
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)1733 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
1734 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1735 Output a =
1736 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1737 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1738 auto shp = ops::Shape(s.WithOpName("shape"), {a});
1739
1740 Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
1741 Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
1742 Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
1743
1744 Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
1745 Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
1746
1747 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1748 Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
1749 Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
1750
1751 GrapplerItem item;
1752 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1753
1754 GraphProperties properties(item);
1755 TF_CHECK_OK(properties.InferStatically(false));
1756 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1757 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1758 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1759 EXPECT_EQ(2, shape_a.dim_size());
1760 EXPECT_EQ(1, shape_o1.dim_size());
1761 EXPECT_EQ(1, shape_o2.dim_size());
1762 EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
1763 EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
1764 }
1765
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)1766 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
1767 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1768 Output placeholder =
1769 ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
1770 ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
1771 auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
1772
1773 Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
1774 Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
1775 Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
1776
1777 Output slice =
1778 ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
1779 stride, ops::StridedSlice::ShrinkAxisMask(1));
1780
1781 GrapplerItem item;
1782 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1783
1784 // Without aggressive shape inference, it cannot infer output value of
1785 // StridedSlice with ShrinkAxisMask.
1786 {
1787 GraphProperties properties(item);
1788 TF_CHECK_OK(properties.InferStatically(
1789 /*assume_valid_feeds=*/false,
1790 /*aggressive_shape_inference=*/false));
1791 EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
1792 }
1793
1794 // InferStatically with aggressive shape inference can infer output value of
1795 // StridedSlice with ShrinkAxisMask.
1796 {
1797 GraphProperties properties(item);
1798 TF_CHECK_OK(properties.InferStatically(
1799 /*assume_valid_feeds=*/false,
1800 /*aggressive_shape_inference=*/true));
1801 EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
1802 const auto slice_value =
1803 properties.GetOutputProperties("slice").at(0).value();
1804 ExpectTensorValues({5}, slice_value);
1805 }
1806 }
1807
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)1808 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
1809 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1810 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1811 Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
1812 Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
1813
1814 Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
1815 Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
1816 Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
1817 Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
1818 Output c_plus_b_plus_2a =
1819 ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
1820
1821 GrapplerItem item;
1822 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1823 GraphProperties properties(item);
1824 TF_CHECK_OK(properties.InferStatically(
1825 /*assume_valid_feeds=*/false,
1826 /*aggressive_shape_inference=*/true));
1827
1828 // Check output shapes and values.
1829 const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
1830 EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
1831 EXPECT_TRUE(a_plus_one_prop.has_value());
1832 ExpectTensorValues({6, 8}, a_plus_one_prop.value());
1833
1834 const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
1835 EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
1836 EXPECT_TRUE(a_plus_a_prop.has_value());
1837 ExpectTensorValues({10, 14}, a_plus_a_prop.value());
1838
1839 const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
1840 EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
1841 EXPECT_TRUE(b_plus_2a_prop.has_value());
1842 ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
1843
1844 const auto& c_plus_b_plus_2a_prop =
1845 properties.GetOutputProperties("c_plus_b_plus_2a")[0];
1846 EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
1847 EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
1848 ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
1849 }
1850
TEST_F(GraphPropertiesTest,ShapeAnnotation)1851 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
1852 GrapplerItem item;
1853 TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1854 .Attr("dtype", DT_FLOAT)
1855 .Attr("shape", PartialTensorShape({-1, -1}))
1856 .Finalize(item.graph.add_node()));
1857 // Annotate shapes.
1858 TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1859 .Attr("dtype", DT_FLOAT)
1860 .Attr("_same_output_for_iterations", true)
1861 .Attr("_output_shape_vector", {TensorShape({5, 7})})
1862 .Input("Input", 0, DT_FLOAT)
1863 .Finalize(item.graph.add_node()));
1864 {
1865 GraphProperties properties(item);
1866 // Without aggressive_shape_inference, ignore annotated information.
1867 TF_CHECK_OK(properties.InferStatically(
1868 /*assume_valid_feeds=*/false,
1869 /*aggressive_shape_inference=*/false));
1870 const auto props = properties.GetOutputProperties("Identity");
1871 EXPECT_EQ(1, props.size());
1872 const OpInfo::TensorProperties& prop = props[0];
1873 EXPECT_EQ(DT_FLOAT, prop.dtype());
1874 EXPECT_EQ(2, prop.shape().dim_size());
1875 // Get unknown shapes without using annotated information.
1876 EXPECT_EQ("float: [-1,-1]", PropToString(prop));
1877 }
1878 {
1879 GraphProperties properties(item);
1880 // Use annotated information.
1881 TF_CHECK_OK(properties.InferStatically(
1882 /*assume_valid_feeds=*/false,
1883 /*aggressive_shape_inference=*/true));
1884 const auto props = properties.GetOutputProperties("Identity");
1885 EXPECT_EQ(1, props.size());
1886 const OpInfo::TensorProperties& prop = props[0];
1887 EXPECT_EQ(DT_FLOAT, prop.dtype());
1888 EXPECT_EQ(2, prop.shape().dim_size());
1889 // Update output shape using annotated shapes.
1890 EXPECT_EQ("float: [5,7]", PropToString(prop));
1891 }
1892 }
1893
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)1894 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
1895 GrapplerItem item;
1896 TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1897 .Attr("dtype", DT_FLOAT)
1898 .Attr("shape", PartialTensorShape({-1, 100}))
1899 .Finalize(item.graph.add_node()));
1900 // Annotate shapes.
1901 TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1902 .Attr("dtype", DT_FLOAT)
1903 .Attr("_same_output_for_iterations", true)
1904 .Attr("_output_shape_vector", {TensorShape({10, 100})})
1905 .Input("Input", 0, DT_FLOAT)
1906 .Finalize(item.graph.add_node()));
1907 GraphProperties properties(item);
1908 // Use annotated information.
1909 TF_CHECK_OK(properties.InferStatically(
1910 /*assume_valid_feeds=*/false,
1911 /*aggressive_shape_inference=*/true));
1912 const auto props = properties.GetOutputProperties("Identity");
1913 EXPECT_EQ(1, props.size());
1914 const OpInfo::TensorProperties& prop = props[0];
1915 EXPECT_EQ(DT_FLOAT, prop.dtype());
1916 EXPECT_EQ(2, prop.shape().dim_size());
1917 // Compatible shapes. Update output shape using annotated shapes.
1918 EXPECT_EQ("float: [10,100]", PropToString(prop));
1919 }
1920
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)1921 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
1922 GrapplerItem item;
1923 TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1924 .Attr("dtype", DT_FLOAT)
1925 .Attr("shape", PartialTensorShape({-1, 100}))
1926 .Finalize(item.graph.add_node()));
1927 // Annotate shapes.
1928 TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1929 .Attr("dtype", DT_FLOAT)
1930 .Attr("_same_output_for_iterations", true)
1931 .Attr("_output_shape_vector", {TensorShape({10, 10})})
1932 .Input("Input", 0, DT_FLOAT)
1933 .Finalize(item.graph.add_node()));
1934 GraphProperties properties(item);
1935 // Use annotated information.
1936 TF_CHECK_OK(properties.InferStatically(
1937 /*assume_valid_feeds=*/false,
1938 /*aggressive_shape_inference=*/true));
1939 const auto props = properties.GetOutputProperties("Identity");
1940 EXPECT_EQ(1, props.size());
1941 const OpInfo::TensorProperties& prop = props[0];
1942 EXPECT_EQ(DT_FLOAT, prop.dtype());
1943 EXPECT_EQ(2, prop.shape().dim_size());
1944 // Incompatible shapes. Do not use annotated shapes.
1945 EXPECT_EQ("float: [-1,100]", PropToString(prop));
1946 }
1947
1948 } // namespace
1949 } // namespace grappler
1950 } // namespace tensorflow
1951