1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <deque>
17 #include <unordered_set>
18
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/memory_types.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/devices.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
32 #include "tensorflow/core/grappler/utils/frame.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37
38 namespace tensorflow {
39 namespace grappler {
40 namespace {
41
42 const char kSuffix[] = "LayoutOptimizer";
43 const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
44 const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
45 const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
46 const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
47 const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
48 const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
49 const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
50 const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
51 const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
52 const char kReshapeConst[] = "ReshapeConst";
53
GetOpsFormatSupported()54 std::set<string> GetOpsFormatSupported() {
55 std::set<string> ops_format_supported = {
56 "AvgPool",
57 "AvgPoolGrad",
58 "Conv2D",
59 "Conv2DBackpropFilter",
60 "Conv2DBackpropInput",
61 "BiasAdd",
62 "BiasAddGrad",
63 "DepthwiseConv2dNative",
64 "DepthwiseConv2dNativeBackpropInput",
65 "DepthwiseConv2dNativeBackpropFilter",
66 "FusedBatchNorm",
67 "FusedBatchNormV2",
68 "FusedBatchNormGrad",
69 "FusedBatchNormGradV2",
70 "FusedConv2DBiasActivation",
71 "MaxPool",
72 "MaxPoolV2",
73 "MaxPoolGrad",
74 "MaxPoolGradGrad",
75 "MaxPoolGradV2",
76 "MaxPoolGradGradV2",
77 "SpaceToDepth",
78 "DepthToSpace"};
79 return ops_format_supported;
80 }
81
GetOpsFormatAgnostic()82 std::set<string> GetOpsFormatAgnostic() {
83 std::set<string> ops_format_agnostic = {"Abs",
84 "Add",
85 "AddN",
86 "AddV2",
87 "Acos",
88 "Acosh",
89 "All",
90 "Angle",
91 "Any",
92 "ApproximateEqual",
93 "Asin",
94 "Asinh",
95 "Atan",
96 "Atan2",
97 "Atanh",
98 "Betainc",
99 "Bitcast",
100 "Cast",
101 "Ceil",
102 "CheckNumerics",
103 "Complex",
104 "ComplexAbs",
105 "Concat",
106 "ConcatV2",
107 "Conj",
108 "Cos",
109 "Cosh",
110 "Digamma",
111 "Div",
112 "Elu",
113 "EluGrad",
114 "Enter",
115 "Equal",
116 "Erf",
117 "Erfc",
118 "Exit",
119 "Exp",
120 "Expm1",
121 "FakeQuantWithMinMaxVars",
122 "FakeQuantWithMinMaxArgs",
123 "Fill",
124 "Floor",
125 "FloorDiv",
126 "FloorMod",
127 "Greater",
128 "GreaterEqual",
129 "GuaranteeConst",
130 "HistogramSummary",
131 "Identity",
132 "IdentityN",
133 "Igamma",
134 "Igammac",
135 "Imag",
136 "Inv",
137 "InvGrad",
138 "IsFinite",
139 "IsInf",
140 "IsNan",
141 "Less",
142 "LessEqual",
143 "Lgamma",
144 "Log",
145 "LogicalAnd",
146 "LogicalNot",
147 "LogicalOr",
148 "Log1p",
149 "Max",
150 "Maximum",
151 "Mean",
152 "Merge",
153 "Min",
154 "Minimum",
155 "Mod",
156 "Mul",
157 "Neg",
158 "NextIteration",
159 "NotEqual",
160 "OnesLike",
161 "Pad",
162 "PreventGradient",
163 "Prod",
164 "Polygamma",
165 "QuantizeAndDequantizeV2",
166 "QuantizeAndDequantizeV3",
167 "Pow",
168 "Real",
169 "RealDiv",
170 "Reciprocal",
171 "ReciprocalGrad",
172 "Relu",
173 "Relu6",
174 "Relu6Grad",
175 "ReluGrad",
176 "Rint",
177 "Select",
178 "Selu",
179 "SeluGrad",
180 "Shape",
181 "ShapeN",
182 "Sigmoid",
183 "SigmoidGrad",
184 "Sign",
185 "Sin",
186 "Sinh",
187 "Slice",
188 "Snapshot",
189 "Softplus",
190 "SoftplusGrad",
191 "Split",
192 "SplitV",
193 "StridedSlice",
194 "StridedSliceGrad",
195 "Switch",
196 "Tile",
197 "TruncateDiv",
198 "TruncateMod",
199 "ReverseV2",
200 "Round",
201 "Rsqrt",
202 "RsqrtGrad",
203 "Sqrt",
204 "SqrtGrad",
205 "Square",
206 "SquaredDifference",
207 "Squeeze",
208 "StopGradient",
209 "Sub",
210 "Sum",
211 "Tan",
212 "Tanh",
213 "TanhGrad",
214 "ZerosLike",
215 "Zeta"};
216 return ops_format_agnostic;
217 }
218
EndWith(const string & str,const string & ending)219 bool EndWith(const string& str, const string& ending) {
220 if (str.size() < ending.size()) return false;
221 if (str.substr(str.size() - ending.size(), ending.size()) == ending)
222 return true;
223 return false;
224 }
225
IsNodeByLayoutOptimizer(const string & node_name)226 bool IsNodeByLayoutOptimizer(const string& node_name) {
227 const string suffix = kSuffix;
228 return EndWith(node_name, suffix);
229 }
230
IsNodeType(const string & node_name,const string & type)231 bool IsNodeType(const string& node_name, const string& type) {
232 const string suffix = strings::StrCat(type, "-", kSuffix);
233 return EndWith(node_name, suffix);
234 }
235
IsTransposeNHWCToNCHW(const string & node_name)236 bool IsTransposeNHWCToNCHW(const string& node_name) {
237 return IsNodeType(node_name, kTransposeNHWCToNCHW);
238 }
239
IsTransposeNCHWToNHWC(const string & node_name)240 bool IsTransposeNCHWToNHWC(const string& node_name) {
241 return IsNodeType(node_name, kTransposeNCHWToNHWC);
242 }
243
IsDimMapNHWCToNCHW(const string & node_name)244 bool IsDimMapNHWCToNCHW(const string& node_name) {
245 return IsNodeType(node_name, kDimMapNHWCToNCHW);
246 }
247
IsDimMapNCHWToNHWC(const string & node_name)248 bool IsDimMapNCHWToNHWC(const string& node_name) {
249 return IsNodeType(node_name, kDimMapNCHWToNHWC);
250 }
251
IsVecPermuteNHWCToNCHW(const string & node_name)252 bool IsVecPermuteNHWCToNCHW(const string& node_name) {
253 return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
254 }
255
IsVecPermuteNCHWToNHWC(const string & node_name)256 bool IsVecPermuteNCHWToNHWC(const string& node_name) {
257 return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
258 }
259
IsConcat(const NodeDef & node)260 bool IsConcat(const NodeDef& node) {
261 const auto op = node.op();
262 return op == "Concat" || op == "ConcatV2";
263 }
264
IsConcatV1(const NodeDef & node)265 bool IsConcatV1(const NodeDef& node) {
266 const auto op = node.op();
267 return op == "Concat";
268 }
269
IsMaxPoolV2(const NodeDef & node)270 bool IsMaxPoolV2(const NodeDef& node) {
271 const auto& op = node.op();
272 return op == "MaxPoolV2";
273 }
274
IsMaxPoolGradV1(const NodeDef & node)275 bool IsMaxPoolGradV1(const NodeDef& node) {
276 const auto& op = node.op();
277 return op == "MaxPoolGrad";
278 }
279
IsMaxPoolGradV2(const NodeDef & node)280 bool IsMaxPoolGradV2(const NodeDef& node) {
281 const auto& op = node.op();
282 return op == "MaxPoolGradV2";
283 }
284
IsMaxPoolGradGradV1(const NodeDef & node)285 bool IsMaxPoolGradGradV1(const NodeDef& node) {
286 const auto& op = node.op();
287 return op == "MaxPoolGradGrad";
288 }
289
IsMaxPoolGradGradV2(const NodeDef & node)290 bool IsMaxPoolGradGradV2(const NodeDef& node) {
291 const auto& op = node.op();
292 return op == "MaxPoolGradGradV2";
293 }
294
IsUnaryGrad(const NodeDef & node)295 bool IsUnaryGrad(const NodeDef& node) {
296 bool is_unary_grad =
297 IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
298 IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
299 IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
300 IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
301 return is_unary_grad;
302 }
303
IsComparisonOp(const NodeDef & node)304 bool IsComparisonOp(const NodeDef& node) {
305 bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
306 IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
307 IsLessEqual(node) || IsNotEqual(node);
308 return is_compare;
309 }
310
IsReduceOp(const NodeDef & node)311 bool IsReduceOp(const NodeDef& node) {
312 return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
313 IsMin(node) || IsAll(node) || IsAny(node);
314 }
315
IsBinaryOp(const NodeDef & node)316 bool IsBinaryOp(const NodeDef& node) {
317 bool is_binary =
318 IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
319 IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
320 IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
321 IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
322 IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
323 IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
324 return is_binary;
325 }
326
NonControlInputs(const NodeDef & node)327 std::vector<int> NonControlInputs(const NodeDef& node) {
328 std::vector<int> pos;
329 for (int i = 0; i < node.input_size(); i++) {
330 if (!IsControlInput(node.input(i))) {
331 pos.push_back(i);
332 }
333 }
334 return pos;
335 }
336
DataInputPosConcat(const NodeDef & node)337 std::vector<int> DataInputPosConcat(const NodeDef& node) {
338 int n = node.attr().at("N").i();
339 std::vector<int> input_pos;
340 int start = (IsConcatV1(node)) ? 1 : 0;
341 int end = start + n;
342 for (int i = start; i < end; i++) {
343 input_pos.push_back(i);
344 }
345 return input_pos;
346 }
347
DataInputPos(const NodeDef & node)348 std::vector<int> DataInputPos(const NodeDef& node) {
349 if (IsSplit(node) || IsHistogramSummary(node)) {
350 return {1};
351 }
352 if (IsStridedSliceGrad(node)) {
353 return {4};
354 }
355 if (IsBinaryOp(node) || IsUnaryGrad(node)) {
356 return {0, 1};
357 }
358 if (IsBetainc(node) || IsSelect(node)) {
359 return {0, 1, 2};
360 }
361 if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node) || IsMerge(node)) {
362 return NonControlInputs(node);
363 }
364 if (IsConcat(node)) {
365 return DataInputPosConcat(node);
366 }
367 if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
368 return {0};
369 }
370 return {};
371 }
372
IsHostMemory(const NodeDef & node,int output_port)373 bool IsHostMemory(const NodeDef& node, int output_port) {
374 DeviceNameUtils::ParsedName parsed_name;
375 if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
376 DeviceType device_type(parsed_name.type);
377 Status s = FindKernelDef(device_type, node, nullptr, nullptr);
378 if (s.ok()) {
379 tensorflow::MemoryTypeVector in_mtypes;
380 tensorflow::MemoryTypeVector out_mtypes;
381 s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
382 node, &in_mtypes, &out_mtypes);
383 if (s.ok()) {
384 if (out_mtypes[output_port] == HOST_MEMORY) {
385 return true;
386 }
387 }
388 } else {
389 return true;
390 }
391 }
392 return false;
393 }
394
395 class GraphProcessor {
396 public:
GraphProcessor(const GraphProperties & graph_properties,const VirtualPlacer & virtual_placer,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,NodeMap * node_map)397 GraphProcessor(const GraphProperties& graph_properties,
398 const VirtualPlacer& virtual_placer,
399 const std::unordered_set<string>& nodes_to_preserve,
400 GraphDef* graph, NodeMap* node_map)
401 : graph_properties_(graph_properties),
402 virtual_placer_(virtual_placer),
403 nodes_to_preserve_(nodes_to_preserve),
404 graph_(graph),
405 node_map_(node_map) {}
406
407 protected:
AddNodePermConst(const string & name,const string & device,const std::vector<int> & permutation)408 NodeDef* AddNodePermConst(const string& name, const string& device,
409 const std::vector<int>& permutation) {
410 NodeDef* node = graph_->add_node();
411 node_map_->AddNode(name, node);
412 node->set_name(name);
413 node->set_op("Const");
414 AttrValue attr_data_type;
415 attr_data_type.set_type(DT_INT32);
416 node->mutable_attr()->insert({"dtype", attr_data_type});
417 AttrValue attr_tensor;
418 Tensor tensor(DT_INT32, TensorShape({4}));
419 for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
420 tensor.flat<int>()(i) = permutation[i];
421 }
422 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
423 node->mutable_attr()->insert({"value", attr_tensor});
424 string device_name;
425 if (device.empty()) {
426 device_name = virtual_placer_.get_canonical_device_name(*node);
427 } else {
428 device_name = device;
429 }
430 node->set_device(device_name);
431 return node;
432 }
433
AddNodeConstScalar(const string & name,const string & device,DataType dtype,int value)434 NodeDef* AddNodeConstScalar(const string& name, const string& device,
435 DataType dtype, int value) {
436 NodeDef* node = graph_->add_node();
437 node_map_->AddNode(name, node);
438 node->set_name(name);
439 node->set_op("Const");
440 AttrValue attr_data_type;
441 attr_data_type.set_type(dtype);
442 node->mutable_attr()->insert({"dtype", attr_data_type});
443 AttrValue attr_tensor;
444 Tensor tensor(dtype, TensorShape({}));
445 tensor.scalar<int>()() = value;
446 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
447 node->mutable_attr()->insert({"value", attr_tensor});
448 string device_name;
449 if (device.empty()) {
450 device_name = virtual_placer_.get_canonical_device_name(*node);
451 } else {
452 device_name = device;
453 }
454 node->set_device(device_name);
455 return node;
456 }
457
LayoutOptimizerNode(const string & base_name)458 string LayoutOptimizerNode(const string& base_name) {
459 return strings::StrCat(base_name, "-", kSuffix);
460 }
461
462 const GraphProperties& graph_properties_;
463 const VirtualPlacer& virtual_placer_;
464 const std::unordered_set<string>& nodes_to_preserve_;
465 GraphDef* graph_;
466 NodeMap* node_map_;
467 };
468
469 struct OptimizeContext {
OptimizeContexttensorflow::grappler::__anonbd4a17fb0111::OptimizeContext470 OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
471 const GraphProperties& graph_properties,
472 const VirtualPlacer& virtual_placer,
473 const std::unordered_set<string>& nodes_to_preserve,
474 bool is_in_frame)
475 : graph(graph),
476 node(node),
477 node_map(node_map),
478 graph_properties(graph_properties),
479 virtual_placer(virtual_placer),
480 nodes_to_preserve(nodes_to_preserve),
481 is_in_frame(is_in_frame) {}
482 GraphDef* graph;
483 NodeDef* node;
484 NodeMap* node_map;
485 const GraphProperties& graph_properties;
486 const VirtualPlacer& virtual_placer;
487 const std::unordered_set<string>& nodes_to_preserve;
488 bool is_in_frame;
489 };
490
491 class NodeProcessor : public GraphProcessor {
492 public:
NodeProcessor(const OptimizeContext & opt_cxt)493 explicit NodeProcessor(const OptimizeContext& opt_cxt)
494 : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
495 opt_cxt.nodes_to_preserve, opt_cxt.graph,
496 opt_cxt.node_map),
497 node_(opt_cxt.node),
498 is_in_frame_(opt_cxt.is_in_frame) {}
~NodeProcessor()499 virtual ~NodeProcessor() {}
ConvertNode()500 virtual Status ConvertNode() {
501 if (ShouldProcess()) {
502 UpdateAttrDataFormat();
503 UpdateAttrKSize();
504 UpdateAttrStrides();
505 UpdateAttrDilations();
506 UpdateAttrExplicitPaddings();
507 UpdateAttrShape();
508 TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
509 TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
510 TF_RETURN_IF_ERROR(CustomizedProcessing());
511 }
512 return Status::OK();
513 }
514
515 protected:
IsPortDimsN(const NodeDef & node,int port,int n) const516 bool IsPortDimsN(const NodeDef& node, int port, int n) const {
517 if (node.attr().find("_output_shapes") != node.attr().end()) {
518 if (node.attr().at("_output_shapes").list().shape_size() > port) {
519 auto shape = node.attr().at("_output_shapes").list().shape(port);
520 if (shape.unknown_rank()) {
521 return false;
522 }
523 if (shape.dim_size() == n) {
524 return true;
525 }
526 }
527 }
528 return false;
529 }
530
IsPortZeroDimsN(const NodeDef & node,int n) const531 bool IsPortZeroDimsN(const NodeDef& node, int n) const {
532 return IsPortDimsN(node, 0, n);
533 }
534
IsPortZeroDimsFour(const NodeDef & node) const535 bool IsPortZeroDimsFour(const NodeDef& node) const {
536 return NodeProcessor::IsPortZeroDimsN(node, 4) ||
537 IsTransposeNCHWToNHWC(node.name());
538 }
539
IsPortDimsFour(const NodeDef & node,int port) const540 bool IsPortDimsFour(const NodeDef& node, int port) const {
541 return NodeProcessor::IsPortDimsN(node, port, 4) ||
542 IsTransposeNCHWToNHWC(node.name());
543 }
544
IsNHWC() const545 bool IsNHWC() const {
546 if (node_->attr().find("data_format") != node_->attr().end()) {
547 if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
548 return true;
549 }
550 }
551 return false;
552 }
553
HasOutputs() const554 bool HasOutputs() const {
555 auto outputs = node_map_->GetOutputs(node_->name());
556 return !outputs.empty();
557 }
558
HasAttribute(const NodeDef & node,const string & attr) const559 Status HasAttribute(const NodeDef& node, const string& attr) const {
560 if (node.attr().find(attr) == node.attr().end()) {
561 return Status(error::INVALID_ARGUMENT,
562 strings::StrCat("Missing attribute ", attr));
563 }
564 return Status::OK();
565 }
566
MustPreserve() const567 bool MustPreserve() const {
568 return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
569 }
570
IsOnGPU() const571 bool IsOnGPU() const {
572 string device_name;
573 if (node_->device().empty()) {
574 device_name = virtual_placer_.get_canonical_device_name(*node_);
575 } else {
576 device_name = node_->device();
577 }
578 string device;
579 string not_used;
580 if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
581 str_util::StrContains(str_util::Lowercase(device),
582 str_util::Lowercase(DEVICE_GPU))) {
583 return true;
584 }
585 return false;
586 }
587
ShouldProcess() const588 virtual bool ShouldProcess() const {
589 return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
590 HasOutputs() && IsOnGPU();
591 }
592
UpdateAttrShape()593 virtual void UpdateAttrShape() {
594 if (node_->attr().find("_output_shapes") != node_->attr().end()) {
595 for (const auto& pos : GetOutputPos()) {
596 auto shape = node_->mutable_attr()
597 ->at("_output_shapes")
598 .mutable_list()
599 ->mutable_shape(pos);
600 if (shape->dim_size() == 4) {
601 int64 h = shape->dim(1).size();
602 int64 w = shape->dim(2).size();
603 int64 c = shape->dim(3).size();
604 shape->mutable_dim(1)->set_size(c);
605 shape->mutable_dim(2)->set_size(h);
606 shape->mutable_dim(3)->set_size(w);
607 }
608 }
609 }
610 }
611
UpdateAttrValueOfInput(int input_index,bool permute)612 Status UpdateAttrValueOfInput(int input_index, bool permute) {
613 auto input_node = node_map_->GetNode(node_->input(input_index));
614 // We created a copy of the node, so that we don't modify the original node,
615 // which might be used elsewhere. Note that this copy also copies the
616 // control dependency input in the case this node is inside a loop,
617 // to ensure added_node is in the same frame with node_.
618 NodeDef* added_node = graph_->add_node();
619 *added_node = *input_node;
620 string base_name = strings::StrCat(node_->name(), "-", input_index);
621 string node_name = LayoutOptimizerNode(base_name);
622 added_node->set_name(node_name);
623 *node_->mutable_input(input_index) = node_name;
624 node_map_->AddNode(node_name, added_node);
625 node_map_->AddOutput(node_name, node_->name());
626 return UpdateAttrValue(added_node, permute);
627 }
628
GetInputPos() const629 virtual std::vector<int> GetInputPos() const { return {0}; }
630
GetOutputPos() const631 virtual std::set<int> GetOutputPos() const {
632 // For most nodes, no need to process control nodes or nodes that use an
633 // output other than the first output: only the first output is of
634 // 4D NCHW/NHWC format and thus relevant here.
635 std::set<int> output_pos = {0};
636 return output_pos;
637 }
638
AddLayoutTransposeToInputs()639 virtual Status AddLayoutTransposeToInputs() {
640 std::vector<int> input_pos = GetInputPos();
641 for (const auto& pos : input_pos) {
642 string node_name = LayoutOptimizerNode(
643 strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
644 DataType dtype =
645 graph_properties_.GetInputProperties(node_->name())[pos].dtype();
646 auto input_node = node_map_->GetNode(node_->input(pos));
647 TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
648 string const_name = GetOrAddNodePermNHWCToNCHW(pos);
649 int output_pos;
650 ParseNodeName(node_->input(pos), &output_pos);
651 AddNodeTranspose(
652 node_name, node_->input(pos), const_name, dtype,
653 input_node->attr().at("_output_shapes").list().shape(output_pos),
654 true);
655 node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
656 node_name);
657 node_map_->AddOutput(node_name, node_->name());
658 *node_->mutable_input(pos) = node_name;
659 }
660 return Status::OK();
661 }
662
AddTransformToOutputs(const string & op)663 Status AddTransformToOutputs(const string& op) {
664 auto outputs = node_map_->GetOutputs(node_->name());
665 string const_name = GetOrAddNodePermNCHWToNHWC();
666 int output_count = 0;
667 for (const auto& output : outputs) {
668 int connections = 0;
669 int connections_removed = 0;
670 for (int i = 0; i < output->input_size(); i++) {
671 auto& input = *output->mutable_input(i);
672 int input_port;
673 string input_name = ParseNodeName(input, &input_port);
674 auto output_pos = GetOutputPos();
675 if (input_name == node_->name()) {
676 connections++;
677 if (output_pos.find(input_port) != output_pos.end()) {
678 connections_removed++;
679 string added_node_base_name =
680 strings::StrCat(node_->name(), "-", output_count, "-", i);
681 string added_node_name;
682 DataType dtype =
683 graph_properties_.GetOutputProperties(node_->name())[input_port]
684 .dtype();
685 if (op == "Transpose") {
686 added_node_name = LayoutOptimizerNode(strings::StrCat(
687 added_node_base_name, "-", kTransposeNCHWToNHWC));
688 TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
689 AddNodeTranspose(
690 added_node_name, input, const_name, dtype,
691 node_->attr().at("_output_shapes").list().shape(input_port),
692 false);
693 } else if (op == "DataFormatVecPermute") {
694 added_node_name = LayoutOptimizerNode(strings::StrCat(
695 added_node_base_name, "-", kVecPermuteNCHWToNHWC));
696 AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
697 } else {
698 return errors::InvalidArgument("Unsupported op type: ", op);
699 }
700 input = added_node_name;
701 node_map_->AddOutput(node_->name(), added_node_name);
702 node_map_->AddOutput(added_node_name, output->name());
703 }
704 }
705 }
706 if (connections == connections_removed) {
707 node_map_->RemoveOutput(node_->name(), output->name());
708 }
709 output_count++;
710 }
711 return Status::OK();
712 }
713
AddLayoutTransposeToOutputs()714 virtual Status AddLayoutTransposeToOutputs() {
715 return AddTransformToOutputs("Transpose");
716 }
717
CustomizedProcessing()718 virtual Status CustomizedProcessing() { return Status::OK(); }
719
UpdateOrTransformParamInput(int param_index,const string & op,DataType dtype)720 Status UpdateOrTransformParamInput(int param_index, const string& op,
721 DataType dtype) {
722 auto param_node = node_map_->GetNode(node_->input(param_index));
723 bool permute = (op == "DataFormatVecPermute") ? true : false;
724 if (IsConstant(*param_node)) {
725 TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
726 } else {
727 AddDataFormatTranformToParamInput(op, param_index, dtype);
728 }
729 return Status::OK();
730 }
731
732 NodeDef* node_;
733 bool is_in_frame_;
734
735 private:
UpdateAttrKSize()736 void UpdateAttrKSize() {
737 if (node_->attr().find("ksize") != node_->attr().end()) {
738 auto list = node_->mutable_attr()->at("ksize").mutable_list();
739 UpdateTuple(list);
740 }
741 }
742
UpdateAttrStrides()743 void UpdateAttrStrides() {
744 if (node_->attr().find("strides") != node_->attr().end()) {
745 auto list = node_->mutable_attr()->at("strides").mutable_list();
746 UpdateTuple(list);
747 }
748 }
749
UpdateAttrDilations()750 void UpdateAttrDilations() {
751 if (node_->attr().find("dilations") != node_->attr().end()) {
752 auto list = node_->mutable_attr()->at("dilations").mutable_list();
753 UpdateTuple(list);
754 }
755 }
756
UpdateAttrExplicitPaddings()757 void UpdateAttrExplicitPaddings() {
758 if (node_->attr().find("explicit_paddings") != node_->attr().end()) {
759 auto list = node_->mutable_attr()->at("explicit_paddings").mutable_list();
760 int size = list->i_size();
761 if (size == 8) {
762 int64 height_before = list->i(2);
763 int64 height_after = list->i(3);
764 int64 width_before = list->i(4);
765 int64 width_after = list->i(5);
766 list->set_i(2, 0);
767 list->set_i(3, 0);
768 list->set_i(4, height_before);
769 list->set_i(5, height_after);
770 list->set_i(6, width_before);
771 list->set_i(7, width_after);
772 } else if (size != 0) {
773 LOG(ERROR) << "Cannot handle explicit_paddings attribute of size "
774 << size;
775 }
776 }
777 }
778
UpdateAttrDataFormat()779 void UpdateAttrDataFormat() {
780 if (node_->attr().find("data_format") != node_->attr().end()) {
781 if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
782 string* data_format =
783 node_->mutable_attr()->at("data_format").mutable_s();
784 *data_format = "NCHW";
785 }
786 }
787 }
788
UpdateAttrValue(NodeDef * node,bool permute)789 Status UpdateAttrValue(NodeDef* node, bool permute) {
790 TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
791 Tensor tensor;
792 auto success =
793 tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
794 if (!success) {
795 LOG(ERROR) << "Failed to parse TensorProto.";
796 }
797
798 if (permute) {
799 if (tensor.dims() == 1) {
800 if (tensor.flat<int>().size() == 4) {
801 int c = tensor.flat<int>()(3);
802 tensor.flat<int>()(3) = tensor.flat<int>()(2);
803 tensor.flat<int>()(2) = tensor.flat<int>()(1);
804 tensor.flat<int>()(1) = c;
805 } else {
806 return Status(error::INVALID_ARGUMENT,
807 strings::StrCat("Unsupported tensor size: ",
808 tensor.flat<int>().size()));
809 }
810 } else if (tensor.dims() == 2) {
811 for (int i = 0; i < 2; i++) {
812 int c = tensor.matrix<int>()(3, i);
813 tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
814 tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
815 tensor.matrix<int>()(1, i) = c;
816 }
817 } else {
818 return Status(
819 error::INVALID_ARGUMENT,
820 strings::StrCat("Unsupported dimension size: ", tensor.dims()));
821 }
822 } else {
823 for (int i = 0; i < tensor.flat<int>().size(); i++) {
824 int value = tensor.flat<int>()(i);
825 value = (value >= 0) ? value : value + 4;
826 if (value == 1 || value == 2) {
827 value = value + 1;
828 } else if (value == 3) {
829 value = 1;
830 }
831 tensor.flat<int>()(i) = value;
832 }
833 }
834
835 if (tensor.dims() == 0) {
836 tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
837 } else {
838 tensor.AsProtoTensorContent(
839 node->mutable_attr()->at({"value"}).mutable_tensor());
840 }
841 return Status::OK();
842 }
843
AddNodeTranspose(const string & node_name,const string & input_name,const string & const_name,DataType data_type,const TensorShapeProto & input_shape,bool NHWCToNCHW)844 NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
845 const string& const_name, DataType data_type,
846 const TensorShapeProto& input_shape,
847 bool NHWCToNCHW) {
848 NodeDef* node = graph_->add_node();
849 node_map_->AddNode(node_name, node);
850 node->set_name(node_name);
851 *node->add_input() = input_name;
852 *node->add_input() = const_name;
853 node->set_op("Transpose");
854 node->set_device(node_->device());
855 AttrValue attr_data_type;
856 attr_data_type.set_type(data_type);
857 node->mutable_attr()->insert({"T", attr_data_type});
858 AttrValue attr_data_type_perm;
859 attr_data_type_perm.set_type(DT_INT32);
860 node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
861 if (!input_shape.unknown_rank()) {
862 AttrValue attr_output_shape;
863 auto output_shape = attr_output_shape.mutable_list()->add_shape();
864 if (NHWCToNCHW) {
865 output_shape->add_dim()->set_size(input_shape.dim(0).size());
866 output_shape->add_dim()->set_size(input_shape.dim(3).size());
867 output_shape->add_dim()->set_size(input_shape.dim(1).size());
868 output_shape->add_dim()->set_size(input_shape.dim(2).size());
869 } else {
870 output_shape->add_dim()->set_size(input_shape.dim(0).size());
871 output_shape->add_dim()->set_size(input_shape.dim(2).size());
872 output_shape->add_dim()->set_size(input_shape.dim(3).size());
873 output_shape->add_dim()->set_size(input_shape.dim(1).size());
874 }
875 node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
876 }
877 return node;
878 }
879
AddNodePermNHWCToNCHW(const string & base_name,const string & depended_node,const string & device)880 NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
881 const string& depended_node,
882 const string& device) {
883 string name =
884 LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
885 auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
886 // This is to ensure the transpose node and the const node are in the
887 // same frame.
888 *const_node->add_input() = AsControlDependency(depended_node);
889 return const_node;
890 }
891
AddNodePermNCHWToNHWC(const string & base_name,const string & depended_node,const string & device)892 NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
893 const string& depended_node,
894 const string& device) {
895 auto const_node = AddNodePermConst(
896 LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC)),
897 device, {0, 2, 3, 1});
898 // This is to ensure the transpose node and the const node are in the same
899 // frame.
900 *const_node->add_input() = AsControlDependency(depended_node);
901 return const_node;
902 }
903
GetOrAddNodePermNHWCToNCHW(int pos)904 string GetOrAddNodePermNHWCToNCHW(int pos) {
905 string const_name;
906 if (is_in_frame_) {
907 string base_name = strings::StrCat(node_->name(), "-", pos);
908 string input = NodeName(node_->input(pos));
909 string depended_node;
910 if (!IsTransposeNCHWToNHWC(input)) {
911 depended_node = input;
912 } else {
913 auto input_node = node_map_->GetNode(input);
914 depended_node = NodeName(input_node->input(0));
915 }
916 auto const_node =
917 AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
918 const_name = const_node->name();
919 } else {
920 const_name = LayoutOptimizerNode(kPermNHWCToNCHW);
921 }
922 return const_name;
923 }
924
GetOrAddNodePermNCHWToNHWC()925 string GetOrAddNodePermNCHWToNHWC() {
926 string const_name;
927 if (is_in_frame_) {
928 auto const_node =
929 AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
930 const_name = const_node->name();
931 } else {
932 const_name = LayoutOptimizerNode(kPermNCHWToNHWC);
933 }
934 return const_name;
935 }
936
UpdateTuple(AttrValue_ListValue * list)937 void UpdateTuple(AttrValue_ListValue* list) {
938 int64 h = list->i(1);
939 int64 w = list->i(2);
940 int64 c = list->i(3);
941 list->set_i(1, c);
942 list->set_i(2, h);
943 list->set_i(3, w);
944 }
945
IsInputOnHost(const string & input_name) const946 bool IsInputOnHost(const string& input_name) const {
947 string device = node_->device();
948 DeviceNameUtils::ParsedName parsed_name;
949 if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
950 if (parsed_name.type != "CPU") {
951 NodeDef* input = node_map_->GetNode(input_name);
952 int port;
953 ParseNodeName(input_name, &port);
954 if (IsHostMemory(*input, port)) {
955 return true;
956 }
957 }
958 }
959 return false;
960 }
961
AddNodeDataFormatOp(const string & name,const string & input_name,const string & op,DataType dtype,bool nhwc_to_nchw)962 NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
963 const string& op, DataType dtype,
964 bool nhwc_to_nchw) {
965 NodeDef* added_node = graph_->add_node();
966 added_node->set_name(name);
967 added_node->set_op(op);
968 node_map_->AddNode(added_node->name(), added_node);
969 added_node->set_device(node_->device());
970 // The inputs of a DataFormat op could be in host memory for ops such as
971 // Reshape. In such cases, run the kernel on the host too.
972 if (IsInputOnHost(input_name)) {
973 AttrValue attr_kernel;
974 attr_kernel.set_s("host");
975 added_node->mutable_attr()->insert({"_kernel", attr_kernel});
976 }
977 AttrValue attr_data_type;
978 attr_data_type.set_type(dtype);
979 added_node->mutable_attr()->insert({"T", attr_data_type});
980 string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
981 string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
982 AttrValue attr_format;
983 attr_format.set_s(src_format);
984 added_node->mutable_attr()->insert({"src_format", attr_format});
985 attr_format.set_s(dst_format);
986 added_node->mutable_attr()->insert({"dst_format", attr_format});
987 *added_node->add_input() = input_name;
988 return added_node;
989 }
990
AddDataFormatTranformToParamInput(const string & op,int input_pos,DataType dtype)991 void AddDataFormatTranformToParamInput(const string& op, int input_pos,
992 DataType dtype) {
993 string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
994 : kDimMapNHWCToNCHW;
995 string name = LayoutOptimizerNode(
996 strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
997 auto added_node =
998 AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
999 *node_->mutable_input(input_pos) = added_node->name();
1000 node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
1001 added_node->name());
1002 node_map_->AddOutput(added_node->name(), node_->name());
1003 }
1004 };
1005
1006 class AvgPoolGradProcessor : public NodeProcessor {
1007 public:
AvgPoolGradProcessor(const OptimizeContext & opt_cxt)1008 explicit AvgPoolGradProcessor(const OptimizeContext& opt_cxt)
1009 : NodeProcessor(opt_cxt) {}
1010
1011 protected:
GetInputPos() const1012 std::vector<int> GetInputPos() const override { return {1}; }
CustomizedProcessing()1013 Status CustomizedProcessing() override {
1014 return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
1015 }
1016 };
1017
1018 class BiasAddGradProcessor : public NodeProcessor {
1019 public:
BiasAddGradProcessor(const OptimizeContext & opt_cxt)1020 explicit BiasAddGradProcessor(const OptimizeContext& opt_cxt)
1021 : NodeProcessor(opt_cxt) {}
1022
1023 protected:
ShouldProcess() const1024 bool ShouldProcess() const override {
1025 if (MustPreserve()) {
1026 return false;
1027 }
1028 if (!IsOnGPU()) {
1029 return false;
1030 }
1031 auto input = node_map_->GetNode(node_->input(0));
1032 if (input) {
1033 int port;
1034 ParseNodeName(node_->input(0), &port);
1035 if (IsNHWC() && IsPortDimsFour(*input, port)) {
1036 return true;
1037 }
1038 }
1039 return false;
1040 }
1041
AddLayoutTransposeToOutputs()1042 Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1043 };
1044
1045 class Conv2DProcessor : public NodeProcessor {
1046 public:
Conv2DProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1047 Conv2DProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1048 : NodeProcessor(opt_cxt), no_gemm_(no_gemm) {}
1049
1050 protected:
ShouldProcess() const1051 bool ShouldProcess() const override {
1052 return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
1053 HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU();
1054 }
1055
GetShape(const string & input_name) const1056 TensorShapeProto GetShape(const string& input_name) const {
1057 string node_name;
1058 int output_pos;
1059 node_name = ParseNodeName(input_name, &output_pos);
1060 NodeDef* node = node_map_->GetNode(node_name);
1061 if (node->attr().find("_output_shapes") != node->attr().end()) {
1062 return node->attr().at("_output_shapes").list().shape(output_pos);
1063 }
1064 TensorShapeProto shape;
1065 return shape;
1066 }
1067
IsStrideOne() const1068 bool IsStrideOne() const {
1069 if (node_->attr().find("strides") != node_->attr().end()) {
1070 auto list = node_->attr().at("strides").list();
1071 return list.i(1) == 1 && list.i(2) == 1;
1072 }
1073 return false;
1074 }
1075
IsValidPadding() const1076 bool IsValidPadding() const {
1077 if (node_->attr().find("padding") != node_->attr().end()) {
1078 auto padding = node_->attr().at("padding").s();
1079 return padding == "VALID";
1080 }
1081 return false;
1082 }
1083
1084 // The logic inside this function is based on the internal implementation of
1085 // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
1086 // needs to be updated accordingly if the internal implementation changes.
IsGemmUsed(const TensorShapeProto & filter_shape,const TensorShapeProto & input_shape) const1087 bool IsGemmUsed(const TensorShapeProto& filter_shape,
1088 const TensorShapeProto& input_shape) const {
1089 if (filter_shape.dim_size() == 4) {
1090 if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
1091 IsStrideOne()) {
1092 return true;
1093 }
1094 }
1095 if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
1096 if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
1097 input_shape.dim(2).size() == filter_shape.dim(1).size() &&
1098 IsValidPadding()) {
1099 return true;
1100 }
1101 }
1102 return false;
1103 }
1104
IsGemmUsed() const1105 virtual bool IsGemmUsed() const {
1106 auto filter_shape = GetShape(node_->input(1));
1107 auto input_shape = GetShape(node_->input(0));
1108 return IsGemmUsed(filter_shape, input_shape);
1109 }
1110
1111 bool no_gemm_;
1112 };
1113
1114 class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
1115 public:
Conv2DBackpropFilterProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1116 Conv2DBackpropFilterProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1117 : Conv2DProcessor(opt_cxt, no_gemm) {}
1118
1119 protected:
IsGemmUsed() const1120 bool IsGemmUsed() const override {
1121 auto filter_shape = GetShape(node_->name());
1122 auto input_shape = GetShape(node_->input(0));
1123 return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
1124 }
1125
GetInputPos() const1126 std::vector<int> GetInputPos() const override { return {0, 2}; }
1127
AddLayoutTransposeToOutputs()1128 Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1129 // No need to update output shape, as it is always of shape
1130 // [filter_height, filter_width, in_channels, out_channels], regardless of
1131 // whether NCHW or NHWC is used.
UpdateAttrShape()1132 void UpdateAttrShape() override {}
1133 };
1134
1135 class Conv2DBackpropInputProcessor : public Conv2DProcessor {
1136 public:
Conv2DBackpropInputProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1137 Conv2DBackpropInputProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1138 : Conv2DProcessor(opt_cxt, no_gemm) {}
1139
1140 protected:
IsGemmUsed() const1141 bool IsGemmUsed() const override {
1142 auto filter_shape = GetShape(node_->input(1));
1143 auto input_shape = GetShape(node_->name());
1144 return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
1145 }
1146
GetInputPos() const1147 std::vector<int> GetInputPos() const override { return {2}; }
1148
CustomizedProcessing()1149 Status CustomizedProcessing() override {
1150 return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
1151 }
1152 };
1153
1154 class FusedBatchNormGradProcessor : public NodeProcessor {
1155 public:
FusedBatchNormGradProcessor(const OptimizeContext & opt_cxt)1156 explicit FusedBatchNormGradProcessor(const OptimizeContext& opt_cxt)
1157 : NodeProcessor(opt_cxt) {}
1158
1159 protected:
ShouldProcess() const1160 bool ShouldProcess() const override {
1161 return NodeProcessor::ShouldProcess() && IsTraining();
1162 }
1163
GetInputPos() const1164 std::vector<int> GetInputPos() const override { return {0, 1}; }
1165
1166 private:
IsTraining() const1167 bool IsTraining() const {
1168 if (node_->attr().find("is_training") != node_->attr().end()) {
1169 if (node_->attr().at("is_training").b()) {
1170 return true;
1171 }
1172 }
1173 return false;
1174 }
1175 };
1176
1177 class MaxPoolGradProcessor : public NodeProcessor {
1178 public:
MaxPoolGradProcessor(const OptimizeContext & opt_cxt)1179 explicit MaxPoolGradProcessor(const OptimizeContext& opt_cxt)
1180 : NodeProcessor(opt_cxt) {}
1181
1182 protected:
GetInputPos() const1183 std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
1184 };
1185
1186 class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
1187 public:
MaxPoolGradV2Processor(const OptimizeContext & opt_cxt)1188 explicit MaxPoolGradV2Processor(const OptimizeContext& opt_cxt)
1189 : MaxPoolGradProcessor(opt_cxt) {}
1190
1191 protected:
CustomizedProcessing()1192 Status CustomizedProcessing() override {
1193 for (int i = 3; i <= 4; i++) {
1194 TF_RETURN_IF_ERROR(
1195 UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
1196 }
1197 return Status::OK();
1198 }
1199 };
1200
1201 class MaxPoolV2Processor : public NodeProcessor {
1202 public:
MaxPoolV2Processor(const OptimizeContext & opt_cxt)1203 explicit MaxPoolV2Processor(const OptimizeContext& opt_cxt)
1204 : NodeProcessor(opt_cxt) {}
1205
1206 protected:
ShouldProcess() const1207 bool ShouldProcess() const override {
1208 // We check data_input's shape instead, because the shape inference of
1209 // MaxPoolV2 is not able to infer the shape when ksize or strides is not
1210 // constant.
1211 auto data_input = node_map_->GetNode(node_->input(0));
1212 int port;
1213 ParseNodeName(node_->input(0), &port);
1214 return !MustPreserve() && IsNHWC() && IsPortDimsFour(*data_input, port) &&
1215 HasOutputs() && IsOnGPU();
1216 }
1217
CustomizedProcessing()1218 Status CustomizedProcessing() override {
1219 for (int i = 1; i <= 2; i++) {
1220 TF_RETURN_IF_ERROR(
1221 UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
1222 }
1223 return Status::OK();
1224 }
1225 };
1226
1227 class AgnosticNodeProcessor : public NodeProcessor {
1228 public:
AgnosticNodeProcessor(const OptimizeContext & opt_cxt)1229 explicit AgnosticNodeProcessor(const OptimizeContext& opt_cxt)
1230 : NodeProcessor(opt_cxt) {}
1231
1232 protected:
ShouldProcess() const1233 bool ShouldProcess() const override {
1234 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1235 IsNodeAfterNCHWToNHWC() && IsOnGPU();
1236 }
1237
IsNodeAfterNCHWToNHWC(const NodeDef & node) const1238 bool IsNodeAfterNCHWToNHWC(const NodeDef& node) const {
1239 std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1240 std::deque<NodeDef*> queue;
1241 auto data_node_pos = DataInputPos(node);
1242 std::unordered_set<string> visited;
1243 for (const auto& pos : data_node_pos) {
1244 auto input_node = node_map_->GetNode(node.input(pos));
1245 queue.push_back(input_node);
1246 visited.insert(input_node->name());
1247 }
1248 // The code will exit this while loop in one iteration in most cases, as the
1249 // graph is already topologically sorted.
1250 while (!queue.empty()) {
1251 NodeDef* current_node = queue.front();
1252 queue.pop_front();
1253 if (IsTransposeNCHWToNHWC(current_node->name()) ||
1254 IsDimMapNCHWToNHWC(current_node->name()) ||
1255 IsVecPermuteNCHWToNHWC(current_node->name())) {
1256 return true;
1257 }
1258 // We only continue searching if the path is connected through
1259 // format-agnostic nodes.
1260 if (ops_format_agnostic.find(current_node->op()) !=
1261 ops_format_agnostic.end()) {
1262 auto current_node_pos = DataInputPos(*current_node);
1263 for (const auto& pos : current_node_pos) {
1264 auto input_node = node_map_->GetNode(current_node->input(pos));
1265 if (visited.find(input_node->name()) == visited.end()) {
1266 queue.push_back(input_node);
1267 visited.insert(input_node->name());
1268 }
1269 }
1270 }
1271 }
1272 return false;
1273 }
1274
IsNodeAfterNCHWToNHWC() const1275 bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
1276 };
1277
1278 class AddNProcessor : public AgnosticNodeProcessor {
1279 public:
AddNProcessor(const OptimizeContext & opt_cxt)1280 explicit AddNProcessor(const OptimizeContext& opt_cxt)
1281 : AgnosticNodeProcessor(opt_cxt) {}
1282
1283 protected:
GetInputPos() const1284 std::vector<int> GetInputPos() const override {
1285 return NonControlInputs(*node_);
1286 }
1287 };
1288
1289 class BinaryOpProcessor : public AgnosticNodeProcessor {
1290 public:
BinaryOpProcessor(const OptimizeContext & opt_cxt)1291 explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
1292 : AgnosticNodeProcessor(opt_cxt) {}
1293
1294 protected:
ShouldProcess() const1295 bool ShouldProcess() const override {
1296 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1297 IsNodeAfterNCHWToNHWC() &&
1298 (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
1299 IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4) ||
1300 IsNDOperateWithMD(1, 4)) &&
1301 IsOnGPU();
1302 }
1303
GetInputPos() const1304 std::vector<int> GetInputPos() const override {
1305 std::vector<int> input_pos;
1306 auto input0 = node_map_->GetNode(node_->input(0));
1307 auto input1 = node_map_->GetNode(node_->input(1));
1308 int input0_port;
1309 ParseNodeName(node_->input(0), &input0_port);
1310 int input1_port;
1311 ParseNodeName(node_->input(1), &input1_port);
1312 if (IsPortDimsFour(*input0, input0_port)) {
1313 input_pos.push_back(0);
1314 }
1315 if (IsPortDimsFour(*input1, input1_port)) {
1316 input_pos.push_back(1);
1317 }
1318 return input_pos;
1319 }
1320
IsNDOperateWithMD(int n,int m) const1321 bool IsNDOperateWithMD(int n, int m) const {
1322 auto input0 = node_map_->GetNode(node_->input(0));
1323 auto input1 = node_map_->GetNode(node_->input(1));
1324 int input0_port;
1325 ParseNodeName(node_->input(0), &input0_port);
1326 int input1_port;
1327 ParseNodeName(node_->input(1), &input1_port);
1328
1329 if (input0 && input1) {
1330 bool input0_is_n = (n == 4) ? IsPortDimsFour(*input0, input0_port)
1331 : IsPortDimsN(*input0, input0_port, n);
1332 bool input1_is_m = (m == 4) ? IsPortDimsFour(*input1, input1_port)
1333 : IsPortDimsN(*input1, input1_port, m);
1334 return input0_is_n && input1_is_m;
1335 }
1336 return false;
1337 }
1338
AddNodeShapeConst(const string & name,int num_channels,const string & depended_node)1339 NodeDef* AddNodeShapeConst(const string& name, int num_channels,
1340 const string& depended_node) {
1341 NodeDef* node = graph_->add_node();
1342 node_map_->AddNode(name, node);
1343 node->set_name(name);
1344 node->set_op("Const");
1345 node->set_device(node_->device());
1346 AttrValue attr_data_type;
1347 attr_data_type.set_type(DT_INT32);
1348 node->mutable_attr()->insert({"dtype", attr_data_type});
1349
1350 AttrValue attr_tensor;
1351 Tensor tensor(DT_INT32, TensorShape({4}));
1352 std::vector<int> shape = {1, num_channels, 1, 1};
1353 for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1354 tensor.flat<int>()(i) = shape[i];
1355 }
1356 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1357 node->mutable_attr()->insert({"value", attr_tensor});
1358 if (is_in_frame_) {
1359 // This is to ensure the transpose node and the const node are in the
1360 // same frame.
1361 *node->add_input() = AsControlDependency(depended_node);
1362 }
1363 return node;
1364 }
1365
AddNodeReshape(const string & node_name,const string & input_name,const string & shape_const_node_name,DataType data_type)1366 NodeDef* AddNodeReshape(const string& node_name, const string& input_name,
1367 const string& shape_const_node_name,
1368 DataType data_type) {
1369 NodeDef* node = graph_->add_node();
1370 node_map_->AddNode(node_name, node);
1371 node->set_name(node_name);
1372 *node->add_input() = input_name;
1373 *node->add_input() = shape_const_node_name;
1374 node->set_op("Reshape");
1375 node->set_device(node_->device());
1376
1377 AttrValue attr_type_indices;
1378 attr_type_indices.set_type(DT_INT32);
1379 node->mutable_attr()->insert({"Tshape", attr_type_indices});
1380
1381 AttrValue attr_type_params;
1382 attr_type_params.set_type(data_type);
1383 node->mutable_attr()->insert({"T", attr_type_params});
1384 return node;
1385 }
1386
CustomizedProcessing()1387 Status CustomizedProcessing() override {
1388 int vector_index = -1;
1389 if (IsNDOperateWithMD(4, 1)) {
1390 vector_index = 1;
1391 } else if (IsNDOperateWithMD(1, 4)) {
1392 vector_index = 0;
1393 }
1394 if (vector_index != -1) {
1395 string base_name = strings::StrCat(node_->name(), "-", vector_index);
1396 string reshape_node_name = LayoutOptimizerNode(
1397 strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
1398 string shape_const_node_name =
1399 LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
1400 auto input_node = node_map_->GetNode(node_->input(vector_index));
1401 TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
1402 int port;
1403 ParseNodeName(node_->input(vector_index), &port);
1404 int vector_size = input_node->attr()
1405 .at("_output_shapes")
1406 .list()
1407 .shape(port)
1408 .dim(0)
1409 .size();
1410 AddNodeShapeConst(shape_const_node_name, vector_size,
1411 NodeName(node_->input(vector_index)));
1412 TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
1413 AddNodeReshape(reshape_node_name, node_->input(vector_index),
1414 shape_const_node_name, node_->attr().at("T").type());
1415 node_map_->AddOutput(shape_const_node_name, reshape_node_name);
1416 node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
1417 node_->name(), reshape_node_name);
1418 node_map_->AddOutput(reshape_node_name, node_->name());
1419 *node_->mutable_input(vector_index) = reshape_node_name;
1420 }
1421 return Status::OK();
1422 }
1423 };
1424
1425 class ConcatProcessor : public AgnosticNodeProcessor {
1426 public:
ConcatProcessor(const OptimizeContext & opt_cxt)1427 explicit ConcatProcessor(const OptimizeContext& opt_cxt)
1428 : AgnosticNodeProcessor(opt_cxt) {
1429 // For Concat, the concat axis is the first input; for ConcatV2,
1430 // the last input. Note that if with control inputs, the number of inputs
1431 // is larger than the integer attribute N.
1432 int n = node_->attr().at("N").i();
1433 axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
1434 }
1435
1436 protected:
GetInputPos() const1437 std::vector<int> GetInputPos() const override {
1438 return DataInputPosConcat(*node_);
1439 }
1440
CustomizedProcessing()1441 Status CustomizedProcessing() override {
1442 DataType dtype =
1443 (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
1444 return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
1445 dtype);
1446 }
1447
1448 int axis_node_pos_;
1449 };
1450
1451 class FillProcessor : public AgnosticNodeProcessor {
1452 public:
FillProcessor(const OptimizeContext & opt_cxt)1453 explicit FillProcessor(const OptimizeContext& opt_cxt)
1454 : AgnosticNodeProcessor(opt_cxt) {}
1455
1456 protected:
GetInputPos() const1457 std::vector<int> GetInputPos() const override { return {}; }
1458
CustomizedProcessing()1459 Status CustomizedProcessing() override {
1460 DataType dtype = node_->attr().at("index_type").type();
1461 return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
1462 }
1463 };
1464
1465 class HistogramSummaryProcessor : public AgnosticNodeProcessor {
1466 public:
HistogramSummaryProcessor(const OptimizeContext & opt_cxt)1467 explicit HistogramSummaryProcessor(const OptimizeContext& opt_cxt)
1468 : AgnosticNodeProcessor(opt_cxt) {}
1469
1470 protected:
ShouldProcess() const1471 bool ShouldProcess() const override {
1472 auto input1 = node_map_->GetNode(node_->input(1));
1473 int port;
1474 ParseNodeName(node_->input(1), &port);
1475 return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1476 IsPortDimsFour(*input1, port) && IsOnGPU();
1477 }
1478
GetInputPos() const1479 std::vector<int> GetInputPos() const override { return {1}; }
1480
AddLayoutTransposeToOutputs()1481 Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1482 };
1483
1484 class IdentityNProcessor : public AgnosticNodeProcessor {
1485 public:
IdentityNProcessor(const OptimizeContext & opt_cxt)1486 explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
1487 : AgnosticNodeProcessor(opt_cxt) {
1488 std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1489 for (int i = 0; i < node_->input_size(); i++) {
1490 auto input = node_map_->GetNode(node_->input(i));
1491 int port;
1492 ParseNodeName(node_->input(i), &port);
1493 // Skip control input.
1494 if (port != -1) {
1495 bool is_agnostic =
1496 ops_format_agnostic.find(input->op()) != ops_format_agnostic.end();
1497 if (IsPortDimsFour(*input, port) &&
1498 ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) ||
1499 IsTransposeNCHWToNHWC(input->name()))) {
1500 input_pos_.push_back(i);
1501 }
1502 }
1503 }
1504 }
1505
1506 protected:
ShouldProcess() const1507 bool ShouldProcess() const override {
1508 return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1509 IsOnGPU();
1510 }
1511
GetInputPos() const1512 std::vector<int> GetInputPos() const override { return input_pos_; }
1513
GetOutputPos() const1514 std::set<int> GetOutputPos() const override {
1515 std::set<int> output_pos{};
1516 for (const auto& input_pos : input_pos_) {
1517 output_pos.insert(input_pos);
1518 }
1519 return output_pos;
1520 }
1521
1522 private:
1523 std::vector<int> input_pos_;
1524 };
1525
1526 class ShapeProcessor : public IdentityNProcessor {
1527 public:
ShapeProcessor(const OptimizeContext & opt_cxt)1528 explicit ShapeProcessor(const OptimizeContext& opt_cxt)
1529 : IdentityNProcessor(opt_cxt) {}
1530
1531 protected:
AddLayoutTransposeToOutputs()1532 Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1533
CustomizedProcessing()1534 Status CustomizedProcessing() override {
1535 return AddTransformToOutputs("DataFormatVecPermute");
1536 }
1537 };
1538
1539 class MergeProcessor : public AgnosticNodeProcessor {
1540 public:
MergeProcessor(const OptimizeContext & opt_cxt)1541 explicit MergeProcessor(const OptimizeContext& opt_cxt)
1542 : AgnosticNodeProcessor(opt_cxt) {}
1543
1544 protected:
ShouldProcess() const1545 bool ShouldProcess() const override {
1546 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1547 IsEveryInputAfterNCHWToNHWC() && IsOnGPU();
1548 }
1549
GetInputPos() const1550 std::vector<int> GetInputPos() const override {
1551 std::vector<int> input_pos;
1552 int n = node_->attr().at("N").i();
1553 input_pos.reserve(n);
1554 for (int i = 0; i < n; i++) {
1555 input_pos.push_back(i);
1556 }
1557 return input_pos;
1558 }
1559
1560 private:
IsEveryInputAfterNCHWToNHWC() const1561 bool IsEveryInputAfterNCHWToNHWC() const {
1562 std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1563 for (const auto& input : node_->input()) {
1564 auto input_node = node_map_->GetNode(input);
1565 int port;
1566 ParseNodeName(input, &port);
1567 bool is_agnostic = ops_format_agnostic.find(input_node->op()) !=
1568 ops_format_agnostic.end();
1569 if (IsPortDimsFour(*input_node, port) &&
1570 ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) ||
1571 IsTransposeNCHWToNHWC(input_node->name()))) {
1572 continue;
1573 }
1574 return false;
1575 }
1576 return true;
1577 }
1578 };
1579
1580 class PadProcessor : public AgnosticNodeProcessor {
1581 public:
PadProcessor(const OptimizeContext & opt_cxt)1582 explicit PadProcessor(const OptimizeContext& opt_cxt)
1583 : AgnosticNodeProcessor(opt_cxt) {}
1584
1585 protected:
CustomizedProcessing()1586 Status CustomizedProcessing() override {
1587 DataType dtype = node_->attr().at("Tpaddings").type();
1588 return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
1589 }
1590 };
1591
1592 class ReverseProcessor : public AgnosticNodeProcessor {
1593 public:
ReverseProcessor(const OptimizeContext & opt_cxt)1594 explicit ReverseProcessor(const OptimizeContext& opt_cxt)
1595 : AgnosticNodeProcessor(opt_cxt) {}
1596
1597 protected:
CustomizedProcessing()1598 Status CustomizedProcessing() override {
1599 DataType dtype = node_->attr().at("Tidx").type();
1600 return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
1601 }
1602 };
1603
1604 class SplitProcessor : public AgnosticNodeProcessor {
1605 public:
SplitProcessor(const OptimizeContext & opt_cxt)1606 explicit SplitProcessor(const OptimizeContext& opt_cxt)
1607 : AgnosticNodeProcessor(opt_cxt) {
1608 axis_node_pos_ = 0;
1609 }
1610
1611 protected:
GetInputPos() const1612 std::vector<int> GetInputPos() const override { return {1}; }
1613
GetOutputPos() const1614 std::set<int> GetOutputPos() const override {
1615 std::set<int> output_pos{0};
1616 if (HasAttribute(*node_, "num_split").ok()) {
1617 for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
1618 output_pos.insert(i);
1619 }
1620 }
1621 return output_pos;
1622 }
1623
CustomizedProcessing()1624 Status CustomizedProcessing() override {
1625 return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
1626 DT_INT32);
1627 }
1628
1629 int axis_node_pos_;
1630 };
1631
1632 class SplitVProcessor : public SplitProcessor {
1633 public:
SplitVProcessor(const OptimizeContext & opt_cxt)1634 explicit SplitVProcessor(const OptimizeContext& opt_cxt)
1635 : SplitProcessor(opt_cxt) {
1636 axis_node_pos_ = 2;
1637 }
1638
1639 protected:
GetInputPos() const1640 std::vector<int> GetInputPos() const override { return {0}; }
1641 };
1642
1643 class TernaryOpProcessor : public AgnosticNodeProcessor {
1644 public:
TernaryOpProcessor(const OptimizeContext & opt_cxt)1645 explicit TernaryOpProcessor(const OptimizeContext& opt_cxt)
1646 : AgnosticNodeProcessor(opt_cxt) {}
1647
1648 protected:
GetInputPos() const1649 std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
1650 };
1651
1652 class SelectProcessor : public AgnosticNodeProcessor {
1653 public:
SelectProcessor(const OptimizeContext & opt_cxt)1654 explicit SelectProcessor(const OptimizeContext& opt_cxt)
1655 : AgnosticNodeProcessor(opt_cxt) {}
1656
1657 protected:
ShouldProcess() const1658 bool ShouldProcess() const override {
1659 auto input0 = node_map_->GetNode(node_->input(0));
1660 int input0_port;
1661 ParseNodeName(node_->input(0), &input0_port);
1662 bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
1663 IsPortDimsN(*input0, input0_port, 1) ||
1664 IsPortDimsN(*input0, input0_port, 4);
1665 return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
1666 }
1667
GetInputPos() const1668 std::vector<int> GetInputPos() const override {
1669 auto input0 = node_map_->GetNode(node_->input(0));
1670 int input0_port;
1671 ParseNodeName(node_->input(0), &input0_port);
1672 // Input 0 could be a scalar, a vector with size matching the first
1673 // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
1674 if (IsPortDimsFour(*input0, input0_port)) {
1675 return {0, 1, 2};
1676 } else {
1677 return {1, 2};
1678 }
1679 }
1680 };
1681
1682 class UnaryGradProcessor : public AgnosticNodeProcessor {
1683 public:
UnaryGradProcessor(const OptimizeContext & opt_cxt)1684 explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
1685 : AgnosticNodeProcessor(opt_cxt) {}
1686
1687 protected:
GetInputPos() const1688 std::vector<int> GetInputPos() const override { return {0, 1}; }
1689 };
1690
1691 class SliceProcessor : public AgnosticNodeProcessor {
1692 public:
SliceProcessor(const OptimizeContext & opt_cxt)1693 explicit SliceProcessor(const OptimizeContext& opt_cxt)
1694 : AgnosticNodeProcessor(opt_cxt) {
1695 // Skip the first input, which is the data to be sliced.
1696 start_ = 1;
1697 // Note that we can't use node_->input_size() here because there
1698 // could be control inputs.
1699 end_ = 2;
1700 }
1701
1702 protected:
ProcessInputs()1703 Status ProcessInputs() {
1704 for (int i = start_; i <= end_; i++) {
1705 DataType dtype = node_->attr().at("Index").type();
1706 TF_RETURN_IF_ERROR(
1707 UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
1708 }
1709 return Status::OK();
1710 }
1711
CustomizedProcessing()1712 Status CustomizedProcessing() override { return ProcessInputs(); }
1713
1714 int start_;
1715 int end_;
1716 };
1717
1718 class StridedSliceProcessor : public SliceProcessor {
1719 public:
StridedSliceProcessor(const OptimizeContext & opt_cxt)1720 explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
1721 : SliceProcessor(opt_cxt) {
1722 start_ = 1;
1723 end_ = 3;
1724 }
1725
1726 protected:
ShouldProcess() const1727 bool ShouldProcess() const override {
1728 return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
1729 }
1730
CustomizedProcessing()1731 Status CustomizedProcessing() override {
1732 TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
1733 TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
1734 TF_RETURN_IF_ERROR(ProcessInputs());
1735 return Status::OK();
1736 }
1737
1738 private:
IsMaskZero(const string & mask) const1739 bool IsMaskZero(const string& mask) const {
1740 return node_->attr().at(mask).i() == 0;
1741 }
1742
IsOnlyBeginEndMask() const1743 bool IsOnlyBeginEndMask() const {
1744 return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
1745 IsMaskZero("shrink_axis_mask");
1746 }
1747
UpdateMask(const string & mask)1748 Status UpdateMask(const string& mask) {
1749 int i = node_->attr().at(mask).i();
1750 if (i < 0 || i > 15) {
1751 return errors::InvalidArgument("invalid mask value: ", i);
1752 }
1753 if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
1754 switch (i) {
1755 case 2:
1756 case 3:
1757 i += 2;
1758 break;
1759 case 4:
1760 case 5:
1761 i += 4;
1762 break;
1763 case 6:
1764 case 7:
1765 i += 6;
1766 break;
1767 case 8:
1768 case 9:
1769 i -= 6;
1770 break;
1771 case 10:
1772 case 11:
1773 i -= 4;
1774 break;
1775 case 12:
1776 case 13:
1777 i -= 2;
1778 break;
1779 }
1780 node_->mutable_attr()->at(mask).set_i(i);
1781 return Status::OK();
1782 }
1783 };
1784
1785 class StridedSliceGradProcessor : public StridedSliceProcessor {
1786 public:
StridedSliceGradProcessor(const OptimizeContext & opt_cxt)1787 explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
1788 : StridedSliceProcessor(opt_cxt) {
1789 start_ = 0;
1790 end_ = 3;
1791 }
1792
1793 protected:
GetInputPos() const1794 std::vector<int> GetInputPos() const override { return {4}; }
1795 };
1796
1797 class SqueezeProcessor : public AgnosticNodeProcessor {
1798 public:
SqueezeProcessor(const OptimizeContext & opt_cxt)1799 explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
1800 : AgnosticNodeProcessor(opt_cxt) {}
1801
1802 protected:
ShouldProcess() const1803 bool ShouldProcess() const override {
1804 bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
1805 (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
1806 return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1807 IsInputConvertible() && is_dims_supported && IsOnGPU();
1808 }
1809
AddLayoutTransposeToOutputs()1810 Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1811
CustomizedProcessing()1812 Status CustomizedProcessing() override {
1813 TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
1814 auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
1815 if (list->i_size() == 2) {
1816 list->set_i(0, 2);
1817 list->set_i(1, 3);
1818 } else if (list->i_size() == 3) {
1819 list->set_i(1, 2);
1820 list->set_i(2, 3);
1821 }
1822 return Status::OK();
1823 }
1824
1825 private:
IsInputConvertible() const1826 bool IsInputConvertible() const {
1827 int input_port;
1828 auto input = node_map_->GetNode(node_->input(0));
1829 ParseNodeName(node_->input(0), &input_port);
1830 if (input->attr().find("_output_shapes") != input->attr().end()) {
1831 auto shape = input->attr().at("_output_shapes").list().shape(input_port);
1832 if (shape.dim_size() != 4) {
1833 return false;
1834 }
1835 if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
1836 return true;
1837 }
1838 if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
1839 shape.dim(2).size() == 1) {
1840 return true;
1841 }
1842 }
1843 return false;
1844 }
1845
IsAlongAxis(const std::vector<int> & axis) const1846 bool IsAlongAxis(const std::vector<int>& axis) const {
1847 if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
1848 auto list = node_->attr().at("squeeze_dims").list();
1849 // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1850 if (list.i_size() == 0) return true;
1851 if (list.i_size() == axis.size()) {
1852 bool along_axis = true;
1853 for (int i = 0; i < axis.size(); i++) {
1854 along_axis = along_axis && (list.i(i) == axis[i]);
1855 }
1856 if (along_axis) return true;
1857 }
1858 }
1859 return false;
1860 }
IsAlongHW() const1861 bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
IsAlongNHW() const1862 bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
1863 };
1864
1865 class ReduceProcessor : public AgnosticNodeProcessor {
1866 public:
ReduceProcessor(const OptimizeContext & opt_cxt)1867 explicit ReduceProcessor(const OptimizeContext& opt_cxt)
1868 : AgnosticNodeProcessor(opt_cxt) {}
1869
1870 protected:
ShouldProcess() const1871 bool ShouldProcess() const override {
1872 auto input0 = node_map_->GetNode(node_->input(0));
1873 int port;
1874 ParseNodeName(node_->input(0), &port);
1875 return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1876 IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
1877 IsOnGPU();
1878 }
1879
CustomizedProcessing()1880 Status CustomizedProcessing() override {
1881 if (IsReduceAxisSupported()) {
1882 DataType dtype = node_->attr().at("Tidx").type();
1883 TF_RETURN_IF_ERROR(
1884 UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
1885 }
1886 return Status::OK();
1887 }
1888
AddLayoutTransposeToOutputs()1889 Status AddLayoutTransposeToOutputs() override {
1890 if (KeepDims()) {
1891 return AddTransformToOutputs("Transpose");
1892 }
1893 return Status::OK();
1894 }
1895
1896 private:
IsReduceAxisSupported() const1897 bool IsReduceAxisSupported() const {
1898 return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
1899 IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
1900 !KeepDims());
1901 }
1902
IsAlongAxis(const std::vector<int> & axis) const1903 bool IsAlongAxis(const std::vector<int>& axis) const {
1904 auto axis_node = node_map_->GetNode(node_->input(1));
1905 if (!IsConstant(*axis_node)) {
1906 return false;
1907 }
1908 if (HasAttribute(*axis_node, "value").ok()) {
1909 Tensor tensor;
1910 auto success = tensor.FromProto(axis_node->attr().at({"value"}).tensor());
1911 if (!success) {
1912 LOG(ERROR) << "Failed to parse TensorProto.";
1913 }
1914 if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
1915 bool along_axis = true;
1916 for (int i = 0; i < axis.size(); i++) {
1917 along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
1918 }
1919 if (along_axis) return true;
1920 }
1921 }
1922 return false;
1923 }
1924
IsAlongAllFourDims() const1925 bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
1926
IsAlongHWC() const1927 bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
1928
IsAlongNHW() const1929 bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
1930
IsAlongHW() const1931 bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
1932
IsAlongC() const1933 bool IsAlongC() const { return IsAlongAxis({3}); }
1934
KeepDims() const1935 bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
1936 };
1937
1938 class SwitchProcessor : public AgnosticNodeProcessor {
1939 public:
SwitchProcessor(const OptimizeContext & opt_cxt)1940 explicit SwitchProcessor(const OptimizeContext& opt_cxt)
1941 : AgnosticNodeProcessor(opt_cxt) {}
1942
1943 protected:
GetOutputPos() const1944 std::set<int> GetOutputPos() const override { return {0, 1}; }
1945 };
1946
1947 class TileProcessor : public AgnosticNodeProcessor {
1948 public:
TileProcessor(const OptimizeContext & opt_cxt)1949 explicit TileProcessor(const OptimizeContext& opt_cxt)
1950 : AgnosticNodeProcessor(opt_cxt) {}
1951
1952 protected:
CustomizedProcessing()1953 Status CustomizedProcessing() override {
1954 DataType dtype = node_->attr().at("Tmultiples").type();
1955 return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
1956 }
1957 };
1958
1959 class DataLayoutOptimizer : GraphProcessor {
1960 public:
DataLayoutOptimizer(const GraphProperties & graph_properties,const VirtualPlacer & virtual_placer,const LayoutOptimizer::TuningConfig & config,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,NodeMap * node_map)1961 explicit DataLayoutOptimizer(
1962 const GraphProperties& graph_properties,
1963 const VirtualPlacer& virtual_placer,
1964 const LayoutOptimizer::TuningConfig& config,
1965 const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
1966 NodeMap* node_map)
1967 : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
1968 graph, node_map),
1969 config_(config) {}
1970
Optimize()1971 Status Optimize() {
1972 VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
1973 TF_RETURN_IF_ERROR(Expand());
1974 VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
1975 TF_RETURN_IF_ERROR(Collapse());
1976 VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
1977 return Status::OK();
1978 }
1979
1980 private:
AddNodePermNHWCToNCHW()1981 NodeDef* AddNodePermNHWCToNCHW() {
1982 return AddNodePermConst(LayoutOptimizerNode(kPermNHWCToNCHW), "",
1983 {0, 3, 1, 2});
1984 }
1985
AddNodePermNCHWToNHWC()1986 NodeDef* AddNodePermNCHWToNHWC() {
1987 return AddNodePermConst(LayoutOptimizerNode(kPermNCHWToNHWC), "",
1988 {0, 2, 3, 1});
1989 }
1990
1991 // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Expand()1992 Status Expand() {
1993 int node_size_original = graph_->node_size();
1994
1995 FrameView frame_view;
1996 TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph_));
1997
1998 // This is the first pass where we expand the nodes which support NCHW.
1999 std::set<string> ops_format_supported = GetOpsFormatSupported();
2000 for (int i = 0; i < node_size_original; i++) {
2001 if (IsNodeByLayoutOptimizer(graph_->node(i).name())) {
2002 return Status(error::INVALID_ARGUMENT,
2003 "The graph is already optimized by layout optimizer.");
2004 }
2005 if (ops_format_supported.find(graph_->node(i).op()) !=
2006 ops_format_supported.end()) {
2007 auto node = graph_->mutable_node(i);
2008 bool is_in_frame = frame_view.IsInFrame(*node);
2009 OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
2010 virtual_placer_, nodes_to_preserve_,
2011 is_in_frame);
2012 std::unique_ptr<NodeProcessor> node_processor;
2013 if (IsAvgPoolGrad(*node)) {
2014 node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
2015 } else if (IsBiasAddGrad(*node)) {
2016 node_processor.reset(new BiasAddGradProcessor(opt_cxt));
2017 } else if (IsConv2D(*node)) {
2018 node_processor.reset(new Conv2DProcessor(opt_cxt, config_.no_gemm));
2019 } else if (IsConv2DBackpropFilter(*node)) {
2020 node_processor.reset(
2021 new Conv2DBackpropFilterProcessor(opt_cxt, config_.no_gemm));
2022 } else if (IsConv2DBackpropInput(*node)) {
2023 node_processor.reset(
2024 new Conv2DBackpropInputProcessor(opt_cxt, config_.no_gemm));
2025 } else if (IsDepthwiseConv2dNative(*node)) {
2026 node_processor.reset(new Conv2DProcessor(opt_cxt, true));
2027 } else if (IsDepthwiseConv2dNativeBackpropFilter(*node)) {
2028 node_processor.reset(
2029 new Conv2DBackpropFilterProcessor(opt_cxt, true));
2030 } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
2031 node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
2032 } else if (IsFusedBatchNormGrad(*node)) {
2033 node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
2034 } else if (IsMaxPoolV2(*node)) {
2035 node_processor.reset(new MaxPoolV2Processor(opt_cxt));
2036 } else if (IsMaxPoolGradV1(*node) || IsMaxPoolGradGradV1(*node)) {
2037 node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
2038 } else if (IsMaxPoolGradV2(*node) || IsMaxPoolGradGradV2(*node)) {
2039 node_processor.reset(new MaxPoolGradV2Processor(opt_cxt));
2040 } else {
2041 node_processor.reset(new NodeProcessor(opt_cxt));
2042 }
2043 TF_RETURN_IF_ERROR(node_processor->ConvertNode());
2044 }
2045 }
2046
2047 // This is the second pass where we expand layout-agnostic nodes. This pass
2048 // only needs to be performed if at least one node in the previous pass is
2049 // expanded.
2050 if (graph_->node_size() > node_size_original) {
2051 // Create Const nodes holding the permutation used by added Transposes of
2052 // nodes not in a frame.
2053 AddNodePermNHWCToNCHW();
2054 AddNodePermNCHWToNHWC();
2055 std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
2056 for (int i = 0; i < graph_->node_size(); i++) {
2057 if (ops_format_agnostic.find(graph_->node(i).op()) !=
2058 ops_format_agnostic.end()) {
2059 auto node = graph_->mutable_node(i);
2060 bool is_in_frame = frame_view.IsInFrame(*node);
2061 OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
2062 virtual_placer_, nodes_to_preserve_,
2063 is_in_frame);
2064 std::unique_ptr<NodeProcessor> node_processor;
2065 if (IsAddN(*node)) {
2066 node_processor.reset(new AddNProcessor(opt_cxt));
2067 } else if (IsBetainc(*node)) {
2068 node_processor.reset(new TernaryOpProcessor(opt_cxt));
2069 } else if (IsBinaryOp(*node)) {
2070 node_processor.reset(new BinaryOpProcessor(opt_cxt));
2071 } else if (IsConcat(*node)) {
2072 node_processor.reset(new ConcatProcessor(opt_cxt));
2073 } else if (IsFill(*node)) {
2074 node_processor.reset(new FillProcessor(opt_cxt));
2075 } else if (IsHistogramSummary(*node)) {
2076 node_processor.reset(new HistogramSummaryProcessor(opt_cxt));
2077 } else if (IsIdentityN(*node)) {
2078 node_processor.reset(new IdentityNProcessor(opt_cxt));
2079 } else if (IsMerge(*node)) {
2080 node_processor.reset(new MergeProcessor(opt_cxt));
2081 } else if (IsPad(*node) || IsMirrorPad(*node) ||
2082 IsMirrorPadGrad(*node)) {
2083 node_processor.reset(new PadProcessor(opt_cxt));
2084 } else if (IsReduceOp(*node)) {
2085 node_processor.reset(new ReduceProcessor(opt_cxt));
2086 } else if (IsReverseV2(*node)) {
2087 node_processor.reset(new ReverseProcessor(opt_cxt));
2088 } else if (IsSelect(*node)) {
2089 node_processor.reset(new SelectProcessor(opt_cxt));
2090 } else if (IsSlice(*node)) {
2091 node_processor.reset(new SliceProcessor(opt_cxt));
2092 } else if (IsStridedSlice(*node)) {
2093 node_processor.reset(new StridedSliceProcessor(opt_cxt));
2094 } else if (IsShape(*node) || IsShapeN(*node)) {
2095 node_processor.reset(new ShapeProcessor(opt_cxt));
2096 } else if (IsSplit(*node)) {
2097 node_processor.reset(new SplitProcessor(opt_cxt));
2098 } else if (IsSplitV(*node)) {
2099 node_processor.reset(new SplitVProcessor(opt_cxt));
2100 } else if (IsSqueeze(*node)) {
2101 node_processor.reset(new SqueezeProcessor(opt_cxt));
2102 } else if (IsStridedSliceGrad(*node)) {
2103 node_processor.reset(new StridedSliceGradProcessor(opt_cxt));
2104 } else if (IsSwitch(*node)) {
2105 node_processor.reset(new SwitchProcessor(opt_cxt));
2106 } else if (IsTile(*node)) {
2107 node_processor.reset(new TileProcessor(opt_cxt));
2108 } else if (IsUnaryGrad(*node)) {
2109 node_processor.reset(new UnaryGradProcessor(opt_cxt));
2110 } else {
2111 node_processor.reset(new AgnosticNodeProcessor(opt_cxt));
2112 }
2113 TF_RETURN_IF_ERROR(node_processor->ConvertNode());
2114 }
2115 }
2116 }
2117 return Status::OK();
2118 }
2119
2120 // Remove all node pairs, where a NCHW-to-NHWC node is followed by
2121 // a NHWC-to-NCHW node.
Collapse()2122 Status Collapse() {
2123 std::unordered_set<string> nodes_removable;
2124 for (int i = 0; i < graph_->node_size(); i++) {
2125 auto node = graph_->mutable_node(i);
2126 node->mutable_attr()->erase("_output_shapes");
2127 if (IsTransposeNHWCToNCHW(node->name()) ||
2128 IsDimMapNHWCToNCHW(node->name()) ||
2129 IsVecPermuteNHWCToNCHW(node->name())) {
2130 bool transpose_pair = IsTransposeNHWCToNCHW(node->name()) &&
2131 IsTransposeNCHWToNHWC(node->input(0));
2132 bool dim_map_pair = IsDimMapNHWCToNCHW(node->name()) &&
2133 IsDimMapNCHWToNHWC(node->input(0));
2134 bool vec_permute_pair = IsVecPermuteNHWCToNCHW(node->name()) &&
2135 IsVecPermuteNCHWToNHWC(node->input(0));
2136 if (transpose_pair || dim_map_pair || vec_permute_pair) {
2137 const string& trans_first = node->input(0);
2138 const string& trans_second = node->name();
2139 auto outputs = node_map_->GetOutputs(trans_second);
2140 CHECK(outputs.size() == 1)
2141 << "There is always only a single output for a Transpose node, "
2142 << "due to the way it is added by NodeProcessor.";
2143 NodeDef* output = *outputs.begin();
2144 string input = node_map_->GetNode(trans_first)->input(0);
2145 for (int i = 0; i < output->input_size(); i++) {
2146 if (output->input(i).compare(trans_second) == 0) {
2147 *output->mutable_input(i) = input;
2148 break;
2149 }
2150 }
2151 nodes_removable.insert(trans_first);
2152 nodes_removable.insert(trans_second);
2153 }
2154 }
2155 }
2156 graph_->mutable_node()->erase(
2157 std::remove_if(
2158 graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
2159 [nodes_removable](const NodeDef& node) {
2160 return nodes_removable.find(node.name()) != nodes_removable.end();
2161 }),
2162 graph_->mutable_node()->end());
2163 return Status::OK();
2164 }
2165
2166 const LayoutOptimizer::TuningConfig& config_;
2167 };
2168
GetNumGPUs(const Cluster & cluster)2169 int GetNumGPUs(const Cluster& cluster) {
2170 auto devices = cluster.GetDevices();
2171 int num_gpus = 0;
2172 for (const auto& device : devices) {
2173 if (device.second.type() == "GPU") {
2174 num_gpus++;
2175 }
2176 }
2177 return num_gpus;
2178 }
2179 } // namespace
2180
Tune(const GrapplerItem & item,const GraphProperties & graph_properties,const TuningConfig & config,GraphDef * output)2181 Status LayoutOptimizer::Tune(const GrapplerItem& item,
2182 const GraphProperties& graph_properties,
2183 const TuningConfig& config, GraphDef* output) {
2184 auto status = graph_properties.AnnotateOutputShapes(output);
2185 if (!status.ok()) {
2186 VLOG(1) << "Annotate shape return status: " << status.ToString();
2187 *output = item.graph;
2188 return status;
2189 }
2190 NodeMap node_map(output);
2191 DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
2192 config, nodes_to_preserve_, output,
2193 &node_map);
2194 status = layout_optimizer.Optimize();
2195 return status;
2196 }
2197
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2198 Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
2199 GraphDef* output) {
2200 if (cluster == nullptr) {
2201 return errors::InvalidArgument("cluster == nullptr");
2202 }
2203
2204 if (GetNumGPUs(*cluster) < 1) {
2205 // LayoutOptimizer is currently only tuned for GPU.
2206 *output = item.graph;
2207 return Status::OK();
2208 }
2209
2210 virtual_placer_.reset(new VirtualPlacer(cluster));
2211 nodes_to_preserve_ = item.NodesToPreserve();
2212 GraphProperties graph_properties(item);
2213 auto status = graph_properties.InferStatically(false);
2214 if (!status.ok()) {
2215 VLOG(1) << "Infer shape return status: " << status.ToString();
2216 *output = item.graph;
2217 return status;
2218 }
2219 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
2220
2221 TuningConfig config;
2222 config.no_gemm = true;
2223 // TODO(yaozhang): Enable tuning with various TuningConfig choices with
2224 // the measurement-based estimator.
2225 status = Tune(item, graph_properties, config, output);
2226 if (!status.ok()) {
2227 *output = item.graph;
2228 }
2229 return status;
2230 }
2231
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)2232 void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
2233 const GraphDef& optimize_output, double result) {
2234 // Nothing to do for LayoutOptimizer.
2235 }
2236
2237 } // end namespace grappler
2238 } // end namespace tensorflow
2239