1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/protobuf/device_properties.pb.h"
27
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
31
32 class LayoutOptimizerTest : public ::testing::Test {
33 protected:
SetUp()34 void SetUp() override {
35 DeviceProperties device_properties;
36 device_properties.set_type("GPU");
37 device_properties.mutable_environment()->insert({"architecture", "6"});
38 virtual_cluster_.reset(new VirtualCluster({{"/GPU:0", device_properties}}));
39 }
40
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)41 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
42 const string& padding) {
43 return SimpleConv2D(s, input_size, filter_size, padding, "");
44 }
45
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)46 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
47 const string& padding, const string& device) {
48 int batch_size = 8;
49 int input_height = input_size;
50 int input_width = input_size;
51 int input_depth = 3;
52 int filter_count = 2;
53 int stride = 1;
54 TensorShape input_shape(
55 {batch_size, input_height, input_width, input_depth});
56 Tensor input_data(DT_FLOAT, input_shape);
57 test::FillIota<float>(&input_data, 1.0f);
58 Output input =
59 ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
60
61 TensorShape filter_shape(
62 {filter_size, filter_size, input_depth, filter_count});
63 Tensor filter_data(DT_FLOAT, filter_shape);
64 test::FillIota<float>(&filter_data, 1.0f);
65 Output filter =
66 ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
67
68 Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
69 filter, {1, stride, stride, 1}, padding);
70 return conv;
71 }
72
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)73 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
74 int filter_size, const string& padding) {
75 return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true);
76 }
77
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool const_input_size)78 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
79 int filter_size, const string& padding,
80 bool const_input_size) {
81 int batch_size = 128;
82 int input_height = input_size;
83 int input_width = input_size;
84 int input_depth = 3;
85 int filter_count = 2;
86 int stride = 1;
87 TensorShape input_sizes_shape({4});
88 Tensor input_data(DT_INT32, input_sizes_shape);
89 test::FillValues<int>(&input_data,
90 {batch_size, input_height, input_width, input_depth});
91 Output input_sizes =
92 ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
93
94 TensorShape filter_shape(
95 {filter_size, filter_size, input_depth, filter_count});
96 Tensor filter_data(DT_FLOAT, filter_shape);
97 test::FillIota<float>(&filter_data, 1.0f);
98 Output filter =
99 ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
100
101 int output_height = input_height;
102 int output_width = input_width;
103 TensorShape output_shape(
104 {batch_size, output_height, output_width, filter_count});
105 Tensor output_data(DT_FLOAT, output_shape);
106 test::FillIota<float>(&output_data, 1.0f);
107 Output output =
108 ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
109
110 Output conv_backprop_input;
111 Output input_sizes_i =
112 ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
113 if (const_input_size) {
114 conv_backprop_input = ops::Conv2DBackpropInput(
115 s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
116 {1, stride, stride, 1}, padding);
117 } else {
118 conv_backprop_input = ops::Conv2DBackpropInput(
119 s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
120 {1, stride, stride, 1}, padding);
121 }
122 return conv_backprop_input;
123 }
124
GetAttrValue(const NodeDef & node)125 Tensor GetAttrValue(const NodeDef& node) {
126 Tensor tensor;
127 CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
128 return tensor;
129 }
130
SimpleFusedBatchNormGrad(tensorflow::Scope * s,bool is_training)131 Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
132 int batch_size = 16;
133 int input_height = 8;
134 int input_width = 8;
135 int input_channels = 3;
136 TensorShape shape({batch_size, input_height, input_width, input_channels});
137 Tensor data(DT_FLOAT, shape);
138 test::FillIota<float>(&data, 1.0f);
139 Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
140 Output y_backprop =
141 ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
142
143 TensorShape shape_vector({input_channels});
144 Tensor data_vector(DT_FLOAT, shape_vector);
145 test::FillIota<float>(&data_vector, 2.0f);
146 Output scale =
147 ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
148 Output reserve1 =
149 ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
150 Output reserve2 =
151 ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
152
153 ops::FusedBatchNormGrad::Attrs attrs;
154 attrs.is_training_ = is_training;
155 auto output =
156 ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
157 x, scale, reserve1, reserve2, attrs);
158 return output.x_backprop;
159 }
160
161 std::unique_ptr<VirtualCluster> virtual_cluster_;
162 };
163
TEST_F(LayoutOptimizerTest,Conv2DBackpropInput)164 TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
165 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
166 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME");
167 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
168 GrapplerItem item;
169 TF_CHECK_OK(s.ToGraphDef(&item.graph));
170 LayoutOptimizer optimizer;
171 GraphDef output;
172
173 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
174 NodeMap node_map(&output);
175 string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
176 auto input_sizes_node = node_map.GetNode(input_name);
177 CHECK(input_sizes_node);
178 auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
179 CHECK(conv2d_backprop_node);
180 EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
181 auto input_sizes = GetAttrValue(*input_sizes_node);
182 Tensor input_sizes_expected(DT_INT32, {4});
183 test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
184 test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
185 }
186
TEST_F(LayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)187 TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
188 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
189 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false);
190 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
191 GrapplerItem item;
192 TF_CHECK_OK(s.ToGraphDef(&item.graph));
193 LayoutOptimizer optimizer;
194 GraphDef output;
195
196 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
197 NodeMap node_map(&output);
198 auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
199 CHECK(conv2d_backprop_node);
200 EXPECT_EQ(conv2d_backprop_node->input(0),
201 "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
202 auto input_sizes_node = node_map.GetNode(
203 "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
204 CHECK(input_sizes_node);
205 EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
206 EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
207 }
208
TEST_F(LayoutOptimizerTest,FilterSizeIsOne)209 TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
210 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211 auto conv = SimpleConv2D(&s, 2, 1, "SAME");
212 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
213 GrapplerItem item;
214 TF_CHECK_OK(s.ToGraphDef(&item.graph));
215 LayoutOptimizer optimizer;
216 GraphDef output;
217 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
218 NodeMap node_map(&output);
219 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
220 }
221
TEST_F(LayoutOptimizerTest,FilterSizeNotOne)222 TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
223 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
224 auto conv = SimpleConv2D(&s, 2, 1, "SAME");
225 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
226 GrapplerItem item;
227 TF_CHECK_OK(s.ToGraphDef(&item.graph));
228 LayoutOptimizer optimizer;
229 GraphDef output;
230 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
231 NodeMap node_map(&output);
232 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
233 }
234
TEST_F(LayoutOptimizerTest,EqualSizeWithValidPadding)235 TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
236 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
237 auto conv = SimpleConv2D(&s, 2, 2, "VALID");
238 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
239 GrapplerItem item;
240 TF_CHECK_OK(s.ToGraphDef(&item.graph));
241 LayoutOptimizer optimizer;
242 GraphDef output;
243 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
244 NodeMap node_map(&output);
245 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
246 }
247
TEST_F(LayoutOptimizerTest,EqualSizeWithSamePadding)248 TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
249 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
250 auto conv = SimpleConv2D(&s, 2, 2, "SAME");
251 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
252 GrapplerItem item;
253 TF_CHECK_OK(s.ToGraphDef(&item.graph));
254 LayoutOptimizer optimizer;
255 GraphDef output;
256 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
257 NodeMap node_map(&output);
258 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
259 }
260
TEST_F(LayoutOptimizerTest,NotEqualSizeWithValidPadding)261 TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
262 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
263 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
264 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
265 GrapplerItem item;
266 TF_CHECK_OK(s.ToGraphDef(&item.graph));
267 LayoutOptimizer optimizer;
268 GraphDef output;
269 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
270 NodeMap node_map(&output);
271 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
272 }
273
TEST_F(LayoutOptimizerTest,Pad)274 TEST_F(LayoutOptimizerTest, Pad) {
275 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
276 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
277 auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
278 auto p = ops::Pad(s.WithOpName("p"), conv, c);
279 auto o = ops::Identity(s.WithOpName("o"), p);
280 GrapplerItem item;
281 TF_CHECK_OK(s.ToGraphDef(&item.graph));
282 LayoutOptimizer optimizer;
283 GraphDef output;
284 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
285 NodeMap node_map(&output);
286
287 auto pad = node_map.GetNode("p");
288 EXPECT_EQ(pad->input(0), "Conv2D");
289
290 auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
291 EXPECT_TRUE(pad_const);
292 EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
293 Tensor tensor;
294 EXPECT_TRUE(
295 tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
296 Tensor tensor_expected(DT_INT32, {4, 2});
297 test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
298 test::ExpectTensorEqual<int>(tensor_expected, tensor);
299 }
300
TEST_F(LayoutOptimizerTest,Connectivity)301 TEST_F(LayoutOptimizerTest, Connectivity) {
302 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
303 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
304 auto i1 = ops::Identity(s.WithOpName("i1"), conv);
305 auto i2 = ops::Identity(s.WithOpName("i2"), i1);
306 auto i3 = ops::Identity(s.WithOpName("i3"), i2);
307 GrapplerItem item;
308 TF_CHECK_OK(s.ToGraphDef(&item.graph));
309 // Make the graph not in topological order to test the handling of multi-hop
310 // connectivity (here we say two nodes are connected if all nodes in the
311 // middle are layout agnostic). If the graph is already in topological order,
312 // the problem is easier, where layout optimizer only needs to check
313 // single-hop connectivity.
314 NodeMap node_map_original(&item.graph);
315 auto node_i1 = node_map_original.GetNode("i1");
316 auto node_i2 = node_map_original.GetNode("i2");
317 node_i2->Swap(node_i1);
318 LayoutOptimizer optimizer;
319 GraphDef output;
320 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
321 NodeMap node_map_output(&output);
322 auto node_i2_output = node_map_output.GetNode("i2");
323 // Layout optimizer should process i2, as it detects i2 is connected with the
324 // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
325 // directly connected to the Conv2D node. The two added transposes between
326 // i1 and i2 should cancel each other, and as a result i2 is directly
327 // connected to i1.
328 EXPECT_EQ(node_i2_output->input(0), "i1");
329 }
330
TEST_F(LayoutOptimizerTest,ConnectivityBinaryOpWithInputScalarAnd4D)331 TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
332 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
333 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
334 auto i1 = ops::Identity(s.WithOpName("i1"), conv);
335 auto i2 = ops::Identity(s.WithOpName("i2"), i1);
336 auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
337 auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
338 auto i3 = ops::Identity(s.WithOpName("i3"), sub);
339 auto i4 = ops::Identity(s.WithOpName("i4"), i3);
340 auto i5 = ops::Identity(s.WithOpName("i5"), i4);
341 auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
342 auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
343 auto i6 = ops::Identity(s.WithOpName("i6"), mul);
344 GrapplerItem item;
345 TF_CHECK_OK(s.ToGraphDef(&item.graph));
346 // Make the graph not in topological order to test the handling of multi-hop
347 // connectivity (here we say two nodes are connected if all nodes in the
348 // middle are layout agnostic). If the graph is already in topological order,
349 // the problem is easier, where layout optimizer only needs to check
350 // single-hop connectivity.
351 NodeMap node_map_original(&item.graph);
352 auto node_i1 = node_map_original.GetNode("i1");
353 auto node_mul = node_map_original.GetNode("mul");
354 node_mul->Swap(node_i1);
355 LayoutOptimizer optimizer;
356 GraphDef output;
357 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
358 NodeMap node_map_output(&output);
359 auto mul_node = node_map_output.GetNode("mul");
360 EXPECT_EQ(mul_node->input(0), "scalar_mul");
361 EXPECT_EQ(mul_node->input(1), "i5");
362 }
363
TEST_F(LayoutOptimizerTest,PreserveFetch)364 TEST_F(LayoutOptimizerTest, PreserveFetch) {
365 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
366 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
367 auto i = ops::Identity(s.WithOpName("i"), conv);
368 GrapplerItem item;
369 item.fetch.push_back("Conv2D");
370 TF_CHECK_OK(s.ToGraphDef(&item.graph));
371 LayoutOptimizer optimizer;
372 GraphDef output;
373 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
374 NodeMap node_map(&output);
375 auto conv_node = node_map.GetNode("Conv2D");
376 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
377 }
378
TEST_F(LayoutOptimizerTest,EmptyDevice)379 TEST_F(LayoutOptimizerTest, EmptyDevice) {
380 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
381 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
382 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
383 GrapplerItem item;
384 TF_CHECK_OK(s.ToGraphDef(&item.graph));
385 LayoutOptimizer optimizer;
386 GraphDef output;
387 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
388 NodeMap node_map(&output);
389 auto conv_node = node_map.GetNode("Conv2D");
390 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
391 }
392
TEST_F(LayoutOptimizerTest,GPUDevice)393 TEST_F(LayoutOptimizerTest, GPUDevice) {
394 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
395 auto conv =
396 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
397 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
398 GrapplerItem item;
399 TF_CHECK_OK(s.ToGraphDef(&item.graph));
400 LayoutOptimizer optimizer;
401 GraphDef output;
402 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
403 NodeMap node_map(&output);
404 auto conv_node = node_map.GetNode("Conv2D");
405 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
406 }
407
TEST_F(LayoutOptimizerTest,CPUDeviceLowercase)408 TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
409 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
410 auto conv =
411 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
412 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
413 GrapplerItem item;
414 TF_CHECK_OK(s.ToGraphDef(&item.graph));
415 LayoutOptimizer optimizer;
416 GraphDef output;
417 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
418 NodeMap node_map(&output);
419 auto conv_node = node_map.GetNode("Conv2D");
420 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
421 }
422
TEST_F(LayoutOptimizerTest,CPUDeviceUppercase)423 TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
424 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
425 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
426 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
427 GrapplerItem item;
428 TF_CHECK_OK(s.ToGraphDef(&item.graph));
429 LayoutOptimizer optimizer;
430 GraphDef output;
431 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
432 NodeMap node_map(&output);
433 auto conv_node = node_map.GetNode("Conv2D");
434 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
435 }
436
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingTrue)437 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
438 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
439 auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
440 Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
441 GrapplerItem item;
442 TF_CHECK_OK(s.ToGraphDef(&item.graph));
443 LayoutOptimizer optimizer;
444 GraphDef output;
445 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
446 NodeMap node_map(&output);
447 auto conv_node = node_map.GetNode("FusedBatchNormGrad");
448 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
449 }
450
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingFalse)451 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
452 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
453 auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
454 Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
455 GrapplerItem item;
456 TF_CHECK_OK(s.ToGraphDef(&item.graph));
457 LayoutOptimizer optimizer;
458 GraphDef output;
459 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
460 NodeMap node_map(&output);
461 auto conv_node = node_map.GetNode("FusedBatchNormGrad");
462 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
463 }
464
TEST_F(LayoutOptimizerTest,SplitDimC)465 TEST_F(LayoutOptimizerTest, SplitDimC) {
466 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
467 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
468 auto c = ops::Const(s.WithOpName("c"), 3, {});
469 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
470 auto i = ops::Identity(s.WithOpName("i"), split[0]);
471 GrapplerItem item;
472 TF_CHECK_OK(s.ToGraphDef(&item.graph));
473 LayoutOptimizer optimizer;
474 GraphDef output;
475 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
476 NodeMap node_map(&output);
477 auto split_node = node_map.GetNode("split");
478 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
479 EXPECT_EQ(split_node->input(1), "Conv2D");
480 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
481 EXPECT_EQ(split_const->op(), "Const");
482 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
483 }
484
TEST_F(LayoutOptimizerTest,SplitDimH)485 TEST_F(LayoutOptimizerTest, SplitDimH) {
486 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
487 auto conv = SimpleConv2D(&s, 6, 2, "SAME");
488 auto c = ops::Const(s.WithOpName("c"), 1, {});
489 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
490 auto i = ops::Identity(s.WithOpName("i"), split[0]);
491 GrapplerItem item;
492 TF_CHECK_OK(s.ToGraphDef(&item.graph));
493 LayoutOptimizer optimizer;
494 GraphDef output;
495 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
496 NodeMap node_map(&output);
497 auto split_node = node_map.GetNode("split");
498 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
499 EXPECT_EQ(split_node->input(1), "Conv2D");
500 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
501 EXPECT_EQ(split_const->op(), "Const");
502 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
503 }
504
TEST_F(LayoutOptimizerTest,SplitDimW)505 TEST_F(LayoutOptimizerTest, SplitDimW) {
506 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
507 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
508 auto c = ops::Const(s.WithOpName("c"), 2, {});
509 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
510 auto i = ops::Identity(s.WithOpName("i"), split[0]);
511 GrapplerItem item;
512 TF_CHECK_OK(s.ToGraphDef(&item.graph));
513 LayoutOptimizer optimizer;
514 GraphDef output;
515 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
516 NodeMap node_map(&output);
517 auto split_node = node_map.GetNode("split");
518 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
519 EXPECT_EQ(split_node->input(1), "Conv2D");
520 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
521 EXPECT_EQ(split_const->op(), "Const");
522 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
523 }
524
TEST_F(LayoutOptimizerTest,SplitDimN)525 TEST_F(LayoutOptimizerTest, SplitDimN) {
526 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
527 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
528 auto c = ops::Const(s.WithOpName("c"), 0, {});
529 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
530 auto i = ops::Identity(s.WithOpName("i"), split[0]);
531 GrapplerItem item;
532 TF_CHECK_OK(s.ToGraphDef(&item.graph));
533 LayoutOptimizer optimizer;
534 GraphDef output;
535 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
536 NodeMap node_map(&output);
537 auto split_node = node_map.GetNode("split");
538 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
539 EXPECT_EQ(split_node->input(1), "Conv2D");
540 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
541 EXPECT_EQ(split_const->op(), "Const");
542 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
543 }
544
TEST_F(LayoutOptimizerTest,SplitNonConstDim)545 TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
546 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
547 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
548 auto c = ops::Const(s.WithOpName("c"), 0, {});
549 auto i1 = ops::Identity(s.WithOpName("i1"), c);
550 auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
551 auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
552 GrapplerItem item;
553 TF_CHECK_OK(s.ToGraphDef(&item.graph));
554 LayoutOptimizer optimizer;
555 GraphDef output;
556 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
557 NodeMap node_map(&output);
558 auto split_node = node_map.GetNode("split");
559 EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
560 EXPECT_EQ(split_node->input(1), "Conv2D");
561 auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
562 EXPECT_EQ(map_node->op(), "DataFormatDimMap");
563 EXPECT_EQ(map_node->input(0), "i1");
564 }
565
TEST_F(LayoutOptimizerTest,SplitSamePortToMultipleInputsOfSameNode)566 TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
567 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
568 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
569 auto axis = ops::Const(s.WithOpName("axis"), 3);
570 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
571 auto concat =
572 ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
573 auto o = ops::Identity(s.WithOpName("o"), concat);
574 GrapplerItem item;
575 TF_CHECK_OK(s.ToGraphDef(&item.graph));
576 LayoutOptimizer optimizer;
577 GraphDef output;
578 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
579 NodeMap node_map(&output);
580 auto concat_node = node_map.GetNode("concat");
581 EXPECT_EQ(concat_node->input(0), "split:1");
582 EXPECT_EQ(concat_node->input(1), "split:1");
583 EXPECT_EQ(concat_node->input(2), "split:1");
584 EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
585 auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
586 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
587 }
588
TEST_F(LayoutOptimizerTest,ConcatDimH)589 TEST_F(LayoutOptimizerTest, ConcatDimH) {
590 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
591 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
592 auto axis = ops::Const(s.WithOpName("axis"), 1);
593 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
594 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
595 auto o = ops::Identity(s.WithOpName("o"), concat);
596 GrapplerItem item;
597 TF_CHECK_OK(s.ToGraphDef(&item.graph));
598 LayoutOptimizer optimizer;
599 GraphDef output;
600 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
601 NodeMap node_map(&output);
602 auto concat_node = node_map.GetNode("concat");
603 EXPECT_EQ(concat_node->input(0), "split");
604 EXPECT_EQ(concat_node->input(1), "split:1");
605 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
606 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
607 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
608 }
609
TEST_F(LayoutOptimizerTest,ConcatNonConst)610 TEST_F(LayoutOptimizerTest, ConcatNonConst) {
611 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
612 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
613 auto axis = ops::Const(s.WithOpName("axis"), 1);
614 auto i = ops::Identity(s.WithOpName("i"), axis);
615 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
616 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
617 auto o = ops::Identity(s.WithOpName("o"), concat);
618 GrapplerItem item;
619 TF_CHECK_OK(s.ToGraphDef(&item.graph));
620 LayoutOptimizer optimizer;
621 GraphDef output;
622 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
623 NodeMap node_map(&output);
624 auto concat_node = node_map.GetNode("concat");
625 EXPECT_EQ(concat_node->input(0), "split");
626 EXPECT_EQ(concat_node->input(1), "split:1");
627 EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
628 auto concat_dim =
629 node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
630 EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
631 EXPECT_EQ(concat_dim->input(0), "i");
632 }
633
TEST_F(LayoutOptimizerTest,ConcatDimW)634 TEST_F(LayoutOptimizerTest, ConcatDimW) {
635 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
636 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
637 auto axis = ops::Const(s.WithOpName("axis"), 2);
638 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
639 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
640 auto o = ops::Identity(s.WithOpName("o"), concat);
641 GrapplerItem item;
642 TF_CHECK_OK(s.ToGraphDef(&item.graph));
643 LayoutOptimizer optimizer;
644 GraphDef output;
645 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
646 NodeMap node_map(&output);
647 auto concat_node = node_map.GetNode("concat");
648 EXPECT_EQ(concat_node->input(0), "split");
649 EXPECT_EQ(concat_node->input(1), "split:1");
650 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
651 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
652 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
653 }
654
TEST_F(LayoutOptimizerTest,ConcatDimN)655 TEST_F(LayoutOptimizerTest, ConcatDimN) {
656 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
657 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
658 auto axis = ops::Const(s.WithOpName("axis"), 0);
659 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
660 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
661 auto o = ops::Identity(s.WithOpName("o"), concat);
662 GrapplerItem item;
663 TF_CHECK_OK(s.ToGraphDef(&item.graph));
664 LayoutOptimizer optimizer;
665 GraphDef output;
666 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
667 NodeMap node_map(&output);
668 auto concat_node = node_map.GetNode("concat");
669 EXPECT_EQ(concat_node->input(0), "split");
670 EXPECT_EQ(concat_node->input(1), "split:1");
671 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
672 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
673 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
674 }
675
TEST_F(LayoutOptimizerTest,ConcatDimC)676 TEST_F(LayoutOptimizerTest, ConcatDimC) {
677 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
678 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
679 auto axis = ops::Const(s.WithOpName("axis"), 3);
680 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
681 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
682 auto o = ops::Identity(s.WithOpName("o"), concat);
683 GrapplerItem item;
684 TF_CHECK_OK(s.ToGraphDef(&item.graph));
685 LayoutOptimizer optimizer;
686 GraphDef output;
687 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
688 NodeMap node_map(&output);
689 auto concat_node = node_map.GetNode("concat");
690 EXPECT_EQ(concat_node->input(0), "split");
691 EXPECT_EQ(concat_node->input(1), "split:1");
692 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
693 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
694 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
695 }
696
TEST_F(LayoutOptimizerTest,Sum)697 TEST_F(LayoutOptimizerTest, Sum) {
698 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
699 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
700 auto reduction_indices =
701 ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
702 auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
703 auto o = ops::Identity(s.WithOpName("o"), sum);
704 GrapplerItem item;
705 TF_CHECK_OK(s.ToGraphDef(&item.graph));
706 LayoutOptimizer optimizer;
707 GraphDef output;
708 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
709 // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
710 // because of the worse performance in some cases.
711 /*
712 NodeMap node_map(&output);
713 auto sum_node = node_map.GetNode("sum");
714 EXPECT_EQ(sum_node->input(0), "Conv2D");
715 EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
716 auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
717 Tensor tensor;
718 EXPECT_TRUE(
719 tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
720 Tensor tensor_expected(DT_INT32, {3});
721 test::FillValues<int>(&tensor_expected, {0, 2, 3});
722 test::ExpectTensorEqual<int>(tensor_expected, tensor);
723 */
724 }
725
TEST_F(LayoutOptimizerTest,MulScalarAnd4D)726 TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
727 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
728 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
729 auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
730 auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
731 auto o = ops::Identity(s.WithOpName("o"), mul);
732 GrapplerItem item;
733 TF_CHECK_OK(s.ToGraphDef(&item.graph));
734 LayoutOptimizer optimizer;
735 GraphDef output;
736 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
737 NodeMap node_map(&output);
738 auto mul_node = node_map.GetNode("mul");
739 EXPECT_EQ(mul_node->input(0), "scalar");
740 EXPECT_EQ(mul_node->input(1), "Conv2D");
741 }
742
TEST_F(LayoutOptimizerTest,Mul4DAndScalar)743 TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
744 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
745 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
746 auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
747 auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
748 auto o = ops::Identity(s.WithOpName("o"), mul);
749 GrapplerItem item;
750 TF_CHECK_OK(s.ToGraphDef(&item.graph));
751 LayoutOptimizer optimizer;
752 GraphDef output;
753 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
754 NodeMap node_map(&output);
755 auto mul_node = node_map.GetNode("mul");
756 EXPECT_EQ(mul_node->input(0), "Conv2D");
757 EXPECT_EQ(mul_node->input(1), "scalar");
758 }
759
TEST_F(LayoutOptimizerTest,Mul4DAndUnknownRank)760 TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
761 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
762 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
763 auto unknown_rank =
764 ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
765 ops::Placeholder::Shape(PartialTensorShape()));
766 Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
767 Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
768 auto o = ops::AddN(s.WithOpName("o"), {mul, c});
769 GrapplerItem item;
770 TF_CHECK_OK(s.ToGraphDef(&item.graph));
771 LayoutOptimizer optimizer;
772 GraphDef output;
773 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
774 NodeMap node_map(&output);
775 auto mul_node = node_map.GetNode("mul");
776 // Node mul should not be processed by layout optimizer, because one of its
777 // inputs is of unknown rank.
778 EXPECT_EQ(mul_node->input(0),
779 "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
780 EXPECT_EQ(mul_node->input(1), "unknown");
781 }
782
TEST_F(LayoutOptimizerTest,Mul4DAnd4D)783 TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
784 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
786 auto i = ops::Identity(s.WithOpName("i"), conv);
787 auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
788 auto o = ops::Identity(s.WithOpName("o"), mul);
789 GrapplerItem item;
790 TF_CHECK_OK(s.ToGraphDef(&item.graph));
791 LayoutOptimizer optimizer;
792 GraphDef output;
793 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
794 NodeMap node_map(&output);
795 auto mul_node = node_map.GetNode("mul");
796 EXPECT_EQ(mul_node->input(0), "Conv2D");
797 EXPECT_EQ(mul_node->input(1), "i");
798 }
799
TEST_F(LayoutOptimizerTest,Mul4DAndVector)800 TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
801 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
802 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
803 auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
804 auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
805 auto o = ops::Identity(s.WithOpName("o"), mul);
806 GrapplerItem item;
807 TF_CHECK_OK(s.ToGraphDef(&item.graph));
808 LayoutOptimizer optimizer;
809 GraphDef output;
810 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
811 NodeMap node_map(&output);
812 auto mul_node = node_map.GetNode("mul");
813 EXPECT_EQ(mul_node->input(0), "Conv2D");
814 EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
815 auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
816 Tensor tensor;
817 EXPECT_TRUE(
818 tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
819 Tensor tensor_expected(DT_INT32, {4});
820 test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
821 test::ExpectTensorEqual<int>(tensor_expected, tensor);
822 }
823
TEST_F(LayoutOptimizerTest,MulVectorAnd4D)824 TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
825 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
826 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
827 auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
828 auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
829 auto o = ops::Identity(s.WithOpName("o"), mul);
830 GrapplerItem item;
831 TF_CHECK_OK(s.ToGraphDef(&item.graph));
832 LayoutOptimizer optimizer;
833 GraphDef output;
834 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
835 NodeMap node_map(&output);
836 auto mul_node = node_map.GetNode("mul");
837 EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
838 EXPECT_EQ(mul_node->input(1), "Conv2D");
839 auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
840 Tensor tensor;
841 EXPECT_TRUE(
842 tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
843 Tensor tensor_expected(DT_INT32, {4});
844 test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
845 test::ExpectTensorEqual<int>(tensor_expected, tensor);
846 }
847
TEST_F(LayoutOptimizerTest,SliceConst)848 TEST_F(LayoutOptimizerTest, SliceConst) {
849 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
850 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
851 auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
852 auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
853 auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
854 auto o = ops::Identity(s.WithOpName("o"), slice);
855 GrapplerItem item;
856 TF_CHECK_OK(s.ToGraphDef(&item.graph));
857 LayoutOptimizer optimizer;
858 GraphDef output;
859 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
860 NodeMap node_map(&output);
861 auto slice_node = node_map.GetNode("slice");
862 EXPECT_EQ(slice_node->input(0), "Conv2D");
863 EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
864 EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
865
866 auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
867 Tensor begin_tensor;
868 EXPECT_TRUE(begin_tensor.FromProto(
869 begin_const->mutable_attr()->at({"value"}).tensor()));
870 Tensor begin_tensor_expected(DT_INT32, {4});
871 test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
872 test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
873
874 auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
875 Tensor size_tensor;
876 EXPECT_TRUE(size_tensor.FromProto(
877 size_const->mutable_attr()->at({"value"}).tensor()));
878 Tensor size_tensor_expected(DT_INT32, {4});
879 test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
880 test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
881 }
882
TEST_F(LayoutOptimizerTest,SliceNonConst)883 TEST_F(LayoutOptimizerTest, SliceNonConst) {
884 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
885 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
886 auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
887 auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
888 auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
889 auto isize = ops::Identity(s.WithOpName("isize"), size);
890 auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
891 auto o = ops::Identity(s.WithOpName("o"), slice);
892 GrapplerItem item;
893 TF_CHECK_OK(s.ToGraphDef(&item.graph));
894 LayoutOptimizer optimizer;
895 GraphDef output;
896 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
897 NodeMap node_map(&output);
898 auto slice_node = node_map.GetNode("slice");
899 EXPECT_EQ(slice_node->input(0), "Conv2D");
900 EXPECT_EQ(slice_node->input(1),
901 "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
902 EXPECT_EQ(slice_node->input(2),
903 "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
904 auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
905 EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
906 EXPECT_EQ(perm1->input(0), "ibegin");
907 auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
908 EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
909 EXPECT_EQ(perm2->input(0), "isize");
910 }
911
TEST_F(LayoutOptimizerTest,DoNotApplyOptimizerTwice)912 TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
913 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
914 auto scalar =
915 ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
916 auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
917 auto o = ops::Identity(s.WithOpName("o"), mul);
918 GrapplerItem item;
919 TF_CHECK_OK(s.ToGraphDef(&item.graph));
920 LayoutOptimizer optimizer;
921 GraphDef output;
922 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
923 EXPECT_TRUE(errors::IsInvalidArgument(status));
924 }
925
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAnd4D)926 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
927 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
928 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
929 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
930 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
931 GrapplerItem item;
932 TF_CHECK_OK(s.ToGraphDef(&item.graph));
933 LayoutOptimizer optimizer;
934 GraphDef output;
935 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
936 NodeMap node_map(&output);
937 auto shapen_node = node_map.GetNode("shapen");
938 EXPECT_EQ(shapen_node->input(0), "Conv2D");
939 EXPECT_EQ(shapen_node->input(1), "Conv2D");
940 auto add_node = node_map.GetNode("add");
941 EXPECT_EQ(add_node->input(0),
942 "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
943 EXPECT_EQ(add_node->input(1),
944 "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
945 auto vec_permute1 =
946 node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
947 EXPECT_EQ(vec_permute1->input(0), "shapen");
948 EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
949 auto vec_permute2 =
950 node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
951 EXPECT_EQ(vec_permute2->input(0), "shapen:1");
952 EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
953 }
954
TEST_F(LayoutOptimizerTest,ShapeNWithInputsVectorAnd4D)955 TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
956 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
957 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
958 auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
959 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
960 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
961 GrapplerItem item;
962 TF_CHECK_OK(s.ToGraphDef(&item.graph));
963 LayoutOptimizer optimizer;
964 GraphDef output;
965 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
966 NodeMap node_map(&output);
967 auto shapen_node = node_map.GetNode("shapen");
968 EXPECT_EQ(shapen_node->input(0), "vector");
969 EXPECT_EQ(shapen_node->input(1), "Conv2D");
970 auto add_node = node_map.GetNode("add");
971 EXPECT_EQ(add_node->input(0), "shapen");
972 EXPECT_EQ(add_node->input(1),
973 "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
974 auto vec_permute =
975 node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
976 EXPECT_EQ(vec_permute->input(0), "shapen:1");
977 EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
978 }
979
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAndNoNeedToTransform4D)980 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
981 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
982 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
983 auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
984 auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
985 Output i2 = ops::Identity(s.WithOpName("i2"), i1);
986 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
987 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
988 GrapplerItem item;
989 TF_CHECK_OK(s.ToGraphDef(&item.graph));
990 LayoutOptimizer optimizer;
991 GraphDef output;
992 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
993 NodeMap node_map(&output);
994 auto shapen_node = node_map.GetNode("shapen");
995 EXPECT_EQ(shapen_node->input(0), "Conv2D");
996 EXPECT_EQ(shapen_node->input(1), "i2");
997 }
998
TEST_F(LayoutOptimizerTest,Switch)999 TEST_F(LayoutOptimizerTest, Switch) {
1000 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1001 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1002 ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
1003 auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
1004 auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
1005 auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
1006 GrapplerItem item;
1007 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1008 LayoutOptimizer optimizer;
1009 GraphDef output;
1010 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1011 NodeMap node_map(&output);
1012 auto switch_node = node_map.GetNode("switch");
1013 EXPECT_EQ(switch_node->input(0), "Conv2D");
1014 EXPECT_EQ(switch_node->input(1), "ctrl");
1015 auto i1_node = node_map.GetNode("i1");
1016 auto i2_node = node_map.GetNode("i2");
1017 auto trans1 = node_map.GetNode(i1_node->input(0));
1018 EXPECT_EQ(trans1->input(0), "switch:1");
1019 auto trans2 = node_map.GetNode(i2_node->input(0));
1020 EXPECT_EQ(trans2->input(0), "switch");
1021 }
1022
TEST_F(LayoutOptimizerTest,MergeBothInputsConvertible)1023 TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
1024 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1026 Output i1 = ops::Identity(s.WithOpName("i1"), conv);
1027 auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
1028 auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1029 GrapplerItem item;
1030 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1031 LayoutOptimizer optimizer;
1032 GraphDef output;
1033 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1034 NodeMap node_map(&output);
1035 auto merge_node = node_map.GetNode("merge");
1036 EXPECT_EQ(merge_node->input(0), "Conv2D");
1037 EXPECT_EQ(merge_node->input(1), "i1");
1038 auto i2_node = node_map.GetNode("i2");
1039 EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1040 auto transpose =
1041 node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1042 EXPECT_EQ(transpose->input(0), "merge");
1043 }
1044
TEST_F(LayoutOptimizerTest,MergeOneInputNotConvertible)1045 TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
1046 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1047 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1048 auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1049 auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
1050 auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1051 GrapplerItem item;
1052 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1053 LayoutOptimizer optimizer;
1054 GraphDef output;
1055 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1056 NodeMap node_map(&output);
1057 auto merge_node = node_map.GetNode("merge");
1058 EXPECT_EQ(merge_node->input(0), "tensor_4d");
1059 EXPECT_EQ(merge_node->input(1),
1060 "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1061 }
1062
TEST_F(LayoutOptimizerTest,Complex)1063 TEST_F(LayoutOptimizerTest, Complex) {
1064 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1066 auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
1067 auto i = ops::Identity(s.WithOpName("i"), comp);
1068 GrapplerItem item;
1069 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1070 LayoutOptimizer optimizer;
1071 GraphDef output;
1072 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1073 NodeMap node_map(&output);
1074 auto merge_node = node_map.GetNode("complex");
1075 EXPECT_EQ(merge_node->input(0), "Conv2D");
1076 EXPECT_EQ(merge_node->input(1), "Conv2D");
1077 auto trans =
1078 node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1079 EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
1080 }
1081
TEST_F(LayoutOptimizerTest,IdentityNWithInputsVectorAnd4D)1082 TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
1083 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1084 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1085 auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
1086 auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
1087 auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
1088 GrapplerItem item;
1089 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1090 LayoutOptimizer optimizer;
1091 GraphDef output;
1092 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1093 NodeMap node_map(&output);
1094 auto i = node_map.GetNode("identity_n");
1095 EXPECT_EQ(i->input(0), "vector");
1096 EXPECT_EQ(i->input(1), "Conv2D");
1097 auto trans =
1098 node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1099 EXPECT_EQ(trans->input(0), "identity_n:1");
1100 auto add_node = node_map.GetNode("add");
1101 EXPECT_EQ(add_node->input(0), "identity_n");
1102 EXPECT_EQ(add_node->input(1),
1103 "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1104 }
1105 } // namespace
1106 } // namespace grappler
1107 } // namespace tensorflow
1108