1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50
51 REGISTER_OP("TestOpWithNoInferenceFn")
52 .Input("x: float")
53 .Output("y: float")
54 .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59
60 class GraphPropertiesTest : public ::testing::Test {
61 public:
SetUp()62 void SetUp() override {
63 // Provision a single machine with 3 cpu cores
64 cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65 TF_ASSERT_OK(cluster_->Provision());
66
67 // This function is simply
68 // out = Fill(shape, value), but
69 // Fill requires values in the shape input, not just shape of it, to infer
70 // output shape.
71 auto f = FunctionDefHelper::Create(
72 // Name
73 "MyFillFunc",
74 // Inputs
75 {"shape: int32", "value: float"},
76 // Outputs
77 {"out: float"},
78 // Attrs
79 {},
80 // Nodes
81 {
82 {{"a"},
83 "Fill",
84 {"shape", "value"},
85 {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86 },
87 // Returns
88 {{"out", "a:output:0"}});
89 function_lib_.add_function()->Swap(&f);
90 }
91
TearDown()92 void TearDown() override {
93 TF_ASSERT_OK(cluster_->Shutdown());
94 cluster_.reset();
95 }
96
97 protected:
98 // Returns a string form of <p>, suitable for comparing type and shape.
99 // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100 string PropToString(const OpInfo::TensorProperties& p) {
101 string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102 if (p.shape().unknown_rank()) {
103 strings::StrAppend(&s, "?");
104 } else {
105 strings::StrAppend(&s, "[");
106 for (int i = 0; i < p.shape().dim_size(); ++i) {
107 strings::StrAppend(&s, i == 0 ? "" : ",",
108 std::max<int64>(p.shape().dim(i).size(), -1));
109 }
110 strings::StrAppend(&s, "]");
111 }
112 return s;
113 }
114
115 // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116 // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)117 void ExpectTensorValues(const std::vector<int64>& expected,
118 const TensorProto& tensor_proto_to_compare) {
119 Tensor tensor;
120 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121 EXPECT_EQ(expected.size(), tensor.NumElements());
122 // We're interested in only integer tensors as only shapes are exported as
123 // graph properties values.
124 ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125 if (tensor.dtype() == DT_INT32) {
126 for (int i = 0; i < tensor.NumElements(); i++) {
127 EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128 }
129 } else {
130 for (int i = 0; i < tensor.NumElements(); i++) {
131 EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
132 }
133 }
134 }
135
136 // Compare values of float (DT_FLOAT) tensor against expected
137 // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138 void ExpectFloatTensorValues(const std::vector<float>& expected,
139 const TensorProto& tensor_proto_to_compare) {
140 Tensor tensor;
141 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142 EXPECT_EQ(expected.size(), tensor.NumElements());
143 ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144 for (int i = 0; i < tensor.NumElements(); i++) {
145 EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146 }
147 }
148
149 std::unique_ptr<SingleMachine> cluster_;
150 FunctionDefLibrary function_lib_;
151 };
152
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155 cluster_->GetDeviceNames());
156 GrapplerItem item;
157 CHECK(fake_input.NextItem(&item));
158
159 GraphProperties properties(item);
160 Status s = properties.InferStatically(true);
161 TF_ASSERT_OK(s);
162
163 for (const auto& node : item.graph.node()) {
164 if (node.op() == "RandomStandardNormal") {
165 // The node has one input (the shape of the tensor to generate).
166 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167 // The const node has one output.
168 const auto props = properties.GetOutputProperties(node.name());
169 EXPECT_EQ(1, props.size());
170 const OpInfo::TensorProperties& prop = props[0];
171 EXPECT_EQ(DT_FLOAT, prop.dtype());
172 EXPECT_FALSE(prop.shape().unknown_rank());
173 EXPECT_EQ(2, prop.shape().dim_size());
174 EXPECT_EQ(10, prop.shape().dim(0).size());
175 EXPECT_EQ(1, prop.shape().dim(1).size());
176 } else if (node.op() == "AddN") {
177 const auto in_props = properties.GetInputProperties(node.name());
178 EXPECT_EQ(1, in_props.size());
179 const OpInfo::TensorProperties& in_prop = in_props[0];
180 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181 EXPECT_FALSE(in_prop.shape().unknown_rank());
182 EXPECT_EQ(2, in_prop.shape().dim_size());
183 EXPECT_EQ(10, in_prop.shape().dim(0).size());
184 EXPECT_EQ(1, in_prop.shape().dim(1).size());
185 const auto out_props = properties.GetOutputProperties(node.name());
186 EXPECT_EQ(1, out_props.size());
187 EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188 EXPECT_EQ(in_prop.shape().DebugString(),
189 out_props[0].shape().DebugString());
190 }
191 }
192 }
193
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196 cluster_->GetDeviceNames());
197 GrapplerItem item;
198 CHECK(fake_input.NextItem(&item));
199
200 GraphProperties properties(item);
201 Status s = properties.InferStatically(true);
202 TF_ASSERT_OK(s);
203
204 for (const auto& node : item.graph.node()) {
205 if (node.op() == "RandomStandardNormal") {
206 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207 const auto props = properties.GetOutputProperties(node.name());
208 properties.ClearOutputProperties(node.name());
209 const auto cleared_props = properties.GetOutputProperties(node.name());
210 EXPECT_TRUE(cleared_props.empty());
211 } else if (node.op() == "AddN") {
212 const auto in_props = properties.GetInputProperties(node.name());
213 EXPECT_EQ(1, in_props.size());
214 properties.ClearInputProperties(node.name());
215 const auto cleared_props = properties.GetInputProperties(node.name());
216 EXPECT_TRUE(cleared_props.empty());
217 }
218 }
219 }
220
TEST_F(GraphPropertiesTest,DynamicProperties)221 TEST_F(GraphPropertiesTest, DynamicProperties) {
222 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223 cluster_->GetDeviceNames());
224 GrapplerItem item;
225 CHECK(fake_input.NextItem(&item));
226
227 GraphProperties properties(item);
228 TF_ASSERT_OK(cluster_->Initialize(item));
229 Status s = properties.InferDynamically(cluster_.get());
230 TF_ASSERT_OK(s);
231
232 for (const auto& node : item.graph.node()) {
233 if (node.op() == "RandomStandardNormal") {
234 // The random node is missing from the cost graph (why ?)
235 EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
236 } else if (node.op() == "AddN") {
237 // Since the random node is missing, we can't infer the input properties
238 // of the first AddN node. The other AddN nodes have the expected
239 // properties.
240 if (node.name() == "AddN") {
241 const auto props = properties.GetInputProperties(node.name());
242 EXPECT_EQ(1, props.size());
243 const OpInfo::TensorProperties& prop = props[0];
244 EXPECT_EQ(DT_INVALID, prop.dtype());
245 EXPECT_TRUE(prop.shape().unknown_rank());
246 } else {
247 const auto props = properties.GetInputProperties(node.name());
248 EXPECT_EQ(1, props.size());
249 const OpInfo::TensorProperties& prop = props[0];
250 EXPECT_EQ(DT_FLOAT, prop.dtype());
251 EXPECT_FALSE(prop.shape().unknown_rank());
252 EXPECT_EQ(2, prop.shape().dim_size());
253 EXPECT_EQ(10, prop.shape().dim(0).size());
254 EXPECT_EQ(1, prop.shape().dim(1).size());
255 const auto out_props = properties.GetOutputProperties(node.name());
256 #ifdef INTEL_MKL
257 if (!NativeFormatEnabled()) {
258 // Intel MKL AddN OP would have two output.
259 // One is the real output, another one for MKL metadata
260 EXPECT_EQ(2, out_props.size());
261 } else {
262 EXPECT_EQ(1, out_props.size());
263 }
264 #else
265 EXPECT_EQ(1, out_props.size());
266 #endif // INTEL_MKL
267 string prop_str;
268 ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
269 string out_prop_str;
270 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
271 &out_prop_str);
272 EXPECT_EQ(prop_str, out_prop_str);
273 }
274 }
275 }
276 }
277
278 // A test op that outputs different shape based on input_tensor in the shape
279 // inference context.
280 REGISTER_OP("DetectInputValueInShapeInferenceOp")
281 .Input("a: T")
282 .Output("o: T")
283 .Attr("T: {numbertype, bool}")
__anond3d7e4310202(shape_inference::InferenceContext* c) 284 .SetShapeFn([](shape_inference::InferenceContext* c) {
285 if (c->input_tensor(0)) {
286 // 10x10 if input_tensor is given to the inference context.
287 c->set_output(0, c->Matrix(10, 10));
288 return Status::OK();
289 }
290 // unknown rank if input_tensor is not provided.
291 return shape_inference::UnknownShape(c);
292 });
293
294 // Helper class for testing Const tensor skip.
295 class ConstTensorSkipTestCase {
296 public:
ConstTensorSkipTestCase(const DataType data_type,const std::vector<int64> shape,const double value,const bool expected)297 ConstTensorSkipTestCase(const DataType data_type,
298 const std::vector<int64> shape, const double value,
299 const bool expected)
300 : data_type_(data_type),
301 shape_(shape),
302 value_(value),
303 expected_(expected) {}
304
RunTestAndValidate() const305 void RunTestAndValidate() const {
306 LOG(INFO) << "Run Const tensor skip test: "
307 << "data_type: " << data_type_ << ", shape: {"
308 << absl::StrJoin(shape_, ",") << "}, value: " << value_
309 << ", expected: " << expected_;
310 // Build a graph wiht Const --> Identity --> Detect.
311 GrapplerItem item;
312 const gtl::ArraySlice<int64> shape_array_slice(shape_);
313 Tensor const_tensor_value(data_type_, TensorShape(shape_array_slice));
314 // Fille the const tensor value based on data type.
315 switch (data_type_) {
316 case DT_INT32:
317 test::FillIota<int32>(&const_tensor_value, static_cast<int32>(value_));
318 break;
319 case DT_INT64:
320 test::FillIota<int64>(&const_tensor_value, static_cast<int64>(value_));
321 break;
322 case DT_FLOAT:
323 test::FillIota<float>(&const_tensor_value, static_cast<float>(value_));
324 break;
325 case DT_DOUBLE:
326 test::FillIota<double>(&const_tensor_value,
327 static_cast<double>(value_));
328 break;
329 case DT_BFLOAT16:
330 test::FillIota<Eigen::bfloat16>(&const_tensor_value,
331 static_cast<Eigen::bfloat16>(value_));
332 break;
333 default:
334 CHECK(false) << "Unsupported data type (" << data_type_
335 << ") in this test.";
336 break;
337 }
338 TF_ASSERT_OK(NodeDefBuilder("const", "Const")
339 .Attr("dtype", data_type_)
340 .Attr("value", const_tensor_value)
341 .Finalize(item.graph.add_node()));
342 TF_ASSERT_OK(NodeDefBuilder("const_identity", "Identity")
343 .Attr("dtype", data_type_)
344 .Input("const", 0, data_type_)
345 .Finalize(item.graph.add_node()));
346 TF_ASSERT_OK(NodeDefBuilder("detect", "DetectInputValueInShapeInferenceOp")
347 .Attr("T", data_type_)
348 .Input("const_identity", 0, data_type_)
349 .Finalize(item.graph.add_node()));
350 item.fetch.push_back("const");
351 item.fetch.push_back("const_identity");
352 item.fetch.push_back("detect");
353
354 // Run static shape inference.
355 GraphProperties graph_properties(item);
356 TF_ASSERT_OK(graph_properties.InferStatically(false));
357
358 // Extract input / output properties of interest.
359 const auto& const_output = graph_properties.GetOutputProperties("const");
360 EXPECT_EQ(1, const_output.size());
361 const OpInfo::TensorProperties& const_output0 = const_output[0];
362 const auto& const_identity_input =
363 graph_properties.GetInputProperties("const_identity");
364 EXPECT_EQ(1, const_identity_input.size());
365 const OpInfo::TensorProperties& const_identity_input0 =
366 const_identity_input[0];
367 const auto& const_identity_output =
368 graph_properties.GetOutputProperties("const_identity");
369 EXPECT_EQ(1, const_identity_output.size());
370 const OpInfo::TensorProperties& const_identity_output0 =
371 const_identity_output[0];
372 EXPECT_TRUE(const_output0.has_value());
373 EXPECT_TRUE(const_identity_input0.has_value());
374 EXPECT_TRUE(const_identity_output0.has_value());
375 const auto& detect_input = graph_properties.GetInputProperties("detect");
376 EXPECT_EQ(1, detect_input.size());
377 const OpInfo::TensorProperties& detect_input0 = detect_input[0];
378 const auto& detect_output = graph_properties.GetOutputProperties("detect");
379 EXPECT_EQ(1, detect_output.size());
380 const OpInfo::TensorProperties& detect_output0 = detect_output[0];
381
382 // Tensor protos are propagated, regardless of types and sizes.
383 EXPECT_TRUE(const_output0.has_value());
384 EXPECT_TRUE(const_identity_input0.has_value());
385 EXPECT_TRUE(const_identity_output0.has_value());
386 EXPECT_TRUE(detect_input0.has_value());
387
388 // Detect op outputs 10x10 matrix if it has input_tensor in the shape
389 // inference context. Otherwise, unknown rank.
390 if (expected_) {
391 EXPECT_EQ(detect_output0.shape().dim_size(), 2);
392 EXPECT_EQ(detect_output0.shape().dim(0).size(), 10);
393 EXPECT_EQ(detect_output0.shape().dim(1).size(), 10);
394 } else {
395 EXPECT_TRUE(detect_output0.shape().unknown_rank());
396 }
397 }
398
399 private:
400 DataType data_type_;
401 std::vector<int64> shape_;
402 double value_;
403 bool expected_;
404 };
405
TEST_F(GraphPropertiesTest,SkipInstantiatingConstTensor)406 TEST_F(GraphPropertiesTest, SkipInstantiatingConstTensor) {
407 // We skip const tensor value propagation in shape inference, if a const
408 // tensor is too large.
409 std::vector<ConstTensorSkipTestCase> test_cases = {
410 // data_type, shape, value, bool: propagate const?
411 {DT_INT32, {16, 8}, 1, true}, // 128 elements; smaller than threshold
412 {DT_INT32, {1, 129}, 2, false}, // 129 elements; larger than threshold
413 {DT_INT64, {8, 8}, 3, true}, // 64 elements; smaller than threshold
414 {DT_INT64, {128, 2}, 0, false}, // 256 elements; larger than threshold
415 {DT_FLOAT, {16, 8}, 1.0, true}, // integer value for float tensor
416 {DT_FLOAT, {16, 8}, 1.3, true}, // fractional value (1.3)
417 {DT_FLOAT, {1, 129}, 0.7, false}, // fractional value (0.7)
418 {DT_DOUBLE, {16, 8}, 1.0, true}, // integer value for float tensor
419 {DT_DOUBLE, {16, 8}, 1.3, true}, // fractional value (1.3)
420 {DT_DOUBLE, {1, 129}, 0.7, false}, // fractional value (0.7)
421 {DT_BFLOAT16, {16, 8}, 1.0, true}, // integer value for float tensor
422 {DT_BFLOAT16, {16, 8}, 1.3, true}, // fractional value (1.3)
423 {DT_BFLOAT16, {1, 129}, 0.7, false}, // fractional value (0.7)
424 };
425 for (const auto& test_case : test_cases) {
426 test_case.RunTestAndValidate();
427 }
428 }
429
TEST_F(GraphPropertiesTest,Variables)430 TEST_F(GraphPropertiesTest, Variables) {
431 GrapplerItem item;
432 TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
433 .Attr("dtype", DT_FLOAT)
434 .Attr("shape", TensorShape({3, 7}))
435 .Finalize(item.graph.add_node()));
436 item.fetch.push_back("Var");
437
438 Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
439 test::FillIota<float>(&initial_val, 0);
440 TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
441 .Attr("dtype", DT_FLOAT)
442 .Attr("value", initial_val)
443 .Finalize(item.graph.add_node()));
444 TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
445 .Input("Var", 0, DT_FLOAT_REF)
446 .Input("InitialVal", 0, DT_FLOAT)
447 .Finalize(item.graph.add_node()));
448 item.init_ops.push_back("InitVar");
449
450 {
451 GraphProperties static_properties(item);
452 TF_ASSERT_OK(static_properties.InferStatically(false));
453
454 const auto props = static_properties.GetOutputProperties("Var");
455 EXPECT_EQ(1, props.size());
456 const OpInfo::TensorProperties& prop = props[0];
457 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
458 EXPECT_FALSE(prop.shape().unknown_rank());
459 EXPECT_EQ(2, prop.shape().dim_size());
460 EXPECT_EQ(3, prop.shape().dim(0).size());
461 EXPECT_EQ(7, prop.shape().dim(1).size());
462 }
463 {
464 TF_ASSERT_OK(cluster_->Initialize(item));
465 GraphProperties dynamic_properties(item);
466 TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
467
468 const auto props = dynamic_properties.GetOutputProperties("Var");
469 EXPECT_EQ(1, props.size());
470 const OpInfo::TensorProperties& prop = props[0];
471 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
472 EXPECT_FALSE(prop.shape().unknown_rank());
473 EXPECT_EQ(2, prop.shape().dim_size());
474 EXPECT_EQ(3, prop.shape().dim(0).size());
475 EXPECT_EQ(7, prop.shape().dim(1).size());
476 }
477 }
478
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)479 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
480 GrapplerItem item;
481 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
482 .Attr("dtype", DT_FLOAT)
483 .Attr("shape", TensorShape({3, 7}))
484 .Finalize(item.graph.add_node()));
485 TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
486 .Attr("T", DT_RESOURCE)
487 .Attr("frame_name", "while_context")
488 .Attr("is_constant", true)
489 .Attr("parallel_iterations", 10)
490 .Input("Var", 0, DT_RESOURCE)
491 .Finalize(item.graph.add_node()));
492 TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
493 .Attr("dtype", DT_FLOAT)
494 .Input("Enter", 0, DT_RESOURCE)
495 .Finalize(item.graph.add_node()));
496
497 GraphProperties properties(item);
498 TF_ASSERT_OK(properties.InferStatically(false));
499 const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
500 EXPECT_EQ(1, props.size());
501 const OpInfo::TensorProperties& prop = props[0];
502 EXPECT_EQ(DT_FLOAT, prop.dtype());
503 EXPECT_FALSE(prop.shape().unknown_rank());
504 EXPECT_EQ(2, prop.shape().dim_size());
505 EXPECT_EQ(3, prop.shape().dim(0).size());
506 EXPECT_EQ(7, prop.shape().dim(1).size());
507 }
508
TEST_F(GraphPropertiesTest,VarHandles)509 TEST_F(GraphPropertiesTest, VarHandles) {
510 GrapplerItem item;
511 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
512 .Attr("dtype", DT_FLOAT)
513 .Attr("shape", TensorShape({3, 7}))
514 .Finalize(item.graph.add_node()));
515
516 TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
517 .Attr("dtype", DT_FLOAT)
518 .Input("Var", 0, DT_RESOURCE)
519 .Finalize(item.graph.add_node()));
520
521 GraphProperties properties(item);
522 TF_ASSERT_OK(properties.InferStatically(false));
523
524 const auto props = properties.GetOutputProperties("VarRead");
525 EXPECT_EQ(1, props.size());
526 const OpInfo::TensorProperties& prop = props[0];
527 EXPECT_EQ(DT_FLOAT, prop.dtype());
528 EXPECT_FALSE(prop.shape().unknown_rank());
529 EXPECT_EQ(2, prop.shape().dim_size());
530 EXPECT_EQ(3, prop.shape().dim(0).size());
531 EXPECT_EQ(7, prop.shape().dim(1).size());
532 }
533
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)534 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
535 // Test graph is first generated in python using:
536 /*
537 i0 = tf.constant(0)
538 v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
539 def cond(i, x):
540 return i < 3
541 def body(i, x):
542 return i + 1, x + x
543 v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
544 */
545 // and then modified by hand such that the ReadVariableOp is inside the loop
546 // body instead of outside the while loop (which is the case when constructed
547 // using the python API), such that we have the following pattern: VarHandleOp
548 // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
549 // DT_RESOURCE is passed all the way until ReadVariableOp.
550 GrapplerItem item;
551 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
552 "while_loop_var_handle_op.pbtxt");
553 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
554 GraphProperties properties(item);
555 TF_ASSERT_OK(properties.InferStatically(false));
556
557 std::vector<string> resource_nodes{
558 "loop_var", "while/Enter", "while/Merge", "while/Switch",
559 "while/Identity", "while/NextIteration", "while/Exit"};
560 for (const string& node : resource_nodes) {
561 const auto props = properties.GetOutputProperties(node);
562 EXPECT_GE(props.size(), 1); // Merge has 2 outputs.
563 EXPECT_EQ("resource: []", PropToString(props[0]));
564 }
565
566 // After ReadVariableOp, the shape should be recovered.
567 const auto props = properties.GetOutputProperties("while/ReadVariableOp");
568 EXPECT_EQ(1, props.size());
569 EXPECT_EQ("int32: []", PropToString(props[0]));
570 }
571
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)572 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
573 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
574 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
575 auto dequeue1 =
576 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
577
578 GrapplerItem item;
579 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
580
581 GraphProperties properties(item);
582 TF_ASSERT_OK(properties.InferStatically(false));
583
584 const auto props1 = properties.GetOutputProperties("Dequeue1");
585 ASSERT_EQ(1, props1.size());
586 EXPECT_EQ("float: ?", PropToString(props1[0]));
587 }
588
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)589 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
590 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
591 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
592 ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
593 auto dequeue1 =
594 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
595
596 GrapplerItem item;
597 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
598
599 GraphProperties properties(item);
600 TF_ASSERT_OK(properties.InferStatically(false));
601
602 const auto props1 = properties.GetOutputProperties("Dequeue1");
603 ASSERT_EQ(1, props1.size());
604 EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
605 }
606
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)607 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
608 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
609 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
610 ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
611 auto dequeue1 =
612 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
613
614 GrapplerItem item;
615 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
616
617 GraphProperties properties(item);
618 TF_ASSERT_OK(properties.InferStatically(false));
619
620 const auto props1 = properties.GetOutputProperties("Dequeue1");
621 ASSERT_EQ(1, props1.size());
622 EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
623 }
624
TEST_F(GraphPropertiesTest,Queues)625 TEST_F(GraphPropertiesTest, Queues) {
626 // Create a graph with known input shapes, and propagate the shapes through a
627 // couple of queues.
628 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
629
630 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
631 Output rnd =
632 ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
633 Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
634 auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
635 auto dequeue1 =
636 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
637
638 auto q2 =
639 ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
640 Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
641 auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
642 auto dequeue2 =
643 ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
644
645 auto q4 =
646 ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
647 auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
648 auto enqueue4_2 =
649 ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
650 auto dequeue4 =
651 ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
652
653 // Create a queue that takes in three tensors.
654 auto q5 = ops::RandomShuffleQueue(
655 root.WithOpName("Queue5"),
656 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
657 Output rnd2 =
658 ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
659 Output rnd3 =
660 ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
661 auto enqueue5 =
662 ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
663 auto dequeue5 = ops::QueueDequeue(
664 root.WithOpName("Dequeue5"), q5,
665 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
666
667 GrapplerItem item;
668 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
669
670 GraphProperties properties(item);
671 TF_ASSERT_OK(properties.InferStatically(false));
672
673 const auto props1 = properties.GetOutputProperties("Dequeue1");
674 ASSERT_EQ(1, props1.size());
675 EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
676
677 const auto props2 = properties.GetOutputProperties("Dequeue2");
678 ASSERT_EQ(1, props2.size());
679 EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
680
681 // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
682 // that we merge the 2 properly to determine the shape of the data coming out
683 // of the queue.
684 const auto props4 = properties.GetOutputProperties("Dequeue4");
685 ASSERT_EQ(1, props4.size());
686 EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
687
688 // The dequeue5 op shape is known.
689 const auto props5 = properties.GetOutputProperties("Dequeue5");
690 ASSERT_EQ(3, props5.size());
691 EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
692 EXPECT_EQ("double: [10]", PropToString(props5[1]));
693 EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
694 }
695
TEST_F(GraphPropertiesTest,MergeWithoutLoops)696 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
697 // Test graph produced in python using:
698 /*
699 with tf.Graph().as_default():
700 x = tf.constant(2)
701 y = tf.constant(5)
702 z = tf.ones([1,1,1])
703 def f1(): return tf.concat([z, z], axis=0)
704 def f2(): return tf.concat([z, z], axis=1)
705 r = tf.cond(tf.less(x, y), f1, f2)
706 tf.concat([r, r], axis=2)
707 with open('/tmp/graph.pbtxt', 'w') as f:
708 f.write(str(tf.get_default_graph().as_graph_def()))
709 */
710
711 GrapplerItem item;
712 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
713 "merge_without_loops.pbtxt");
714 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
715 GraphProperties properties(item);
716 TF_ASSERT_OK(properties.InferStatically(false));
717
718 std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
719 std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
720 "float: [1,2,1]"};
721 for (int i = 0; i < nodes.size(); i++) {
722 const auto props = properties.GetOutputProperties(nodes[i]);
723 const OpInfo::TensorProperties& prop = props[0];
724 EXPECT_EQ(DT_FLOAT, prop.dtype());
725 EXPECT_EQ(expected_outputs[i], PropToString(prop));
726 }
727
728 // The "Less" node should be fed by 2 int32 scalar constant values.
729 const auto props = properties.GetInputProperties("Less");
730 EXPECT_EQ(2, props.size());
731 for (int i = 0; i < props.size(); ++i) {
732 EXPECT_EQ(DT_INT32, props[i].dtype());
733 EXPECT_TRUE(props[i].has_value());
734 EXPECT_EQ("int32: []", PropToString(props[i]));
735 }
736 }
737
TEST_F(GraphPropertiesTest,WhileLoop)738 TEST_F(GraphPropertiesTest, WhileLoop) {
739 // Test graph produced in python using:
740 /*
741 with tf.Graph().as_default():
742 i0 = tf.constant(0)
743 m0 = tf.placeholder([-1, 2])
744 c = lambda i, m: i < 10
745 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
746 r = tf.while_loop(
747 c, b, loop_vars=[i0, m0],
748 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
749 with open('/tmp/graph.pbtxt', 'w') as f:
750 f.write(str(tf.get_default_graph().as_graph_def()))
751 */
752
753 GrapplerItem item;
754 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
755 "while_loop.pbtxt");
756 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
757 GraphProperties properties(item);
758 TF_ASSERT_OK(properties.InferStatically(false));
759
760 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
761 "while/Exit_1"};
762 for (const string& node : nodes) {
763 const auto props = properties.GetOutputProperties(node);
764 const OpInfo::TensorProperties& prop = props[0];
765 EXPECT_EQ(DT_FLOAT, prop.dtype());
766 EXPECT_EQ("float: [-1,2]", PropToString(prop));
767 }
768
769 // The loop outputs batch dim should be different from the input batch dim
770 // since we concatenated along the batch dim.
771 auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
772 auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
773 EXPECT_GE(-2, shape_in.dim(0).size());
774 EXPECT_GE(-2, shape_out.dim(0).size());
775 EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
776 }
777
TEST_F(GraphPropertiesTest,NestedLoop)778 TEST_F(GraphPropertiesTest, NestedLoop) {
779 // Test graph produced in python using:
780 /*
781 with tf.Graph().as_default():
782 i0 = tf.constant(0)
783
784 def inner(j, y):
785 def inner_cond(j, y):
786 return j < 3
787
788 def inner_body(j, y):
789 return j+1, tf.concat([y, y], axis=2)
790
791 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
792 shape_invariants=[i0.get_shape(),
793 tf.TensorShape([None, 1, None])])
794
795 def outer_cond(i, x):
796 return i < 3
797
798 def outer_body(i, x):
799 j, y = inner(0, x)
800 return i+1, tf.concat([x, x], axis=0)
801
802 r = tf.while_loop(outer_cond, outer_body,
803 loop_vars=[i0, tf.ones([1, 1, 1])],
804 shape_invariants=[i0.get_shape(),
805 tf.TensorShape([None, 1, None])])
806
807 with open('/tmp/graph.pbtxt', 'w') as f:
808 f.write(str(tf.get_default_graph().as_graph_def()))
809 */
810
811 GrapplerItem item;
812 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
813 "nested_loop.pbtxt");
814 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
815 GraphProperties properties(item);
816 TF_ASSERT_OK(properties.InferStatically(false));
817
818 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
819 "while/Exit_1"};
820 std::vector<string> inner_nodes{"while/while/Merge_1",
821 "while/while/NextIteration_1",
822 "while/while/Exit_1"};
823 for (const string& node : outer_nodes) {
824 const auto props = properties.GetOutputProperties(node);
825 const OpInfo::TensorProperties& prop = props[0];
826 EXPECT_EQ(DT_FLOAT, prop.dtype());
827 EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
828 }
829 for (const string& node : inner_nodes) {
830 const auto props = properties.GetOutputProperties(node);
831 const OpInfo::TensorProperties& prop = props[0];
832 EXPECT_EQ(DT_FLOAT, prop.dtype());
833 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
834 }
835 }
836
TEST_F(GraphPropertiesTest,LoopsAndQueues)837 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
838 // Test graph produced in python using:
839 /*
840 with tf.Graph().as_default():
841 i0 = tf.constant(0)
842 q = tf.FIFOQueue(1, "float")
843
844 def inner(j, y):
845 def inner_cond(j, y):
846 return j < 3
847
848 def inner_body(j, y):
849 return j+1, tf.concat([y, y], axis=0)
850
851 return tf.while_loop(inner_cond, inner_body,
852 loop_vars=[j, y],
853 shape_invariants=[i0.get_shape(),
854 tf.TensorShape(None)])
855
856 def outer_cond(i, x):
857 return i < 3
858
859 def outer_body(i, x):
860 q.enqueue(x)
861 y = tf.concat([x, x], axis=2)
862 inner(0, q.dequeue())
863 return i+1, y
864
865 i, z = tf.while_loop(outer_cond, outer_body,
866 loop_vars=[i0, tf.ones([1, 1, 1])],
867 shape_invariants=[i0.get_shape(),
868 tf.TensorShape([None, 1, None])])
869
870 with open('/tmp/graph.pbtxt', 'w') as f:
871 f.write(str(tf.get_default_graph().as_graph_def()))
872 */
873
874 GrapplerItem item;
875 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
876 "loops_and_queues.pbtxt");
877 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
878 GraphProperties properties(item);
879 TF_ASSERT_OK(properties.InferStatically(false));
880
881 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
882 "while/Exit_1"};
883 std::vector<string> inner_nodes{"while/while/Merge_1",
884 "while/while/NextIteration_1",
885 "while/while/Exit_1"};
886 for (const string& node : outer_nodes) {
887 const auto props = properties.GetOutputProperties(node);
888 const OpInfo::TensorProperties& prop = props[0];
889 EXPECT_EQ(DT_FLOAT, prop.dtype());
890 EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
891 }
892 for (const string& node : inner_nodes) {
893 const auto props = properties.GetOutputProperties(node);
894 const OpInfo::TensorProperties& prop = props[0];
895 EXPECT_EQ(DT_FLOAT, prop.dtype());
896 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
897 }
898 }
899
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)900 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
901 // Test graph produced in python using:
902 /*
903 with tf.Graph().as_default():
904 i0 = tf.constant(0)
905 with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
906 v = tf.get_variable(initializer=i0, name='loop_var')
907
908 def inner(j, y):
909 def inner_cond(j, y):
910 return j < 3
911
912 def inner_body(j, y):
913 return j + 1, y + y
914
915 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
916
917 def outer_cond(i, x):
918 return i < 3
919
920 def outer_body(i, x):
921 y = x + x
922 inner(0, v)
923 return i + 1, y
924
925 v, z = tf.while_loop(outer_cond, outer_body,
926 loop_vars=[v, tf.constant(1)])
927
928 with open('/tmp/graph.pbtxt', 'w') as f:
929 f.write(str(tf.get_default_graph().as_graph_def()))
930 */
931
932 GrapplerItem item;
933 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
934 "loops_and_resource_vars.pbtxt");
935 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
936 GraphProperties properties(item);
937 TF_ASSERT_OK(properties.InferStatically(false));
938
939 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
940 "while/Exit_1"};
941 std::vector<string> inner_nodes{"while/while/Merge_1",
942 "while/while/NextIteration_1",
943 "while/while/Exit_1"};
944 for (const string& node : outer_nodes) {
945 const auto props = properties.GetOutputProperties(node);
946 const OpInfo::TensorProperties& prop = props[0];
947 EXPECT_EQ(DT_INT32, prop.dtype());
948 EXPECT_EQ("int32: []", PropToString(prop));
949 }
950 for (const string& node : inner_nodes) {
951 const auto props = properties.GetOutputProperties(node);
952 const OpInfo::TensorProperties& prop = props[0];
953 EXPECT_EQ(DT_INT32, prop.dtype());
954 EXPECT_EQ("int32: []", PropToString(prop));
955 }
956 }
957
TEST_F(GraphPropertiesTest,QueuesAndLoops)958 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
959 // Test graph produced in python using:
960 /*
961 with tf.Graph().as_default():
962 i0 = tf.constant(0)
963 q0 = tf.FIFOQueue(1, "float")
964 q0.enqueue(tf.ones([2, 2]))
965 q1 = tf.FIFOQueue(1, "float")
966
967 def c(i, m):
968 return i < 10
969
970 def b(i, m):
971 return i+1, tf.concat([m, m], axis=0)
972
973 i, m = tf.while_loop(
974 c, b, loop_vars=[i0, q0.dequeue()],
975 shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
976
977 q1.enqueue(m)
978 v = q1.dequeue();
979 tf.concat([v, v], axis=1)
980 with open('/tmp/graph.pbtxt', 'w') as f:
981 f.write(str(tf.get_default_graph().as_graph_def()))
982 */
983
984 GrapplerItem item;
985 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
986 "queues_and_loops.pbtxt");
987 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
988 GraphProperties properties(item);
989 TF_ASSERT_OK(properties.InferStatically(false));
990
991 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
992 "while/Exit_1"};
993
994 for (const string& node : nodes) {
995 const auto props = properties.GetOutputProperties(node);
996 const OpInfo::TensorProperties& prop = props[0];
997 EXPECT_EQ(DT_FLOAT, prop.dtype());
998 EXPECT_EQ("float: [-1,2]", PropToString(prop));
999 }
1000
1001 const auto props = properties.GetOutputProperties("concat");
1002 const OpInfo::TensorProperties& prop = props[0];
1003 EXPECT_EQ(DT_FLOAT, prop.dtype());
1004 EXPECT_EQ("float: [-1,4]", PropToString(prop));
1005 }
1006
TEST_F(GraphPropertiesTest,InferRestoreOpShape)1007 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
1008 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1009 Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
1010 DataType::DT_FLOAT);
1011 Output filename =
1012 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1013 Output tensor_name =
1014 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1015 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1016 DataType::DT_FLOAT);
1017 Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
1018
1019 Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
1020 string("256 256 0,128:-"), TensorShape());
1021 Output restore_slice =
1022 ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
1023 shape_and_slice, DataType::DT_FLOAT);
1024 Output init_restore_slice =
1025 ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
1026
1027 Output restore_v2 =
1028 ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
1029 shape_and_slice, DataType::DT_FLOAT);
1030 Output init_restore_v2 =
1031 ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
1032
1033 GrapplerItem item;
1034 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1035 item.fetch.push_back("init_restore");
1036
1037 GraphProperties properties(item);
1038 TF_ASSERT_OK(properties.InferStatically(false));
1039
1040 const auto restore_props = properties.GetOutputProperties("restore");
1041 const OpInfo::TensorProperties& restore_prop = restore_props[0];
1042 EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
1043 EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
1044
1045 const auto restore_slice_props =
1046 properties.GetOutputProperties("restore_slice");
1047 const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
1048 EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
1049 EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
1050
1051 const auto restorev2_props = properties.GetOutputProperties("restore_v2");
1052 const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
1053 EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
1054 EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
1055
1056 // Check input shapes of assign op are propagated correctly.
1057 const auto input_props = properties.GetInputProperties("init_restore");
1058 ASSERT_EQ(2, input_props.size());
1059 const OpInfo::TensorProperties& input_prop = input_props[1];
1060 EXPECT_EQ(DT_FLOAT, input_prop.dtype());
1061 EXPECT_EQ("float: [128,256]", PropToString(input_prop));
1062 }
1063
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)1064 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
1065 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1066 Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
1067 DataType::DT_FLOAT);
1068 Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
1069 DataType::DT_FLOAT);
1070 Output filename =
1071 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1072 Output tensor_name =
1073 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1074 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1075 DataType::DT_FLOAT);
1076 Output init = ops::Assign(s.WithOpName("init"), var, restore);
1077 Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
1078
1079 GrapplerItem item;
1080 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1081 item.fetch.push_back("init");
1082 item.fetch.push_back("init2");
1083
1084 GraphProperties properties(item);
1085 TF_ASSERT_OK(properties.InferStatically(false));
1086
1087 const auto props = properties.GetOutputProperties("restore");
1088 const OpInfo::TensorProperties& prop = props[0];
1089 EXPECT_EQ(DT_FLOAT, prop.dtype());
1090 EXPECT_EQ("float: [128,256]", PropToString(prop));
1091 }
1092
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)1093 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
1094 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1095 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1096 Output a1 = ops::Identity(s.WithOpName("a1"), a);
1097 Output b = ops::Const(s.WithOpName("b"), 99, {});
1098 Output b1 = ops::Identity(s.WithOpName("b1"), b);
1099 Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
1100 Output c1 = ops::Identity(s.WithOpName("c1"), c);
1101
1102 GrapplerItem item;
1103 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1104 GraphProperties properties(item);
1105 TF_ASSERT_OK(properties.InferStatically(false));
1106
1107 // Check output shapes.
1108 EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
1109 EXPECT_EQ("int32: [2]",
1110 PropToString(properties.GetOutputProperties("a1")[0]));
1111 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
1112 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
1113 EXPECT_EQ("int32: [4,4,4]",
1114 PropToString(properties.GetOutputProperties("c")[0]));
1115 EXPECT_EQ("int32: [4,4,4]",
1116 PropToString(properties.GetOutputProperties("c1")[0]));
1117
1118 // Check has_value.
1119 EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
1120 EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
1121 EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
1122 EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
1123 EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
1124 EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
1125 EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
1126 EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
1127 // Note that we propagate tensor value of only 1D vector and scalar.
1128 EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
1129
1130 // Check values.
1131 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
1132 ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
1133 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
1134 ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
1135 ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
1136 ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
1137 std::vector<int64> c_values;
1138 for (int i = 0; i < 4 * 4 * 4; i++) {
1139 c_values.push_back(1);
1140 }
1141 ExpectTensorValues({c_values},
1142 properties.GetOutputProperties("c")[0].value());
1143 ExpectTensorValues({c_values},
1144 properties.GetInputProperties("c1")[0].value());
1145 ExpectTensorValues({c_values},
1146 properties.GetOutputProperties("c1")[0].value());
1147 }
1148
TEST_F(GraphPropertiesTest,IdentityPassingShape)1149 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
1150 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1151 Output a = ops::Const(s.WithOpName("a"), 5, {2});
1152 Output b = ops::Identity(s.WithOpName("b"), a);
1153 Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1154 // Fill needs not only e's shape but also the value of e to figure out output
1155 // shape; hence, Identity op (b) should pass a's value as
1156 // output_tensors_as_shape.
1157 Output d = ops::Fill(s.WithOpName("fill"), b, c);
1158
1159 GrapplerItem item;
1160 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1161 GraphProperties properties(item);
1162 TF_ASSERT_OK(properties.InferStatically(false));
1163 const auto out_props = properties.GetOutputProperties("fill");
1164 const OpInfo::TensorProperties out_prop0 = out_props[0];
1165 EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1166 }
1167
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1168 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1169 // When using aggressive_shape_inference, we run EvaluateNode() for
1170 // allowlisted ops and small input / output tensors. For instance, Fill op is
1171 // evaluated and produces output tensor value if output tensor size is small
1172 // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1173 // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1174 // initializing a large table using Fill.
1175 // Note that we do not propagate float const tensors with fractional values
1176 // (even if they're small); so this test should use integer values.
1177 {
1178 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1179 Output a = ops::Const(s.WithOpName("a"), 4, {2}); // 4x4
1180 Output b = ops::Const(s.WithOpName("const"), 7, {});
1181 // Shape described by a is small; expect output values of Fill op.
1182 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1183
1184 GrapplerItem item;
1185 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1186 GraphProperties properties(item);
1187 TF_ASSERT_OK(properties.InferStatically(
1188 /*assume_valid_feeds=*/false,
1189 /*aggressive_shape_inference=*/true,
1190 /*include_tensor_values=*/true));
1191 const auto out_props = properties.GetOutputProperties("fill");
1192 const OpInfo::TensorProperties out_prop0 = out_props[0];
1193 EXPECT_EQ("int32: [4,4]", PropToString(out_prop0));
1194 EXPECT_TRUE(out_prop0.has_value());
1195 }
1196 {
1197 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1198 Output a = ops::Const(s.WithOpName("a"), 1000, {4}); // 1000x1000x1000x1000
1199 Output b = ops::Const(s.WithOpName("const"), 7, {});
1200 // Shape described by a is huge; in that case we skip value inference.
1201 // Otherwise, it'd be too much overhead.
1202 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1203
1204 GrapplerItem item;
1205 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1206 GraphProperties properties(item);
1207 TF_ASSERT_OK(properties.InferStatically(
1208 /*assume_valid_feeds=*/false,
1209 /*aggressive_shape_inference=*/true,
1210 /*include_tensor_values=*/true));
1211 const auto out_props = properties.GetOutputProperties("fill");
1212 const OpInfo::TensorProperties out_prop0 = out_props[0];
1213 EXPECT_EQ("int32: [1000,1000,1000,1000]", PropToString(out_prop0));
1214 EXPECT_FALSE(out_prop0.has_value());
1215 }
1216 }
1217
TEST_F(GraphPropertiesTest,PackWithConstInput)1218 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1219 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1220 Output a = ops::Const(s.WithOpName("a"), 1, {});
1221 Output b = ops::Const(s.WithOpName("b"), 2, {});
1222 Output c = ops::Const(s.WithOpName("c"), 3, {});
1223 Output d = ops::Const(s.WithOpName("d"), 4, {});
1224 // Note ops::Stack instantiates Pack op.
1225 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1226 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1227 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1228 // Fill needs not only e's shape but also its value to figure out output
1229 // shape.
1230 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1231
1232 GrapplerItem item;
1233 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1234 GraphProperties properties(item);
1235 TF_ASSERT_OK(properties.InferStatically(false));
1236 const auto out_props = properties.GetOutputProperties("fill");
1237 const OpInfo::TensorProperties out_prop0 = out_props[0];
1238 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1239 }
1240
TEST_F(GraphPropertiesTest,RankOp)1241 TEST_F(GraphPropertiesTest, RankOp) {
1242 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1243 Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1244 Output r = ops::Rank(s.WithOpName("Rank"), c);
1245 Output i = ops::Identity(s.WithOpName("Identity"), r);
1246
1247 GrapplerItem item;
1248 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1249 GraphProperties properties(item);
1250 TF_ASSERT_OK(properties.InferStatically(false));
1251 const auto rank_props = properties.GetOutputProperties("Rank");
1252 const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1253 EXPECT_EQ("int32: []", PropToString(rank_prop0));
1254 EXPECT_TRUE(rank_prop0.has_value());
1255 ExpectTensorValues({3}, rank_prop0.value());
1256 const auto identity_props = properties.GetOutputProperties("Identity");
1257 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1258 EXPECT_EQ("int32: []", PropToString(identity_props0));
1259 EXPECT_TRUE(identity_props0.has_value());
1260 ExpectTensorValues({3}, identity_props0.value());
1261 }
1262
TEST_F(GraphPropertiesTest,SizeOp)1263 TEST_F(GraphPropertiesTest, SizeOp) {
1264 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1265 Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1266 Output r = ops::Size(s.WithOpName("Size"), c);
1267 Output i = ops::Identity(s.WithOpName("Identity"), r);
1268
1269 GrapplerItem item;
1270 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1271 GraphProperties properties(item);
1272 TF_ASSERT_OK(properties.InferStatically(false));
1273 const auto size_props = properties.GetOutputProperties("Size");
1274 const OpInfo::TensorProperties size_props0 = size_props[0];
1275 EXPECT_EQ("int32: []", PropToString(size_props0));
1276 EXPECT_TRUE(size_props0.has_value());
1277 ExpectTensorValues({24}, size_props0.value());
1278 const auto identity_props = properties.GetOutputProperties("Identity");
1279 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1280 EXPECT_EQ("int32: []", PropToString(identity_props0));
1281 EXPECT_TRUE(identity_props0.has_value());
1282 ExpectTensorValues({24}, identity_props0.value());
1283 }
1284
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1285 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1286 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1287 Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1288 Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1289 Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1290 // pack is [2], with values {4, -1}.
1291
1292 Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1293 Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1294
1295 Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1296 Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1297 // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1298 // their output shapes would be [4, -1]. However, though we use the same
1299 // shape input to the Reshape ops, their output shapes can be different;
1300 // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1301 // the same.
1302
1303 // if input to the Select ops. Note that s0 has a fully defined shape, while
1304 // s1 has unknown shape.
1305 Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1306 Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1307
1308 Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1309 Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1310
1311 // We instantiate SelectV2, but will replace it with Select. The shape
1312 // inference function for Select links all inputs and outputs as they should
1313 // have the same shapes.
1314 Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1315 Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1316
1317 // For z0, as we know the shape of s0, symbolic shape manager in shape
1318 // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1319 // which is [4, 16].
1320 // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1321 // at best.
1322 // Note that x0 and x1 share the same shape input to the Reshape op, but
1323 // -1 in the shape input should not be treated as the same symoblic unknown
1324 // dim; it is merely a constant value -1 for identitying unknown dim for
1325 // Reshape operation.
1326
1327 GrapplerItem item;
1328 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1329
1330 // Replace SelectV2 op with Select op.
1331 for (int i = 0; i < item.graph.node_size(); ++i) {
1332 auto* node = item.graph.mutable_node(i);
1333 if (node->op() == "SelectV2") {
1334 node->set_op("Select");
1335 }
1336 }
1337
1338 GraphProperties properties(item);
1339 TF_ASSERT_OK(properties.InferStatically(false));
1340 for (const auto& node_name : {"x0", "y0", "z0"}) {
1341 const auto out_props = properties.GetOutputProperties(node_name);
1342 const OpInfo::TensorProperties out_prop0 = out_props[0];
1343 EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1344 }
1345 {
1346 const auto out_props = properties.GetOutputProperties("s0");
1347 const OpInfo::TensorProperties out_prop0 = out_props[0];
1348 EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1349 }
1350
1351 for (const auto& node_name : {"x1", "y1", "z1"}) {
1352 const auto out_props = properties.GetOutputProperties(node_name);
1353 const OpInfo::TensorProperties out_prop0 = out_props[0];
1354 EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1355 }
1356 // if input of Select can be either vector or the same shape to the
1357 // input/output; in this case, even if we know input and output are
1358 // [4, ?], we can't say it's [4, ?] or a vector; hence, it shoudl be
1359 // unknown.
1360 {
1361 const auto out_props = properties.GetOutputProperties("s1");
1362 const OpInfo::TensorProperties out_prop0 = out_props[0];
1363 EXPECT_EQ("bool: ?", PropToString(out_prop0));
1364 }
1365 }
1366
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1367 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1368 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1369 // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1370 // from Const.
1371 // If output_tensors_as_shape is not set for those Shape ops or Pack op
1372 // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1373 // hence, its output shape becomes unknown.
1374 Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1375 Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1376 Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1377 Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1378 Output a = ops::Identity(s.WithOpName("a"), a0);
1379 Output b = ops::Identity(s.WithOpName("b"), b0);
1380 Output c = ops::Identity(s.WithOpName("c"), c0);
1381 Output d = ops::Identity(s.WithOpName("d"), d0);
1382 // Note ops::Stack instantiates Pack op.
1383 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1384 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1385 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1386 // Fill needs not only e's shape but also its value to figure out output
1387 // shape.
1388 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1389
1390 GrapplerItem item;
1391 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1392 GraphProperties properties(item);
1393 TF_ASSERT_OK(properties.InferStatically(false));
1394 const auto out_props = properties.GetOutputProperties("fill");
1395 const OpInfo::TensorProperties out_prop0 = out_props[0];
1396 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1397 }
1398
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1399 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1400 // Function ops may have DT_RESOURCE input; if not properly set shapes and
1401 // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1402 // function ops.
1403 GrapplerItem item;
1404 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1405 "function_with_dt_resource_input.pbtxt");
1406 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1407
1408 // This graph evaluates FunctionWithDtResourceInput with two inputs:
1409 // x [DT_FLOAT Const],
1410 // _Arg [DT_RESOURCE _Arg]
1411 // and has two outputs:
1412 // z1 = x + _Arg
1413 // z2 = x
1414 {
1415 GraphProperties properties(item);
1416 TF_ASSERT_OK(properties.InferStatically(false));
1417 const auto out_props =
1418 properties.GetOutputProperties("FunctionWithDtResourceInput");
1419 EXPECT_EQ(out_props.size(), 2);
1420 const OpInfo::TensorProperties out_prop0 = out_props[0];
1421 EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1422 const OpInfo::TensorProperties out_prop1 = out_props[1];
1423 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1424 }
1425
1426 {
1427 // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1428 for (int i = 0; i < item.graph.node_size(); i++) {
1429 auto* node = item.graph.mutable_node(i);
1430 if (node->name() == "y") { // _Arg node with DT_RESOURCE
1431 node->mutable_attr()->erase("_handle_dtypes");
1432 node->mutable_attr()->erase("_handle_shapes");
1433 break;
1434 }
1435 }
1436 // We cannot infer the function output shape correctly without those attr,
1437 // but still it shouldn't fail; also, there can be some shapes we can
1438 // infer in such a case. In this test graph,
1439 // z2 of the function node just returns x input; hence, even if _Arg's shape
1440 // cannot be inferred, we can infer z2 output shape.
1441 GraphProperties properties(item);
1442 TF_ASSERT_OK(properties.InferStatically(false));
1443 const auto out_props =
1444 properties.GetOutputProperties("FunctionWithDtResourceInput");
1445 EXPECT_EQ(out_props.size(), 2);
1446 const OpInfo::TensorProperties out_prop0 = out_props[0];
1447 // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1448 // for x + _Arg.
1449 EXPECT_EQ("float: ?", PropToString(out_prop0));
1450 // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1451 // infer this output shape.
1452 const OpInfo::TensorProperties out_prop1 = out_props[1];
1453 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1454 }
1455 }
1456
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1457 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1458 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1459 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1460 Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1461 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1462 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1463 s.graph()->op_registry());
1464 tensorflow::Node* func_op;
1465 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1466 auto _value = tensorflow::ops::AsNodeOut(s, value);
1467 TF_ASSERT_OK(
1468 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1469 GrapplerItem item;
1470 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1471
1472 GraphProperties properties(item);
1473 TF_ASSERT_OK(properties.InferStatically(false));
1474 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1475 const OpInfo::TensorProperties out_prop0 = out_props[0];
1476 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1477 }
1478
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1479 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1480 // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1481 // so tensor shapes, not tensor value, should be used as Const input to
1482 // function.
1483 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1484 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1485 Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1486 Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1487 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1488 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1489 s.graph()->op_registry());
1490 tensorflow::Node* func_op;
1491 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1492 auto _value = tensorflow::ops::AsNodeOut(s, value);
1493 TF_ASSERT_OK(
1494 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1495 GrapplerItem item;
1496 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1497
1498 GraphProperties properties(item);
1499 TF_ASSERT_OK(properties.InferStatically(false));
1500 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1501 const OpInfo::TensorProperties out_prop0 = out_props[0];
1502 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1503 }
1504
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1505 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1506 FunctionDefLibrary library;
1507 *library.add_function() = FunctionDefHelper::Create(
1508 "MyFunc", // Name
1509 {"x: int32"}, // Inputs
1510 {"out: int32"}, // Outputs
1511 {}, // Attrs
1512 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1513 {{"out", "a:output:0"}}); // Returns
1514 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1515 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1516
1517 // MyFunc takes Const (shape) and passes it with Identity. Expect function
1518 // output has the same shape as well as value (output_tensors_as_shape) as
1519 // input Const tensor.
1520 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1521 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1522 auto builder =
1523 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1524 tensorflow::Node* func_op;
1525 TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1526
1527 GrapplerItem item;
1528 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1529
1530 GraphProperties properties(item);
1531 TF_ASSERT_OK(properties.InferStatically(true));
1532 const auto out_props = properties.GetOutputProperties("MyFunc");
1533 const OpInfo::TensorProperties out_prop0 = out_props[0];
1534 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1535 EXPECT_TRUE(out_prop0.has_value());
1536 ExpectTensorValues({5, 7}, out_prop0.value());
1537 ExpectTensorValues({5, 7},
1538 properties.GetInputProperties("MyFunc")[0].value());
1539 }
1540
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1541 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1542 FunctionDefLibrary library;
1543 // Function that adds two input values.
1544 *library.add_function() = FunctionDefHelper::Create(
1545 "MyFunc", // Name
1546 {"x: int32", "y: int32"}, // Inputs
1547 {"out: int32"}, // Outputs
1548 {}, // Attrs
1549 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1550 {{"out", "a:z:0"}}); // Returns
1551 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1552 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1553
1554 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1555 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1556 auto builder =
1557 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1558 tensorflow::Node* func_op;
1559 TF_ASSERT_OK(
1560 builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1561
1562 GrapplerItem item;
1563 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1564 {
1565 GraphProperties properties(item);
1566 // Without aggressive_shape_inference, the internal function does not
1567 // evaluate output value.
1568 TF_ASSERT_OK(properties.InferStatically(
1569 /*assume_valid_feeds=*/true,
1570 /*aggressive_shape_inference=*/false,
1571 /*include_tensor_values=*/true));
1572 const auto out_props = properties.GetOutputProperties("MyFunc");
1573 const OpInfo::TensorProperties out_prop0 = out_props[0];
1574 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1575 EXPECT_FALSE(out_prop0.has_value());
1576 }
1577
1578 {
1579 GraphProperties properties(item);
1580 // With aggressive_shape_inference, output value is evaluated.
1581 TF_ASSERT_OK(properties.InferStatically(
1582 /*assume_valid_feeds=*/true,
1583 /*aggressive_shape_inference=*/true,
1584 /*include_tensor_values=*/true));
1585 const auto out_props = properties.GetOutputProperties("MyFunc");
1586 const OpInfo::TensorProperties out_prop0 = out_props[0];
1587 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1588 EXPECT_TRUE(out_prop0.has_value());
1589
1590 ExpectTensorValues({10, 14}, out_prop0.value());
1591 ExpectTensorValues({5, 7},
1592 properties.GetInputProperties("MyFunc")[0].value());
1593 ExpectTensorValues({5, 7},
1594 properties.GetInputProperties("MyFunc")[1].value());
1595 }
1596 }
1597
1598 // Same as the above, but float values; also, one of the function input is
1599 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1600 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1601 FunctionDefLibrary library;
1602 // Function that adds two input values.
1603 *library.add_function() = FunctionDefHelper::Create(
1604 "MyFunc", // Name
1605 {"x: float", "y: float"}, // Inputs
1606 {"out: float"}, // Outputs
1607 {}, // Attrs
1608 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1609 {{"out", "a:z:0"}}); // Returns
1610 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1611 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1612
1613 Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1614 Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1615 auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1616 auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1617 auto builder =
1618 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1619 tensorflow::Node* func_op;
1620 TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1621
1622 GrapplerItem item;
1623 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1624 {
1625 GraphProperties properties(item);
1626 // Without aggressive_shape_inference, the internal function does not
1627 // evaluate output value.
1628 TF_ASSERT_OK(properties.InferStatically(
1629 /*assume_valid_feeds=*/true,
1630 /*aggressive_shape_inference=*/false,
1631 /*include_tensor_values=*/true));
1632 const auto out_props = properties.GetOutputProperties("MyFunc");
1633 const OpInfo::TensorProperties out_prop0 = out_props[0];
1634 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1635 EXPECT_FALSE(out_prop0.has_value());
1636 }
1637
1638 {
1639 GraphProperties properties(item);
1640 // With aggressive_shape_inference, output value is evaluated.
1641 TF_ASSERT_OK(properties.InferStatically(
1642 /*assume_valid_feeds=*/true,
1643 /*aggressive_shape_inference=*/true,
1644 /*include_tensor_values=*/true));
1645 const auto out_props = properties.GetOutputProperties("MyFunc");
1646 const OpInfo::TensorProperties out_prop0 = out_props[0];
1647 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1648 EXPECT_TRUE(out_prop0.has_value());
1649
1650 ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1651 ExpectFloatTensorValues({5.0, 7.0},
1652 properties.GetInputProperties("MyFunc")[0].value());
1653 ExpectFloatTensorValues({5.0, 7.0},
1654 properties.GetInputProperties("MyFunc")[1].value());
1655 }
1656 }
1657
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1658 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1659 // Create graph with a function that takes a scalar value so that we use
1660 // Placeholder with scalar as for input to the function shape inference.
1661 // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1662 // the input; all tensors are scalars.
1663 FunctionDefLibrary library;
1664 *library.add_function() = FunctionDefHelper::Create(
1665 "MyFunc", // Name
1666 {"x: float"}, // Inputs
1667 {"out: float"}, // Outputs
1668 {}, // Attrs
1669 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1670 {{"out", "a:output:0"}}); // Returns
1671 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1672 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1673 Output placeholder =
1674 ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1675 ops::Placeholder::Shape(TensorShape({})));
1676 Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1677 auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1678 auto builder =
1679 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1680 tensorflow::Node* func_op;
1681 TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1682 GrapplerItem item;
1683 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1684
1685 // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1686 // as unknown, instead of scalar.
1687 EXPECT_GT(item.graph.versions().producer(), 21);
1688
1689 // MyFunc output shouldn't be unknown rank.
1690 GraphProperties properties(item);
1691 TF_ASSERT_OK(properties.InferStatically(true));
1692 const auto out_props = properties.GetOutputProperties("MyFunc");
1693 const OpInfo::TensorProperties out_prop0 = out_props[0];
1694 EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1695 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1696 }
1697
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1698 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1699 // Test graph produced in python using:
1700 /*
1701 @function.Defun(*[tf.float32] * 2, noinline=True)
1702 def MyAdd(x, y):
1703 return tf.add(x,y)
1704
1705 with tf.Graph().as_default():
1706 x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1707 y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1708 z = MyAdd(x, y)
1709 z = MyAdd(x, z)
1710 */
1711 GrapplerItem item;
1712 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1713 "simple_function.pbtxt");
1714 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1715 GraphProperties properties(item);
1716 TF_ASSERT_OK(properties.InferStatically(false));
1717 const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1718 const OpInfo::TensorProperties& out_prop = out_props[0];
1719 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1720 EXPECT_FALSE(out_prop.shape().unknown_rank());
1721 EXPECT_EQ(2, out_prop.shape().dim_size());
1722 EXPECT_EQ(1, out_prop.shape().dim(0).size());
1723 EXPECT_EQ(2, out_prop.shape().dim(1).size());
1724
1725 const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1726 EXPECT_EQ(2, in_props.size());
1727
1728 const OpInfo::TensorProperties& in_prop = in_props[0];
1729 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1730
1731 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1732 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1733 }
1734
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1735 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1736 GrapplerItem item;
1737 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1738 "large_function_graph.pbtxt");
1739 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1740 GraphProperties properties(item);
1741 TF_ASSERT_OK(properties.InferStatically(false));
1742
1743 const auto out_props = properties.GetOutputProperties("y0");
1744 EXPECT_EQ(2, out_props.size());
1745
1746 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1747 EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1748
1749 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1750 EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1751
1752 const auto in_props = properties.GetInputProperties("y0");
1753 EXPECT_EQ(4, in_props.size());
1754
1755 const OpInfo::TensorProperties& in_prop0 = in_props[0];
1756 EXPECT_EQ("float: [64]", PropToString(in_prop0));
1757
1758 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1759 EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1760
1761 const OpInfo::TensorProperties& in_prop2 = in_props[2];
1762 EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1763
1764 const OpInfo::TensorProperties& in_prop3 = in_props[3];
1765 EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1766 }
1767
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1768 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1769 // Test graph produced in python using:
1770 /*
1771 @function.Defun(noinline=True)
1772 def MyFunc():
1773 @function.Defun(*[tf.float32] * 2)
1774 def Cond(n, unused_x):
1775 return n > 0
1776
1777 @function.Defun(*[tf.float32] * 2)
1778 def Body(n, x):
1779 return n - 1, x + n
1780
1781 i = tf.constant(10)
1782 return functional_ops.While([i, 0.], Cond, Body)
1783
1784 with tf.Graph().as_default():
1785 z = MyFunc()
1786 */
1787 GrapplerItem item;
1788 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1789 "function_functional_while.pbtxt");
1790 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1791 GraphProperties properties(item);
1792 TF_ASSERT_OK(properties.InferStatically(false));
1793
1794 const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1795 EXPECT_EQ(2, out_props.size());
1796
1797 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1798 EXPECT_EQ(DT_INT32, out_prop0.dtype());
1799 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1800
1801 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1802 EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1803 EXPECT_FALSE(out_prop1.shape().unknown_rank());
1804 }
1805
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1806 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1807 GrapplerItem item;
1808 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1809 "function_error.pbtxt");
1810 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1811 GraphProperties properties(item);
1812 TF_ASSERT_OK(properties.InferStatically(false));
1813
1814 const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1815 EXPECT_EQ(1, out_props.size());
1816
1817 const OpInfo::TensorProperties& out_prop = out_props[0];
1818 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1819 EXPECT_TRUE(out_prop.shape().unknown_rank());
1820
1821 const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1822 EXPECT_EQ(2, in_props.size());
1823
1824 const OpInfo::TensorProperties& in_prop = in_props[0];
1825 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1826
1827 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1828 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1829 }
1830
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1831 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1832 // Test graph produced in python using:
1833 /*
1834 @function.Defun(*[tf.float32] * 2, noinline=True)
1835 def MyAdd(x, y):
1836 return tf.add(x, y)
1837
1838 with tf.Graph().as_default():
1839 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1840 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1841 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1842 z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1843 */
1844 GrapplerItem item;
1845 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1846 "function_switch.pbtxt");
1847 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1848 GraphProperties properties(item);
1849 TF_ASSERT_OK(properties.InferStatically(false));
1850 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1851 const OpInfo::TensorProperties& out_prop = out_props[0];
1852 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1853 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1854
1855 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1856 EXPECT_EQ(2, in_props.size());
1857
1858 const OpInfo::TensorProperties& in_prop = in_props[0];
1859 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1860
1861 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1862 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1863 }
1864
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1865 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1866 // Test graph produced in python using:
1867 /*
1868 @function.Defun(*[tf.float32] * 2, noinline=True)
1869 def MyAdd(x, y):
1870 return tf.add(x, y)
1871
1872 with tf.Graph().as_default():
1873 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1874 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1875 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1876 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1877 */
1878 GrapplerItem item;
1879 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1880 "function_switch_2.pbtxt");
1881 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1882 GraphProperties properties(item);
1883 TF_ASSERT_OK(properties.InferStatically(false));
1884 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1885 const OpInfo::TensorProperties& out_prop = out_props[0];
1886 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1887
1888 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1889 EXPECT_EQ(2, in_props.size());
1890
1891 const OpInfo::TensorProperties& in_prop = in_props[0];
1892 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1893
1894 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1895 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1896 }
1897
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1898 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1899 // Test graph produced in python using:
1900 /*
1901 @function.Defun(*[tf.float32] * 2, noinline=True)
1902 def MyAdd(x, y):
1903 a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1904 b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1905 c = tf.add(x, a)
1906 d = tf.add(y, b)
1907 return c
1908
1909 with tf.Graph().as_default():
1910 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1911 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1912 z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1913 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1914 */
1915 GrapplerItem item;
1916 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1917 "function_switch_shapes.pbtxt");
1918 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1919 GraphProperties properties(item);
1920 TF_ASSERT_OK(properties.InferStatically(false));
1921 const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1922 const OpInfo::TensorProperties& out_prop = out_props[0];
1923 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1924
1925 const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1926 EXPECT_EQ(2, in_props.size());
1927
1928 const OpInfo::TensorProperties& in_prop = in_props[0];
1929 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1930
1931 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1932 EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1933 }
1934
TEST_F(GraphPropertiesTest,SymbolicShapes)1935 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1936 // Build a simple graph with placeholders of unknown dimensions. These
1937 // dimensions will be encoded symbolically.
1938 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1939
1940 Output a =
1941 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1942 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1943 Output b =
1944 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1945 ops::Placeholder::Shape(PartialTensorShape({-1})));
1946 Output c = ops::Identity(s.WithOpName("c"), a);
1947 Output d = ops::Identity(s.WithOpName("d"), b);
1948 Output e = ops::Add(s.WithOpName("e"), c, d);
1949 Output f = ops::Add(s.WithOpName("f"), a, c);
1950
1951 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1952 Output g = ops::Shape(s.WithOpName("g"), c);
1953 Output h = ops::Fill(s.WithOpName("h"), g, zero);
1954 Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1955 Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1956
1957 GrapplerItem item;
1958 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1959
1960 GraphProperties properties(item);
1961 TF_ASSERT_OK(properties.InferStatically(false));
1962 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1963 const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1964 EXPECT_EQ(2, shape_a.dim_size());
1965 EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1966 EXPECT_GE(-2, shape_a.dim(0).size());
1967 EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1968 EXPECT_GE(-2, shape_a.dim(1).size());
1969 EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1970
1971 PartialTensorShape shape(shape_a);
1972 EXPECT_FALSE(shape.IsFullyDefined());
1973 EXPECT_FALSE(shape.unknown_rank());
1974
1975 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1976 const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1977 EXPECT_EQ(1, shape_b.dim_size());
1978 EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1979 EXPECT_GE(-2, shape_b.dim(0).size());
1980 EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1981 EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1982
1983 const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1984 ASSERT_EQ(2, shape_e.dim_size());
1985 EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1986 EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1987 EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1988
1989 const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1990 ASSERT_EQ(2, shape_f.dim_size());
1991 EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1992 EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1993
1994 const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1995 ASSERT_EQ(2, shape_f.dim_size());
1996 EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1997 EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1998
1999 const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
2000 ASSERT_EQ(1, shape_j.dim_size());
2001 EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
2002 }
2003
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)2004 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
2005 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2006 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
2007 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
2008 Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
2009 GrapplerItem item;
2010 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2011 // Create a graph with node a removed (say by some graph optimization
2012 // pass), noting that node c is colocated with a. This is fine as it
2013 // is in the late stage of graph execution, the colocation constraints have
2014 // been validated previously and the device placement of nodes has completed.
2015 GraphDef optimized_graph;
2016 for (const auto& node : item.graph.node()) {
2017 if (node.name() != "a") {
2018 *optimized_graph.add_node() = node;
2019 }
2020 }
2021 item.graph.Swap(&optimized_graph);
2022 GraphProperties properties(item);
2023 // This function should return OK, since it doesn't validate the colocation
2024 // constraints internally.
2025 TF_EXPECT_OK(properties.InferStatically(false));
2026 }
2027
TEST_F(GraphPropertiesTest,ShapeTracking)2028 TEST_F(GraphPropertiesTest, ShapeTracking) {
2029 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2030 Output a =
2031 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2032 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2033 Output b =
2034 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
2035 ops::Placeholder::Shape(PartialTensorShape({-1})));
2036 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2037 auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
2038 Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
2039 Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
2040
2041 GrapplerItem item;
2042 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2043
2044 GraphProperties properties(item);
2045 TF_ASSERT_OK(properties.InferStatically(false));
2046 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2047 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
2048 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2049 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2050 EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
2051 EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
2052 }
2053
TEST_F(GraphPropertiesTest,FedNodes)2054 TEST_F(GraphPropertiesTest, FedNodes) {
2055 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
2056 cluster_->GetDeviceNames());
2057 GrapplerItem item;
2058 CHECK(fake_input.NextItem(&item));
2059
2060 {
2061 // Conservative shape analysis: the shape of fed ports should be unknown
2062 GraphProperties properties(item);
2063 Status s = properties.InferStatically(false);
2064 TF_ASSERT_OK(s);
2065 for (const auto& node : item.graph.node()) {
2066 if (node.op() == "Const") {
2067 continue;
2068 }
2069 const auto in_props = properties.GetInputProperties(node.name());
2070 EXPECT_EQ(1, in_props.size());
2071 const OpInfo::TensorProperties& in_prop = in_props[0];
2072 const auto out_props = properties.GetOutputProperties(node.name());
2073 EXPECT_EQ(1, out_props.size());
2074 const OpInfo::TensorProperties& out_prop = out_props[0];
2075
2076 if (node.name() == "x") {
2077 // x is fed: its input should have a known shape, while its output
2078 // doesn't
2079 EXPECT_FALSE(in_prop.shape().unknown_rank());
2080 EXPECT_EQ(1, in_prop.shape().dim_size());
2081 EXPECT_EQ(2, in_prop.shape().dim(0).size());
2082 EXPECT_TRUE(out_prop.shape().unknown_rank());
2083 } else if (node.op() == "Square" || node.op() == "AddN") {
2084 // These nodes are in the fanout of x: their shapes should be unknown.
2085 EXPECT_TRUE(in_prop.shape().unknown_rank());
2086 EXPECT_TRUE(out_prop.shape().unknown_rank());
2087 }
2088 }
2089 }
2090 {
2091 // Optimistic shape analysis: the shape of fed ports should be derived from
2092 // the shape of the fanin.
2093 GraphProperties properties(item);
2094 Status s = properties.InferStatically(true);
2095 TF_ASSERT_OK(s);
2096 for (const auto& node : item.graph.node()) {
2097 if (node.op() == "Square" || node.op() == "AddN") {
2098 const auto in_props = properties.GetInputProperties(node.name());
2099 EXPECT_EQ(1, in_props.size());
2100 const OpInfo::TensorProperties& in_prop = in_props[0];
2101 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
2102 EXPECT_FALSE(in_prop.shape().unknown_rank());
2103 EXPECT_EQ(2, in_prop.shape().dim_size());
2104 const auto out_props = properties.GetOutputProperties(node.name());
2105 EXPECT_EQ(1, out_props.size());
2106 const OpInfo::TensorProperties& out_prop = out_props[0];
2107 EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
2108 EXPECT_EQ(in_prop.shape().DebugString(),
2109 out_prop.shape().DebugString());
2110 }
2111 }
2112 }
2113 }
2114
TEST_F(GraphPropertiesTest,Performance)2115 TEST_F(GraphPropertiesTest, Performance) {
2116 // Load a large graph with many nested loops to make sure we can infer shapes
2117 // quickly.
2118 GrapplerItem item;
2119 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
2120 "large_graph.pbtxt.html");
2121 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
2122 TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
2123 &item.graph,
2124 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
2125 true));
2126
2127 GraphProperties properties(item);
2128 TF_ASSERT_OK(properties.InferStatically(false));
2129 }
2130
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)2131 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
2132 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2133 Output a =
2134 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2135 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2136 auto shp = ops::Shape(s.WithOpName("shape"), {a});
2137
2138 Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
2139 Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
2140 Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
2141
2142 Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
2143 Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
2144
2145 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2146 Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
2147 Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
2148
2149 GrapplerItem item;
2150 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2151
2152 GraphProperties properties(item);
2153 TF_ASSERT_OK(properties.InferStatically(false));
2154 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2155 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2156 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2157 EXPECT_EQ(2, shape_a.dim_size());
2158 EXPECT_EQ(1, shape_o1.dim_size());
2159 EXPECT_EQ(1, shape_o2.dim_size());
2160 EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2161 EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2162 }
2163
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2164 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2165 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2166 Output placeholder =
2167 ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2168 ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2169 auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2170
2171 Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2172 Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2173 Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2174
2175 Output slice =
2176 ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2177 stride, ops::StridedSlice::ShrinkAxisMask(1));
2178
2179 GrapplerItem item;
2180 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2181
2182 // Without aggressive shape inference, it cannot infer output value of
2183 // StridedSlice with ShrinkAxisMask.
2184 {
2185 GraphProperties properties(item);
2186 TF_ASSERT_OK(properties.InferStatically(
2187 /*assume_valid_feeds=*/false,
2188 /*aggressive_shape_inference=*/false,
2189 /*include_tensor_values=*/true));
2190 EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2191 }
2192
2193 // InferStatically with aggressive shape inference can infer output value of
2194 // StridedSlice with ShrinkAxisMask.
2195 {
2196 GraphProperties properties(item);
2197 TF_ASSERT_OK(properties.InferStatically(
2198 /*assume_valid_feeds=*/false,
2199 /*aggressive_shape_inference=*/true,
2200 /*include_tensor_values=*/true));
2201 EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2202 const auto slice_value =
2203 properties.GetOutputProperties("slice").at(0).value();
2204 ExpectTensorValues({5}, slice_value);
2205 }
2206 }
2207
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2208 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2209 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2210 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2211 Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2212 Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2213
2214 Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2215 Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2216 Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2217 Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2218 Output c_plus_b_plus_2a =
2219 ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2220
2221 GrapplerItem item;
2222 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2223 GraphProperties properties(item);
2224 TF_ASSERT_OK(properties.InferStatically(
2225 /*assume_valid_feeds=*/false,
2226 /*aggressive_shape_inference=*/true,
2227 /*include_tensor_values=*/true));
2228
2229 // Check output shapes and values.
2230 const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2231 EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2232 EXPECT_TRUE(a_plus_one_prop.has_value());
2233 ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2234
2235 const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2236 EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2237 EXPECT_TRUE(a_plus_a_prop.has_value());
2238 ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2239
2240 const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2241 EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2242 EXPECT_TRUE(b_plus_2a_prop.has_value());
2243 ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2244
2245 const auto& c_plus_b_plus_2a_prop =
2246 properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2247 EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2248 EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2249 ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2250 }
2251
TEST_F(GraphPropertiesTest,ShapeAnnotation)2252 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2253 GrapplerItem item;
2254 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2255 .Attr("dtype", DT_FLOAT)
2256 .Attr("shape", PartialTensorShape({-1, -1}))
2257 .Finalize(item.graph.add_node()));
2258 // Annotate shapes.
2259 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2260 .Attr("dtype", DT_FLOAT)
2261 .Attr("_same_output_for_iterations", true)
2262 .Attr("_output_shape_vector", {TensorShape({5, 7})})
2263 .Input("Input", 0, DT_FLOAT)
2264 .Finalize(item.graph.add_node()));
2265 {
2266 GraphProperties properties(item);
2267 // Without aggressive_shape_inference, ignore annotated information.
2268 TF_ASSERT_OK(properties.InferStatically(
2269 /*assume_valid_feeds=*/false,
2270 /*aggressive_shape_inference=*/false,
2271 /*include_tensor_values=*/true));
2272 const auto props = properties.GetOutputProperties("Identity");
2273 EXPECT_EQ(1, props.size());
2274 const OpInfo::TensorProperties& prop = props[0];
2275 EXPECT_EQ(DT_FLOAT, prop.dtype());
2276 EXPECT_EQ(2, prop.shape().dim_size());
2277 // Get unknown shapes without using annotated information.
2278 EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2279 }
2280 {
2281 GraphProperties properties(item);
2282 // Use annotated information.
2283 TF_ASSERT_OK(properties.InferStatically(
2284 /*assume_valid_feeds=*/false,
2285 /*aggressive_shape_inference=*/true,
2286 /*include_tensor_values=*/true));
2287 const auto props = properties.GetOutputProperties("Identity");
2288 EXPECT_EQ(1, props.size());
2289 const OpInfo::TensorProperties& prop = props[0];
2290 EXPECT_EQ(DT_FLOAT, prop.dtype());
2291 EXPECT_EQ(2, prop.shape().dim_size());
2292 // Update output shape using annotated shapes.
2293 EXPECT_EQ("float: [5,7]", PropToString(prop));
2294 }
2295 }
2296
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2297 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2298 GrapplerItem item;
2299 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2300 .Attr("dtype", DT_FLOAT)
2301 .Attr("shape", PartialTensorShape({-1, 100}))
2302 .Finalize(item.graph.add_node()));
2303 // Annotate shapes.
2304 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2305 .Attr("dtype", DT_FLOAT)
2306 .Attr("_same_output_for_iterations", true)
2307 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2308 .Input("Input", 0, DT_FLOAT)
2309 .Finalize(item.graph.add_node()));
2310 GraphProperties properties(item);
2311 // Use annotated information.
2312 TF_ASSERT_OK(properties.InferStatically(
2313 /*assume_valid_feeds=*/false,
2314 /*aggressive_shape_inference=*/true,
2315 /*include_tensor_values=*/true));
2316 const auto props = properties.GetOutputProperties("Identity");
2317 EXPECT_EQ(1, props.size());
2318 const OpInfo::TensorProperties& prop = props[0];
2319 EXPECT_EQ(DT_FLOAT, prop.dtype());
2320 EXPECT_EQ(2, prop.shape().dim_size());
2321 // Compatible shapes. Update output shape using annotated shapes.
2322 EXPECT_EQ("float: [10,100]", PropToString(prop));
2323 }
2324
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2325 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2326 GrapplerItem item;
2327 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2328 .Attr("dtype", DT_FLOAT)
2329 .Attr("shape", PartialTensorShape({-1, 100}))
2330 .Finalize(item.graph.add_node()));
2331 // Annotate shapes.
2332 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2333 .Attr("dtype", DT_FLOAT)
2334 .Attr("_same_output_for_iterations", true)
2335 .Attr("_output_shape_vector", {TensorShape({10, 10})})
2336 .Input("Input", 0, DT_FLOAT)
2337 .Finalize(item.graph.add_node()));
2338 GraphProperties properties(item);
2339 // Use annotated information.
2340 TF_ASSERT_OK(properties.InferStatically(
2341 /*assume_valid_feeds=*/false,
2342 /*aggressive_shape_inference=*/true,
2343 /*include_tensor_values=*/true));
2344 const auto props = properties.GetOutputProperties("Identity");
2345 EXPECT_EQ(1, props.size());
2346 const OpInfo::TensorProperties& prop = props[0];
2347 EXPECT_EQ(DT_FLOAT, prop.dtype());
2348 EXPECT_EQ(2, prop.shape().dim_size());
2349 // Incompatible shapes. Do not use annotated shapes.
2350 EXPECT_EQ("float: [-1,100]", PropToString(prop));
2351 }
2352
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2353 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2354 GrapplerItem item;
2355 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2356 .Attr("dtype", DT_FLOAT)
2357 .Attr("shape", PartialTensorShape({-1, -1}))
2358 .Finalize(item.graph.add_node()));
2359 // Annotate shapes.
2360 TF_ASSERT_OK(
2361 NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2362 .Attr("_same_output_for_iterations", true)
2363 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2364 .Input("Input", 0, DT_FLOAT)
2365 .Finalize(item.graph.add_node()));
2366 GraphProperties properties(item);
2367 // Use annotated information.
2368 TF_ASSERT_OK(properties.InferStatically(
2369 /*assume_valid_feeds=*/false,
2370 /*aggressive_shape_inference=*/true,
2371 /*include_tensor_values=*/true));
2372 const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2373 EXPECT_EQ(1, props.size());
2374 const OpInfo::TensorProperties& prop = props[0];
2375 EXPECT_EQ(DT_FLOAT, prop.dtype());
2376 EXPECT_EQ(2, prop.shape().dim_size());
2377 EXPECT_EQ("float: [10,100]", PropToString(prop));
2378 }
2379
TEST_F(GraphPropertiesTest,PartitionedCallOp)2380 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2381 Scope root = Scope::NewRootScope().ExitOnError();
2382 FunctionDefLibrary library;
2383 FunctionDef called_func = FunctionDefHelper::Create(
2384 "identity_function",
2385 /*in_def=*/{"arg0: int32"},
2386 /*out_def=*/{"ret0: int32"},
2387 /*attr_def=*/{},
2388 {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2389 /*ret_def=*/{{"ret0", "Identity:output:0"}});
2390 *library.add_function() = called_func;
2391 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2392
2393 Output in = ops::Const(root, {3, 1, 2, 0});
2394 NameAttrList b_name_attr;
2395 b_name_attr.set_name("identity_function");
2396 ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2397 b_name_attr);
2398
2399 GrapplerItem item;
2400 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2401
2402 GraphProperties properties(item);
2403 TF_ASSERT_OK(properties.InferStatically(
2404 /*assume_valid_feeds=*/true,
2405 /*aggressive_shape_inference=*/false,
2406 /*include_tensor_values=*/true));
2407
2408 EXPECT_EQ("int32: [4]",
2409 PropToString(properties.GetOutputProperties("identity_call")[0]));
2410 }
2411
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2412 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2413 auto f = FunctionDefHelper::Create(
2414 // Name
2415 "FunctionWhichAdds",
2416 // Inputs
2417 {"arg0: int32", "arg1: int32"},
2418 // Outputs
2419 {"ret0: int32"},
2420 /*attr_def=*/{},
2421 // Nodes
2422 {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2423 /*ret_def=*/{{"ret0", "a:z:0"}});
2424
2425 FunctionDefLibrary function_lib;
2426 function_lib.add_function()->Swap(&f);
2427 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2428 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2429
2430 PartialTensorShape input_shape({2, 2, -1});
2431 Output in1 =
2432 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2433 Output in2 =
2434 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2435 NameAttrList b_name_attr;
2436 b_name_attr.set_name("FunctionWhichAdds");
2437 ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2438 b_name_attr);
2439
2440 GrapplerItem item;
2441 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2442
2443 GraphProperties properties(item);
2444 TF_ASSERT_OK(properties.InferStatically(
2445 /*assume_valid_feeds=*/true,
2446 /*aggressive_shape_inference=*/false,
2447 /*include_tensor_values=*/true));
2448
2449 EXPECT_EQ("int32: [2,2,-1]",
2450 PropToString(properties.GetOutputProperties("add_call")[0]));
2451 }
2452
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2453 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2454 // A function, which we cannot infer output shape statically.
2455 auto f = FunctionDefHelper::Create(
2456 // Name
2457 "FuncShapeCannotBeInferred",
2458 // Inputs
2459 {},
2460 // Outputs
2461 {"output: float"},
2462 // Attrs
2463 {},
2464 // Nodes
2465 {
2466 // Placeholder without shape attr; unknown rank.
2467 {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2468 },
2469 // Returns
2470 {{"output", "p:output:0"}});
2471 FunctionDefLibrary function_lib;
2472 function_lib.add_function()->Swap(&f);
2473 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2474 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2475 tensorflow::Node* func_op;
2476 TensorShapeProto output_shape;
2477 output_shape.set_unknown_rank(false);
2478 output_shape.add_dim()->set_size(1);
2479 output_shape.add_dim()->set_size(2);
2480 output_shape.add_dim()->set_size(3);
2481 output_shape.add_dim()->set_size(4);
2482 // The function node, f, includes shape annotation.
2483 TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2484 s.graph()->op_registry())
2485 .Attr("_execution_count", 1)
2486 .Attr("_same_output_for_iterations", true)
2487 .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2488 .Attr("_output_shape_vector", {output_shape})
2489 .Finalize(s.graph(), &func_op));
2490 GrapplerItem item;
2491 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2492
2493 // InferStatically with aggressive_shape_inference would fail to infer
2494 // the output shape of the node f.
2495 {
2496 GraphProperties properties(item);
2497 TF_ASSERT_OK(properties.InferStatically(
2498 /*assume_valid_feeds=*/false,
2499 /*aggressive_shape_inference=*/false,
2500 /*include_tensor_values=*/false));
2501 const auto out_props = properties.GetOutputProperties("f");
2502 const OpInfo::TensorProperties out_prop0 = out_props[0];
2503 EXPECT_EQ("float: ?", PropToString(out_prop0));
2504 }
2505 // With aggressive_shape_inference, it skips recursively callying
2506 // InferStatically for the function node and outputs annotated shape info.
2507 {
2508 GraphProperties properties(item);
2509 TF_ASSERT_OK(properties.InferStatically(
2510 /*assume_valid_feeds=*/false,
2511 /*aggressive_shape_inference=*/true,
2512 /*include_tensor_values=*/true));
2513 const auto out_props = properties.GetOutputProperties("f");
2514 const OpInfo::TensorProperties out_prop0 = out_props[0];
2515 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2516 }
2517 }
2518
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2519 TEST_F(GraphPropertiesTest,
2520 SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2521 GrapplerItem item;
2522 // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2523 // used for two reshape ops. One reshape op is segment_ids input to
2524 // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2525 // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2526 // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2527 // SymbolicShapeRefiner.
2528 // This dim value (10), however, should not affect the other reshape op, even
2529 // though it shares the shape input; -1 in the shape input of Reshape op is
2530 // a special case of computed output dim, not unknown dim.
2531 // data and num_segments are inputs to UnsortedSegmenetSum.
2532
2533 TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2534 .Attr("dtype", DT_FLOAT)
2535 .Attr("shape", TensorShape({10, 10, 10, 10}))
2536 .Finalize(item.graph.add_node()));
2537 Tensor num_segments(DT_INT32, TensorShape({}));
2538 // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2539 // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2540 // output to form shape vector [-1, 10] to Reshape.
2541 test::FillIota<int>(&num_segments, 3);
2542 TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2543 .Attr("dtype", DT_INT32)
2544 .Attr("value", num_segments)
2545 .Finalize(item.graph.add_node()));
2546 Tensor minus_one(DT_INT32, TensorShape({1}));
2547 test::FillIota<int>(&minus_one, -1);
2548 TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2549 .Attr("dtype", DT_INT32)
2550 .Attr("value", minus_one)
2551 .Finalize(item.graph.add_node()));
2552 Tensor plus_ten(DT_INT32, TensorShape({1}));
2553 test::FillIota<int>(&plus_ten, 10);
2554 TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2555 .Attr("dtype", DT_INT32)
2556 .Attr("value", plus_ten)
2557 .Finalize(item.graph.add_node()));
2558 Tensor axis(DT_INT32, TensorShape({}));
2559 test::FillIota<int>(&axis, -1);
2560 TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2561 .Attr("dtype", DT_INT32)
2562 .Attr("value", axis)
2563 .Finalize(item.graph.add_node()));
2564 std::vector<NodeDefBuilder::NodeOut> inputs(2);
2565 inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2566 inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2567 TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2568 .Input(inputs)
2569 .Input("axis", 0, DT_INT32)
2570 .Attr("N", 2)
2571 .Attr("T", DT_INT32)
2572 .Attr("Tidx", DT_INT32)
2573 .Finalize(item.graph.add_node()));
2574 TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2575 .Attr("dtype", DT_FLOAT)
2576 .Finalize(item.graph.add_node()));
2577 TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2578 .Input("segment_ids_", 0, DT_FLOAT)
2579 .Attr("T", DT_FLOAT)
2580 .Attr("out_type", DT_INT32)
2581 .Finalize(item.graph.add_node()));
2582 TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2583 .Input("segment_ids_", 0, DT_FLOAT)
2584 .Input("concat", 0, DT_INT32)
2585 .Attr("T", DT_FLOAT)
2586 .Attr("Tshape", DT_INT32)
2587 .Finalize(item.graph.add_node()));
2588 // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2589 // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2590 // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2591 // assign 10 from data shape to the unknown dim in segment_ids.
2592 TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2593 .Input("data", 0, DT_FLOAT)
2594 .Input("segment_ids", 0, DT_INT32)
2595 .Input("num_segments", 0, DT_INT32)
2596 .Attr("T", DT_FLOAT)
2597 .Attr("Tindices", DT_INT32)
2598 .Attr("Tnumsegments", DT_INT32)
2599 .Finalize(item.graph.add_node()));
2600 // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2601 // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2602 TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2603 .Attr("dtype", DT_FLOAT)
2604 .Finalize(item.graph.add_node()));
2605 TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2606 .Input("x1", 0, DT_FLOAT)
2607 .Input("concat", 0, DT_INT32)
2608 .Attr("T", DT_FLOAT)
2609 .Attr("Tshape", DT_INT32)
2610 .Finalize(item.graph.add_node()));
2611
2612 GraphProperties properties(item);
2613 TF_ASSERT_OK(properties.InferStatically(true));
2614 const auto& y1_output_properties = properties.GetOutputProperties("y1");
2615 // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2616 // The first dimension should not be 10.
2617 EXPECT_EQ(y1_output_properties.size(), 1);
2618 EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2619 EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2620 EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2621 }
2622
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2623 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2624 // Make a dummy InferenceContext.
2625 NodeDef node_def;
2626 OpRegistrationData op_reg_data;
2627 OpDefBuilder b("dummy");
2628 CHECK(b.Finalize(&op_reg_data).ok());
2629 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2630 input_handle_shapes_and_types;
2631 InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2632 /*input_shapes=*/{},
2633 /*input_tensors=*/{},
2634 /*input_tensors_as_shapes=*/{},
2635 std::move(input_handle_shapes_and_types));
2636
2637 // ShapeHandles for testing.
2638 ShapeHandle fully_defined_vector = ic.MakeShape(
2639 {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2640 ShapeHandle vector_with_unknown = ic.MakeShape(
2641 {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2642 // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2643 // graph_properties.cc
2644 ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2645 {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2646 ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2647
2648 // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2649 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2650 &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2651 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2652 &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2653
2654 // Non-integer data type.
2655 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2656 &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2657
2658 // tensor_as_shape including Unknown or UnknownFromConst.
2659 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2660 &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2661 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2662 &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2663 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2664 &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2665 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2666 &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2667
2668 // shape rank > 1.
2669 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2670 &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2671 }
2672 } // namespace
2673 } // namespace grappler
2674 } // namespace tensorflow
2675