• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <deque>
17 #include <unordered_set>
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/memory_types.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/devices.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
32 #include "tensorflow/core/grappler/utils/frame.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 
38 namespace tensorflow {
39 namespace grappler {
40 namespace {
41 
42 const char kSuffix[] = "LayoutOptimizer";
43 const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
44 const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
45 const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
46 const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
47 const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
48 const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
49 const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
50 const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
51 const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
52 const char kReshapeConst[] = "ReshapeConst";
53 
GetOpsFormatSupported()54 std::set<string> GetOpsFormatSupported() {
55   std::set<string> ops_format_supported = {
56       "AvgPool",
57       "AvgPoolGrad",
58       "Conv2D",
59       "Conv2DBackpropFilter",
60       "Conv2DBackpropInput",
61       "BiasAdd",
62       "BiasAddGrad",
63       "DepthwiseConv2dNative",
64       "DepthwiseConv2dNativeBackpropInput",
65       "DepthwiseConv2dNativeBackpropFilter",
66       "FusedBatchNorm",
67       "FusedBatchNormV2",
68       "FusedBatchNormGrad",
69       "FusedBatchNormGradV2",
70       "FusedConv2DBiasActivation",
71       "MaxPool",
72       "MaxPoolV2",
73       "MaxPoolGrad",
74       "MaxPoolGradGrad",
75       "MaxPoolGradV2",
76       "MaxPoolGradGradV2",
77       "SpaceToDepth",
78       "DepthToSpace"};
79   return ops_format_supported;
80 }
81 
GetOpsFormatAgnostic()82 std::set<string> GetOpsFormatAgnostic() {
83   std::set<string> ops_format_agnostic = {"Abs",
84                                           "Add",
85                                           "AddN",
86                                           "AddV2",
87                                           "Acos",
88                                           "Acosh",
89                                           "All",
90                                           "Angle",
91                                           "Any",
92                                           "ApproximateEqual",
93                                           "Asin",
94                                           "Asinh",
95                                           "Atan",
96                                           "Atan2",
97                                           "Atanh",
98                                           "Betainc",
99                                           "Bitcast",
100                                           "Cast",
101                                           "Ceil",
102                                           "CheckNumerics",
103                                           "Complex",
104                                           "ComplexAbs",
105                                           "Concat",
106                                           "ConcatV2",
107                                           "Conj",
108                                           "Cos",
109                                           "Cosh",
110                                           "Digamma",
111                                           "Div",
112                                           "Elu",
113                                           "EluGrad",
114                                           "Enter",
115                                           "Equal",
116                                           "Erf",
117                                           "Erfc",
118                                           "Exit",
119                                           "Exp",
120                                           "Expm1",
121                                           "FakeQuantWithMinMaxVars",
122                                           "FakeQuantWithMinMaxArgs",
123                                           "Fill",
124                                           "Floor",
125                                           "FloorDiv",
126                                           "FloorMod",
127                                           "Greater",
128                                           "GreaterEqual",
129                                           "GuaranteeConst",
130                                           "HistogramSummary",
131                                           "Identity",
132                                           "IdentityN",
133                                           "Igamma",
134                                           "Igammac",
135                                           "Imag",
136                                           "Inv",
137                                           "InvGrad",
138                                           "IsFinite",
139                                           "IsInf",
140                                           "IsNan",
141                                           "Less",
142                                           "LessEqual",
143                                           "Lgamma",
144                                           "Log",
145                                           "LogicalAnd",
146                                           "LogicalNot",
147                                           "LogicalOr",
148                                           "Log1p",
149                                           "Max",
150                                           "Maximum",
151                                           "Mean",
152                                           "Merge",
153                                           "Min",
154                                           "Minimum",
155                                           "Mod",
156                                           "Mul",
157                                           "Neg",
158                                           "NextIteration",
159                                           "NotEqual",
160                                           "OnesLike",
161                                           "Pad",
162                                           "PreventGradient",
163                                           "Prod",
164                                           "Polygamma",
165                                           "QuantizeAndDequantizeV2",
166                                           "QuantizeAndDequantizeV3",
167                                           "Pow",
168                                           "Real",
169                                           "RealDiv",
170                                           "Reciprocal",
171                                           "ReciprocalGrad",
172                                           "Relu",
173                                           "Relu6",
174                                           "Relu6Grad",
175                                           "ReluGrad",
176                                           "Rint",
177                                           "Select",
178                                           "Selu",
179                                           "SeluGrad",
180                                           "Shape",
181                                           "ShapeN",
182                                           "Sigmoid",
183                                           "SigmoidGrad",
184                                           "Sign",
185                                           "Sin",
186                                           "Sinh",
187                                           "Slice",
188                                           "Snapshot",
189                                           "Softplus",
190                                           "SoftplusGrad",
191                                           "Split",
192                                           "SplitV",
193                                           "StridedSlice",
194                                           "StridedSliceGrad",
195                                           "Switch",
196                                           "Tile",
197                                           "TruncateDiv",
198                                           "TruncateMod",
199                                           "ReverseV2",
200                                           "Round",
201                                           "Rsqrt",
202                                           "RsqrtGrad",
203                                           "Sqrt",
204                                           "SqrtGrad",
205                                           "Square",
206                                           "SquaredDifference",
207                                           "Squeeze",
208                                           "StopGradient",
209                                           "Sub",
210                                           "Sum",
211                                           "Tan",
212                                           "Tanh",
213                                           "TanhGrad",
214                                           "ZerosLike",
215                                           "Zeta"};
216   return ops_format_agnostic;
217 }
218 
EndWith(const string & str,const string & ending)219 bool EndWith(const string& str, const string& ending) {
220   if (str.size() < ending.size()) return false;
221   if (str.substr(str.size() - ending.size(), ending.size()) == ending)
222     return true;
223   return false;
224 }
225 
IsNodeByLayoutOptimizer(const string & node_name)226 bool IsNodeByLayoutOptimizer(const string& node_name) {
227   const string suffix = kSuffix;
228   return EndWith(node_name, suffix);
229 }
230 
IsNodeType(const string & node_name,const string & type)231 bool IsNodeType(const string& node_name, const string& type) {
232   const string suffix = strings::StrCat(type, "-", kSuffix);
233   return EndWith(node_name, suffix);
234 }
235 
IsTransposeNHWCToNCHW(const string & node_name)236 bool IsTransposeNHWCToNCHW(const string& node_name) {
237   return IsNodeType(node_name, kTransposeNHWCToNCHW);
238 }
239 
IsTransposeNCHWToNHWC(const string & node_name)240 bool IsTransposeNCHWToNHWC(const string& node_name) {
241   return IsNodeType(node_name, kTransposeNCHWToNHWC);
242 }
243 
IsDimMapNHWCToNCHW(const string & node_name)244 bool IsDimMapNHWCToNCHW(const string& node_name) {
245   return IsNodeType(node_name, kDimMapNHWCToNCHW);
246 }
247 
IsDimMapNCHWToNHWC(const string & node_name)248 bool IsDimMapNCHWToNHWC(const string& node_name) {
249   return IsNodeType(node_name, kDimMapNCHWToNHWC);
250 }
251 
IsVecPermuteNHWCToNCHW(const string & node_name)252 bool IsVecPermuteNHWCToNCHW(const string& node_name) {
253   return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
254 }
255 
IsVecPermuteNCHWToNHWC(const string & node_name)256 bool IsVecPermuteNCHWToNHWC(const string& node_name) {
257   return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
258 }
259 
IsConcat(const NodeDef & node)260 bool IsConcat(const NodeDef& node) {
261   const auto op = node.op();
262   return op == "Concat" || op == "ConcatV2";
263 }
264 
IsConcatV1(const NodeDef & node)265 bool IsConcatV1(const NodeDef& node) {
266   const auto op = node.op();
267   return op == "Concat";
268 }
269 
IsMaxPoolV2(const NodeDef & node)270 bool IsMaxPoolV2(const NodeDef& node) {
271   const auto& op = node.op();
272   return op == "MaxPoolV2";
273 }
274 
IsMaxPoolGradV1(const NodeDef & node)275 bool IsMaxPoolGradV1(const NodeDef& node) {
276   const auto& op = node.op();
277   return op == "MaxPoolGrad";
278 }
279 
IsMaxPoolGradV2(const NodeDef & node)280 bool IsMaxPoolGradV2(const NodeDef& node) {
281   const auto& op = node.op();
282   return op == "MaxPoolGradV2";
283 }
284 
IsMaxPoolGradGradV1(const NodeDef & node)285 bool IsMaxPoolGradGradV1(const NodeDef& node) {
286   const auto& op = node.op();
287   return op == "MaxPoolGradGrad";
288 }
289 
IsMaxPoolGradGradV2(const NodeDef & node)290 bool IsMaxPoolGradGradV2(const NodeDef& node) {
291   const auto& op = node.op();
292   return op == "MaxPoolGradGradV2";
293 }
294 
IsUnaryGrad(const NodeDef & node)295 bool IsUnaryGrad(const NodeDef& node) {
296   bool is_unary_grad =
297       IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
298       IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
299       IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
300       IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
301   return is_unary_grad;
302 }
303 
IsComparisonOp(const NodeDef & node)304 bool IsComparisonOp(const NodeDef& node) {
305   bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
306                     IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
307                     IsLessEqual(node) || IsNotEqual(node);
308   return is_compare;
309 }
310 
IsReduceOp(const NodeDef & node)311 bool IsReduceOp(const NodeDef& node) {
312   return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
313          IsMin(node) || IsAll(node) || IsAny(node);
314 }
315 
IsBinaryOp(const NodeDef & node)316 bool IsBinaryOp(const NodeDef& node) {
317   bool is_binary =
318       IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
319       IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
320       IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
321       IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
322       IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
323       IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
324   return is_binary;
325 }
326 
NonControlInputs(const NodeDef & node)327 std::vector<int> NonControlInputs(const NodeDef& node) {
328   std::vector<int> pos;
329   for (int i = 0; i < node.input_size(); i++) {
330     if (!IsControlInput(node.input(i))) {
331       pos.push_back(i);
332     }
333   }
334   return pos;
335 }
336 
DataInputPosConcat(const NodeDef & node)337 std::vector<int> DataInputPosConcat(const NodeDef& node) {
338   int n = node.attr().at("N").i();
339   std::vector<int> input_pos;
340   int start = (IsConcatV1(node)) ? 1 : 0;
341   int end = start + n;
342   for (int i = start; i < end; i++) {
343     input_pos.push_back(i);
344   }
345   return input_pos;
346 }
347 
DataInputPos(const NodeDef & node)348 std::vector<int> DataInputPos(const NodeDef& node) {
349   if (IsSplit(node) || IsHistogramSummary(node)) {
350     return {1};
351   }
352   if (IsStridedSliceGrad(node)) {
353     return {4};
354   }
355   if (IsBinaryOp(node) || IsUnaryGrad(node)) {
356     return {0, 1};
357   }
358   if (IsBetainc(node) || IsSelect(node)) {
359     return {0, 1, 2};
360   }
361   if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node) || IsMerge(node)) {
362     return NonControlInputs(node);
363   }
364   if (IsConcat(node)) {
365     return DataInputPosConcat(node);
366   }
367   if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
368     return {0};
369   }
370   return {};
371 }
372 
IsHostMemory(const NodeDef & node,int output_port)373 bool IsHostMemory(const NodeDef& node, int output_port) {
374   DeviceNameUtils::ParsedName parsed_name;
375   if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
376     DeviceType device_type(parsed_name.type);
377     Status s = FindKernelDef(device_type, node, nullptr, nullptr);
378     if (s.ok()) {
379       tensorflow::MemoryTypeVector in_mtypes;
380       tensorflow::MemoryTypeVector out_mtypes;
381       s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
382                                          node, &in_mtypes, &out_mtypes);
383       if (s.ok()) {
384         if (out_mtypes[output_port] == HOST_MEMORY) {
385           return true;
386         }
387       }
388     } else {
389       return true;
390     }
391   }
392   return false;
393 }
394 
395 class GraphProcessor {
396  public:
GraphProcessor(const GraphProperties & graph_properties,const VirtualPlacer & virtual_placer,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,NodeMap * node_map)397   GraphProcessor(const GraphProperties& graph_properties,
398                  const VirtualPlacer& virtual_placer,
399                  const std::unordered_set<string>& nodes_to_preserve,
400                  GraphDef* graph, NodeMap* node_map)
401       : graph_properties_(graph_properties),
402         virtual_placer_(virtual_placer),
403         nodes_to_preserve_(nodes_to_preserve),
404         graph_(graph),
405         node_map_(node_map) {}
406 
407  protected:
AddNodePermConst(const string & name,const string & device,const std::vector<int> & permutation)408   NodeDef* AddNodePermConst(const string& name, const string& device,
409                             const std::vector<int>& permutation) {
410     NodeDef* node = graph_->add_node();
411     node_map_->AddNode(name, node);
412     node->set_name(name);
413     node->set_op("Const");
414     AttrValue attr_data_type;
415     attr_data_type.set_type(DT_INT32);
416     node->mutable_attr()->insert({"dtype", attr_data_type});
417     AttrValue attr_tensor;
418     Tensor tensor(DT_INT32, TensorShape({4}));
419     for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
420       tensor.flat<int>()(i) = permutation[i];
421     }
422     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
423     node->mutable_attr()->insert({"value", attr_tensor});
424     string device_name;
425     if (device.empty()) {
426       device_name = virtual_placer_.get_canonical_device_name(*node);
427     } else {
428       device_name = device;
429     }
430     node->set_device(device_name);
431     return node;
432   }
433 
AddNodeConstScalar(const string & name,const string & device,DataType dtype,int value)434   NodeDef* AddNodeConstScalar(const string& name, const string& device,
435                               DataType dtype, int value) {
436     NodeDef* node = graph_->add_node();
437     node_map_->AddNode(name, node);
438     node->set_name(name);
439     node->set_op("Const");
440     AttrValue attr_data_type;
441     attr_data_type.set_type(dtype);
442     node->mutable_attr()->insert({"dtype", attr_data_type});
443     AttrValue attr_tensor;
444     Tensor tensor(dtype, TensorShape({}));
445     tensor.scalar<int>()() = value;
446     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
447     node->mutable_attr()->insert({"value", attr_tensor});
448     string device_name;
449     if (device.empty()) {
450       device_name = virtual_placer_.get_canonical_device_name(*node);
451     } else {
452       device_name = device;
453     }
454     node->set_device(device_name);
455     return node;
456   }
457 
LayoutOptimizerNode(const string & base_name)458   string LayoutOptimizerNode(const string& base_name) {
459     return strings::StrCat(base_name, "-", kSuffix);
460   }
461 
462   const GraphProperties& graph_properties_;
463   const VirtualPlacer& virtual_placer_;
464   const std::unordered_set<string>& nodes_to_preserve_;
465   GraphDef* graph_;
466   NodeMap* node_map_;
467 };
468 
469 struct OptimizeContext {
OptimizeContexttensorflow::grappler::__anonbd4a17fb0111::OptimizeContext470   OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
471                   const GraphProperties& graph_properties,
472                   const VirtualPlacer& virtual_placer,
473                   const std::unordered_set<string>& nodes_to_preserve,
474                   bool is_in_frame)
475       : graph(graph),
476         node(node),
477         node_map(node_map),
478         graph_properties(graph_properties),
479         virtual_placer(virtual_placer),
480         nodes_to_preserve(nodes_to_preserve),
481         is_in_frame(is_in_frame) {}
482   GraphDef* graph;
483   NodeDef* node;
484   NodeMap* node_map;
485   const GraphProperties& graph_properties;
486   const VirtualPlacer& virtual_placer;
487   const std::unordered_set<string>& nodes_to_preserve;
488   bool is_in_frame;
489 };
490 
491 class NodeProcessor : public GraphProcessor {
492  public:
NodeProcessor(const OptimizeContext & opt_cxt)493   explicit NodeProcessor(const OptimizeContext& opt_cxt)
494       : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
495                        opt_cxt.nodes_to_preserve, opt_cxt.graph,
496                        opt_cxt.node_map),
497         node_(opt_cxt.node),
498         is_in_frame_(opt_cxt.is_in_frame) {}
~NodeProcessor()499   virtual ~NodeProcessor() {}
ConvertNode()500   virtual Status ConvertNode() {
501     if (ShouldProcess()) {
502       UpdateAttrDataFormat();
503       UpdateAttrKSize();
504       UpdateAttrStrides();
505       UpdateAttrDilations();
506       UpdateAttrExplicitPaddings();
507       UpdateAttrShape();
508       TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
509       TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
510       TF_RETURN_IF_ERROR(CustomizedProcessing());
511     }
512     return Status::OK();
513   }
514 
515  protected:
IsPortDimsN(const NodeDef & node,int port,int n) const516   bool IsPortDimsN(const NodeDef& node, int port, int n) const {
517     if (node.attr().find("_output_shapes") != node.attr().end()) {
518       if (node.attr().at("_output_shapes").list().shape_size() > port) {
519         auto shape = node.attr().at("_output_shapes").list().shape(port);
520         if (shape.unknown_rank()) {
521           return false;
522         }
523         if (shape.dim_size() == n) {
524           return true;
525         }
526       }
527     }
528     return false;
529   }
530 
IsPortZeroDimsN(const NodeDef & node,int n) const531   bool IsPortZeroDimsN(const NodeDef& node, int n) const {
532     return IsPortDimsN(node, 0, n);
533   }
534 
IsPortZeroDimsFour(const NodeDef & node) const535   bool IsPortZeroDimsFour(const NodeDef& node) const {
536     return NodeProcessor::IsPortZeroDimsN(node, 4) ||
537            IsTransposeNCHWToNHWC(node.name());
538   }
539 
IsPortDimsFour(const NodeDef & node,int port) const540   bool IsPortDimsFour(const NodeDef& node, int port) const {
541     return NodeProcessor::IsPortDimsN(node, port, 4) ||
542            IsTransposeNCHWToNHWC(node.name());
543   }
544 
IsNHWC() const545   bool IsNHWC() const {
546     if (node_->attr().find("data_format") != node_->attr().end()) {
547       if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
548         return true;
549       }
550     }
551     return false;
552   }
553 
HasOutputs() const554   bool HasOutputs() const {
555     auto outputs = node_map_->GetOutputs(node_->name());
556     return !outputs.empty();
557   }
558 
HasAttribute(const NodeDef & node,const string & attr) const559   Status HasAttribute(const NodeDef& node, const string& attr) const {
560     if (node.attr().find(attr) == node.attr().end()) {
561       return Status(error::INVALID_ARGUMENT,
562                     strings::StrCat("Missing attribute ", attr));
563     }
564     return Status::OK();
565   }
566 
MustPreserve() const567   bool MustPreserve() const {
568     return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
569   }
570 
IsOnGPU() const571   bool IsOnGPU() const {
572     string device_name;
573     if (node_->device().empty()) {
574       device_name = virtual_placer_.get_canonical_device_name(*node_);
575     } else {
576       device_name = node_->device();
577     }
578     string device;
579     string not_used;
580     if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
581         str_util::StrContains(str_util::Lowercase(device),
582                               str_util::Lowercase(DEVICE_GPU))) {
583       return true;
584     }
585     return false;
586   }
587 
ShouldProcess() const588   virtual bool ShouldProcess() const {
589     return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
590            HasOutputs() && IsOnGPU();
591   }
592 
UpdateAttrShape()593   virtual void UpdateAttrShape() {
594     if (node_->attr().find("_output_shapes") != node_->attr().end()) {
595       for (const auto& pos : GetOutputPos()) {
596         auto shape = node_->mutable_attr()
597                          ->at("_output_shapes")
598                          .mutable_list()
599                          ->mutable_shape(pos);
600         if (shape->dim_size() == 4) {
601           int64 h = shape->dim(1).size();
602           int64 w = shape->dim(2).size();
603           int64 c = shape->dim(3).size();
604           shape->mutable_dim(1)->set_size(c);
605           shape->mutable_dim(2)->set_size(h);
606           shape->mutable_dim(3)->set_size(w);
607         }
608       }
609     }
610   }
611 
UpdateAttrValueOfInput(int input_index,bool permute)612   Status UpdateAttrValueOfInput(int input_index, bool permute) {
613     auto input_node = node_map_->GetNode(node_->input(input_index));
614     // We created a copy of the node, so that we don't modify the original node,
615     // which might be used elsewhere. Note that this copy also copies the
616     // control dependency input in the case this node is inside a loop,
617     // to ensure added_node is in the same frame with node_.
618     NodeDef* added_node = graph_->add_node();
619     *added_node = *input_node;
620     string base_name = strings::StrCat(node_->name(), "-", input_index);
621     string node_name = LayoutOptimizerNode(base_name);
622     added_node->set_name(node_name);
623     *node_->mutable_input(input_index) = node_name;
624     node_map_->AddNode(node_name, added_node);
625     node_map_->AddOutput(node_name, node_->name());
626     return UpdateAttrValue(added_node, permute);
627   }
628 
GetInputPos() const629   virtual std::vector<int> GetInputPos() const { return {0}; }
630 
GetOutputPos() const631   virtual std::set<int> GetOutputPos() const {
632     // For most nodes, no need to process control nodes or nodes that use an
633     // output other than the first output: only the first output is of
634     // 4D NCHW/NHWC format and thus relevant here.
635     std::set<int> output_pos = {0};
636     return output_pos;
637   }
638 
AddLayoutTransposeToInputs()639   virtual Status AddLayoutTransposeToInputs() {
640     std::vector<int> input_pos = GetInputPos();
641     for (const auto& pos : input_pos) {
642       string node_name = LayoutOptimizerNode(
643           strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
644       DataType dtype =
645           graph_properties_.GetInputProperties(node_->name())[pos].dtype();
646       auto input_node = node_map_->GetNode(node_->input(pos));
647       TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
648       string const_name = GetOrAddNodePermNHWCToNCHW(pos);
649       int output_pos;
650       ParseNodeName(node_->input(pos), &output_pos);
651       AddNodeTranspose(
652           node_name, node_->input(pos), const_name, dtype,
653           input_node->attr().at("_output_shapes").list().shape(output_pos),
654           true);
655       node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
656                               node_name);
657       node_map_->AddOutput(node_name, node_->name());
658       *node_->mutable_input(pos) = node_name;
659     }
660     return Status::OK();
661   }
662 
AddTransformToOutputs(const string & op)663   Status AddTransformToOutputs(const string& op) {
664     auto outputs = node_map_->GetOutputs(node_->name());
665     string const_name = GetOrAddNodePermNCHWToNHWC();
666     int output_count = 0;
667     for (const auto& output : outputs) {
668       int connections = 0;
669       int connections_removed = 0;
670       for (int i = 0; i < output->input_size(); i++) {
671         auto& input = *output->mutable_input(i);
672         int input_port;
673         string input_name = ParseNodeName(input, &input_port);
674         auto output_pos = GetOutputPos();
675         if (input_name == node_->name()) {
676           connections++;
677           if (output_pos.find(input_port) != output_pos.end()) {
678             connections_removed++;
679             string added_node_base_name =
680                 strings::StrCat(node_->name(), "-", output_count, "-", i);
681             string added_node_name;
682             DataType dtype =
683                 graph_properties_.GetOutputProperties(node_->name())[input_port]
684                     .dtype();
685             if (op == "Transpose") {
686               added_node_name = LayoutOptimizerNode(strings::StrCat(
687                   added_node_base_name, "-", kTransposeNCHWToNHWC));
688               TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
689               AddNodeTranspose(
690                   added_node_name, input, const_name, dtype,
691                   node_->attr().at("_output_shapes").list().shape(input_port),
692                   false);
693             } else if (op == "DataFormatVecPermute") {
694               added_node_name = LayoutOptimizerNode(strings::StrCat(
695                   added_node_base_name, "-", kVecPermuteNCHWToNHWC));
696               AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
697             } else {
698               return errors::InvalidArgument("Unsupported op type: ", op);
699             }
700             input = added_node_name;
701             node_map_->AddOutput(node_->name(), added_node_name);
702             node_map_->AddOutput(added_node_name, output->name());
703           }
704         }
705       }
706       if (connections == connections_removed) {
707         node_map_->RemoveOutput(node_->name(), output->name());
708       }
709       output_count++;
710     }
711     return Status::OK();
712   }
713 
AddLayoutTransposeToOutputs()714   virtual Status AddLayoutTransposeToOutputs() {
715     return AddTransformToOutputs("Transpose");
716   }
717 
CustomizedProcessing()718   virtual Status CustomizedProcessing() { return Status::OK(); }
719 
UpdateOrTransformParamInput(int param_index,const string & op,DataType dtype)720   Status UpdateOrTransformParamInput(int param_index, const string& op,
721                                      DataType dtype) {
722     auto param_node = node_map_->GetNode(node_->input(param_index));
723     bool permute = (op == "DataFormatVecPermute") ? true : false;
724     if (IsConstant(*param_node)) {
725       TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
726     } else {
727       AddDataFormatTranformToParamInput(op, param_index, dtype);
728     }
729     return Status::OK();
730   }
731 
732   NodeDef* node_;
733   bool is_in_frame_;
734 
735  private:
UpdateAttrKSize()736   void UpdateAttrKSize() {
737     if (node_->attr().find("ksize") != node_->attr().end()) {
738       auto list = node_->mutable_attr()->at("ksize").mutable_list();
739       UpdateTuple(list);
740     }
741   }
742 
UpdateAttrStrides()743   void UpdateAttrStrides() {
744     if (node_->attr().find("strides") != node_->attr().end()) {
745       auto list = node_->mutable_attr()->at("strides").mutable_list();
746       UpdateTuple(list);
747     }
748   }
749 
UpdateAttrDilations()750   void UpdateAttrDilations() {
751     if (node_->attr().find("dilations") != node_->attr().end()) {
752       auto list = node_->mutable_attr()->at("dilations").mutable_list();
753       UpdateTuple(list);
754     }
755   }
756 
UpdateAttrExplicitPaddings()757   void UpdateAttrExplicitPaddings() {
758     if (node_->attr().find("explicit_paddings") != node_->attr().end()) {
759       auto list = node_->mutable_attr()->at("explicit_paddings").mutable_list();
760       int size = list->i_size();
761       if (size == 8) {
762         int64 height_before = list->i(2);
763         int64 height_after = list->i(3);
764         int64 width_before = list->i(4);
765         int64 width_after = list->i(5);
766         list->set_i(2, 0);
767         list->set_i(3, 0);
768         list->set_i(4, height_before);
769         list->set_i(5, height_after);
770         list->set_i(6, width_before);
771         list->set_i(7, width_after);
772       } else if (size != 0) {
773         LOG(ERROR) << "Cannot handle explicit_paddings attribute of size "
774                    << size;
775       }
776     }
777   }
778 
UpdateAttrDataFormat()779   void UpdateAttrDataFormat() {
780     if (node_->attr().find("data_format") != node_->attr().end()) {
781       if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
782         string* data_format =
783             node_->mutable_attr()->at("data_format").mutable_s();
784         *data_format = "NCHW";
785       }
786     }
787   }
788 
UpdateAttrValue(NodeDef * node,bool permute)789   Status UpdateAttrValue(NodeDef* node, bool permute) {
790     TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
791     Tensor tensor;
792     auto success =
793         tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
794     if (!success) {
795       LOG(ERROR) << "Failed to parse TensorProto.";
796     }
797 
798     if (permute) {
799       if (tensor.dims() == 1) {
800         if (tensor.flat<int>().size() == 4) {
801           int c = tensor.flat<int>()(3);
802           tensor.flat<int>()(3) = tensor.flat<int>()(2);
803           tensor.flat<int>()(2) = tensor.flat<int>()(1);
804           tensor.flat<int>()(1) = c;
805         } else {
806           return Status(error::INVALID_ARGUMENT,
807                         strings::StrCat("Unsupported tensor size: ",
808                                         tensor.flat<int>().size()));
809         }
810       } else if (tensor.dims() == 2) {
811         for (int i = 0; i < 2; i++) {
812           int c = tensor.matrix<int>()(3, i);
813           tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
814           tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
815           tensor.matrix<int>()(1, i) = c;
816         }
817       } else {
818         return Status(
819             error::INVALID_ARGUMENT,
820             strings::StrCat("Unsupported dimension size: ", tensor.dims()));
821       }
822     } else {
823       for (int i = 0; i < tensor.flat<int>().size(); i++) {
824         int value = tensor.flat<int>()(i);
825         value = (value >= 0) ? value : value + 4;
826         if (value == 1 || value == 2) {
827           value = value + 1;
828         } else if (value == 3) {
829           value = 1;
830         }
831         tensor.flat<int>()(i) = value;
832       }
833     }
834 
835     if (tensor.dims() == 0) {
836       tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
837     } else {
838       tensor.AsProtoTensorContent(
839           node->mutable_attr()->at({"value"}).mutable_tensor());
840     }
841     return Status::OK();
842   }
843 
AddNodeTranspose(const string & node_name,const string & input_name,const string & const_name,DataType data_type,const TensorShapeProto & input_shape,bool NHWCToNCHW)844   NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
845                             const string& const_name, DataType data_type,
846                             const TensorShapeProto& input_shape,
847                             bool NHWCToNCHW) {
848     NodeDef* node = graph_->add_node();
849     node_map_->AddNode(node_name, node);
850     node->set_name(node_name);
851     *node->add_input() = input_name;
852     *node->add_input() = const_name;
853     node->set_op("Transpose");
854     node->set_device(node_->device());
855     AttrValue attr_data_type;
856     attr_data_type.set_type(data_type);
857     node->mutable_attr()->insert({"T", attr_data_type});
858     AttrValue attr_data_type_perm;
859     attr_data_type_perm.set_type(DT_INT32);
860     node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
861     if (!input_shape.unknown_rank()) {
862       AttrValue attr_output_shape;
863       auto output_shape = attr_output_shape.mutable_list()->add_shape();
864       if (NHWCToNCHW) {
865         output_shape->add_dim()->set_size(input_shape.dim(0).size());
866         output_shape->add_dim()->set_size(input_shape.dim(3).size());
867         output_shape->add_dim()->set_size(input_shape.dim(1).size());
868         output_shape->add_dim()->set_size(input_shape.dim(2).size());
869       } else {
870         output_shape->add_dim()->set_size(input_shape.dim(0).size());
871         output_shape->add_dim()->set_size(input_shape.dim(2).size());
872         output_shape->add_dim()->set_size(input_shape.dim(3).size());
873         output_shape->add_dim()->set_size(input_shape.dim(1).size());
874       }
875       node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
876     }
877     return node;
878   }
879 
AddNodePermNHWCToNCHW(const string & base_name,const string & depended_node,const string & device)880   NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
881                                  const string& depended_node,
882                                  const string& device) {
883     string name =
884         LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
885     auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
886     // This is to ensure the transpose node and the const node are in the
887     // same frame.
888     *const_node->add_input() = AsControlDependency(depended_node);
889     return const_node;
890   }
891 
AddNodePermNCHWToNHWC(const string & base_name,const string & depended_node,const string & device)892   NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
893                                  const string& depended_node,
894                                  const string& device) {
895     auto const_node = AddNodePermConst(
896         LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC)),
897         device, {0, 2, 3, 1});
898     // This is to ensure the transpose node and the const node are in the same
899     // frame.
900     *const_node->add_input() = AsControlDependency(depended_node);
901     return const_node;
902   }
903 
GetOrAddNodePermNHWCToNCHW(int pos)904   string GetOrAddNodePermNHWCToNCHW(int pos) {
905     string const_name;
906     if (is_in_frame_) {
907       string base_name = strings::StrCat(node_->name(), "-", pos);
908       string input = NodeName(node_->input(pos));
909       string depended_node;
910       if (!IsTransposeNCHWToNHWC(input)) {
911         depended_node = input;
912       } else {
913         auto input_node = node_map_->GetNode(input);
914         depended_node = NodeName(input_node->input(0));
915       }
916       auto const_node =
917           AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
918       const_name = const_node->name();
919     } else {
920       const_name = LayoutOptimizerNode(kPermNHWCToNCHW);
921     }
922     return const_name;
923   }
924 
GetOrAddNodePermNCHWToNHWC()925   string GetOrAddNodePermNCHWToNHWC() {
926     string const_name;
927     if (is_in_frame_) {
928       auto const_node =
929           AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
930       const_name = const_node->name();
931     } else {
932       const_name = LayoutOptimizerNode(kPermNCHWToNHWC);
933     }
934     return const_name;
935   }
936 
UpdateTuple(AttrValue_ListValue * list)937   void UpdateTuple(AttrValue_ListValue* list) {
938     int64 h = list->i(1);
939     int64 w = list->i(2);
940     int64 c = list->i(3);
941     list->set_i(1, c);
942     list->set_i(2, h);
943     list->set_i(3, w);
944   }
945 
IsInputOnHost(const string & input_name) const946   bool IsInputOnHost(const string& input_name) const {
947     string device = node_->device();
948     DeviceNameUtils::ParsedName parsed_name;
949     if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
950       if (parsed_name.type != "CPU") {
951         NodeDef* input = node_map_->GetNode(input_name);
952         int port;
953         ParseNodeName(input_name, &port);
954         if (IsHostMemory(*input, port)) {
955           return true;
956         }
957       }
958     }
959     return false;
960   }
961 
AddNodeDataFormatOp(const string & name,const string & input_name,const string & op,DataType dtype,bool nhwc_to_nchw)962   NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
963                                const string& op, DataType dtype,
964                                bool nhwc_to_nchw) {
965     NodeDef* added_node = graph_->add_node();
966     added_node->set_name(name);
967     added_node->set_op(op);
968     node_map_->AddNode(added_node->name(), added_node);
969     added_node->set_device(node_->device());
970     // The inputs of a DataFormat op could be in host memory for ops such as
971     // Reshape. In such cases, run the kernel on the host too.
972     if (IsInputOnHost(input_name)) {
973       AttrValue attr_kernel;
974       attr_kernel.set_s("host");
975       added_node->mutable_attr()->insert({"_kernel", attr_kernel});
976     }
977     AttrValue attr_data_type;
978     attr_data_type.set_type(dtype);
979     added_node->mutable_attr()->insert({"T", attr_data_type});
980     string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
981     string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
982     AttrValue attr_format;
983     attr_format.set_s(src_format);
984     added_node->mutable_attr()->insert({"src_format", attr_format});
985     attr_format.set_s(dst_format);
986     added_node->mutable_attr()->insert({"dst_format", attr_format});
987     *added_node->add_input() = input_name;
988     return added_node;
989   }
990 
AddDataFormatTranformToParamInput(const string & op,int input_pos,DataType dtype)991   void AddDataFormatTranformToParamInput(const string& op, int input_pos,
992                                          DataType dtype) {
993     string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
994                                                    : kDimMapNHWCToNCHW;
995     string name = LayoutOptimizerNode(
996         strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
997     auto added_node =
998         AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
999     *node_->mutable_input(input_pos) = added_node->name();
1000     node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
1001                             added_node->name());
1002     node_map_->AddOutput(added_node->name(), node_->name());
1003   }
1004 };
1005 
1006 class AvgPoolGradProcessor : public NodeProcessor {
1007  public:
AvgPoolGradProcessor(const OptimizeContext & opt_cxt)1008   explicit AvgPoolGradProcessor(const OptimizeContext& opt_cxt)
1009       : NodeProcessor(opt_cxt) {}
1010 
1011  protected:
GetInputPos() const1012   std::vector<int> GetInputPos() const override { return {1}; }
CustomizedProcessing()1013   Status CustomizedProcessing() override {
1014     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
1015   }
1016 };
1017 
1018 class BiasAddGradProcessor : public NodeProcessor {
1019  public:
BiasAddGradProcessor(const OptimizeContext & opt_cxt)1020   explicit BiasAddGradProcessor(const OptimizeContext& opt_cxt)
1021       : NodeProcessor(opt_cxt) {}
1022 
1023  protected:
ShouldProcess() const1024   bool ShouldProcess() const override {
1025     if (MustPreserve()) {
1026       return false;
1027     }
1028     if (!IsOnGPU()) {
1029       return false;
1030     }
1031     auto input = node_map_->GetNode(node_->input(0));
1032     if (input) {
1033       int port;
1034       ParseNodeName(node_->input(0), &port);
1035       if (IsNHWC() && IsPortDimsFour(*input, port)) {
1036         return true;
1037       }
1038     }
1039     return false;
1040   }
1041 
AddLayoutTransposeToOutputs()1042   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1043 };
1044 
1045 class Conv2DProcessor : public NodeProcessor {
1046  public:
Conv2DProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1047   Conv2DProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1048       : NodeProcessor(opt_cxt), no_gemm_(no_gemm) {}
1049 
1050  protected:
ShouldProcess() const1051   bool ShouldProcess() const override {
1052     return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
1053            HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU();
1054   }
1055 
GetShape(const string & input_name) const1056   TensorShapeProto GetShape(const string& input_name) const {
1057     string node_name;
1058     int output_pos;
1059     node_name = ParseNodeName(input_name, &output_pos);
1060     NodeDef* node = node_map_->GetNode(node_name);
1061     if (node->attr().find("_output_shapes") != node->attr().end()) {
1062       return node->attr().at("_output_shapes").list().shape(output_pos);
1063     }
1064     TensorShapeProto shape;
1065     return shape;
1066   }
1067 
IsStrideOne() const1068   bool IsStrideOne() const {
1069     if (node_->attr().find("strides") != node_->attr().end()) {
1070       auto list = node_->attr().at("strides").list();
1071       return list.i(1) == 1 && list.i(2) == 1;
1072     }
1073     return false;
1074   }
1075 
IsValidPadding() const1076   bool IsValidPadding() const {
1077     if (node_->attr().find("padding") != node_->attr().end()) {
1078       auto padding = node_->attr().at("padding").s();
1079       return padding == "VALID";
1080     }
1081     return false;
1082   }
1083 
1084   // The logic inside this function is based on the internal implementation of
1085   // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
1086   // needs to be updated accordingly if the internal implementation changes.
IsGemmUsed(const TensorShapeProto & filter_shape,const TensorShapeProto & input_shape) const1087   bool IsGemmUsed(const TensorShapeProto& filter_shape,
1088                   const TensorShapeProto& input_shape) const {
1089     if (filter_shape.dim_size() == 4) {
1090       if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
1091           IsStrideOne()) {
1092         return true;
1093       }
1094     }
1095     if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
1096       if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
1097           input_shape.dim(2).size() == filter_shape.dim(1).size() &&
1098           IsValidPadding()) {
1099         return true;
1100       }
1101     }
1102     return false;
1103   }
1104 
IsGemmUsed() const1105   virtual bool IsGemmUsed() const {
1106     auto filter_shape = GetShape(node_->input(1));
1107     auto input_shape = GetShape(node_->input(0));
1108     return IsGemmUsed(filter_shape, input_shape);
1109   }
1110 
1111   bool no_gemm_;
1112 };
1113 
1114 class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
1115  public:
Conv2DBackpropFilterProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1116   Conv2DBackpropFilterProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1117       : Conv2DProcessor(opt_cxt, no_gemm) {}
1118 
1119  protected:
IsGemmUsed() const1120   bool IsGemmUsed() const override {
1121     auto filter_shape = GetShape(node_->name());
1122     auto input_shape = GetShape(node_->input(0));
1123     return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
1124   }
1125 
GetInputPos() const1126   std::vector<int> GetInputPos() const override { return {0, 2}; }
1127 
AddLayoutTransposeToOutputs()1128   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1129   // No need to update output shape, as it is always of shape
1130   // [filter_height, filter_width, in_channels, out_channels], regardless of
1131   // whether NCHW or NHWC is used.
UpdateAttrShape()1132   void UpdateAttrShape() override {}
1133 };
1134 
1135 class Conv2DBackpropInputProcessor : public Conv2DProcessor {
1136  public:
Conv2DBackpropInputProcessor(const OptimizeContext & opt_cxt,bool no_gemm)1137   Conv2DBackpropInputProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
1138       : Conv2DProcessor(opt_cxt, no_gemm) {}
1139 
1140  protected:
IsGemmUsed() const1141   bool IsGemmUsed() const override {
1142     auto filter_shape = GetShape(node_->input(1));
1143     auto input_shape = GetShape(node_->name());
1144     return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
1145   }
1146 
GetInputPos() const1147   std::vector<int> GetInputPos() const override { return {2}; }
1148 
CustomizedProcessing()1149   Status CustomizedProcessing() override {
1150     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
1151   }
1152 };
1153 
1154 class FusedBatchNormGradProcessor : public NodeProcessor {
1155  public:
FusedBatchNormGradProcessor(const OptimizeContext & opt_cxt)1156   explicit FusedBatchNormGradProcessor(const OptimizeContext& opt_cxt)
1157       : NodeProcessor(opt_cxt) {}
1158 
1159  protected:
ShouldProcess() const1160   bool ShouldProcess() const override {
1161     return NodeProcessor::ShouldProcess() && IsTraining();
1162   }
1163 
GetInputPos() const1164   std::vector<int> GetInputPos() const override { return {0, 1}; }
1165 
1166  private:
IsTraining() const1167   bool IsTraining() const {
1168     if (node_->attr().find("is_training") != node_->attr().end()) {
1169       if (node_->attr().at("is_training").b()) {
1170         return true;
1171       }
1172     }
1173     return false;
1174   }
1175 };
1176 
1177 class MaxPoolGradProcessor : public NodeProcessor {
1178  public:
MaxPoolGradProcessor(const OptimizeContext & opt_cxt)1179   explicit MaxPoolGradProcessor(const OptimizeContext& opt_cxt)
1180       : NodeProcessor(opt_cxt) {}
1181 
1182  protected:
GetInputPos() const1183   std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
1184 };
1185 
1186 class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
1187  public:
MaxPoolGradV2Processor(const OptimizeContext & opt_cxt)1188   explicit MaxPoolGradV2Processor(const OptimizeContext& opt_cxt)
1189       : MaxPoolGradProcessor(opt_cxt) {}
1190 
1191  protected:
CustomizedProcessing()1192   Status CustomizedProcessing() override {
1193     for (int i = 3; i <= 4; i++) {
1194       TF_RETURN_IF_ERROR(
1195           UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
1196     }
1197     return Status::OK();
1198   }
1199 };
1200 
1201 class MaxPoolV2Processor : public NodeProcessor {
1202  public:
MaxPoolV2Processor(const OptimizeContext & opt_cxt)1203   explicit MaxPoolV2Processor(const OptimizeContext& opt_cxt)
1204       : NodeProcessor(opt_cxt) {}
1205 
1206  protected:
ShouldProcess() const1207   bool ShouldProcess() const override {
1208     // We check data_input's shape instead, because the shape inference of
1209     // MaxPoolV2 is not able to infer the shape when ksize or strides is not
1210     // constant.
1211     auto data_input = node_map_->GetNode(node_->input(0));
1212     int port;
1213     ParseNodeName(node_->input(0), &port);
1214     return !MustPreserve() && IsNHWC() && IsPortDimsFour(*data_input, port) &&
1215            HasOutputs() && IsOnGPU();
1216   }
1217 
CustomizedProcessing()1218   Status CustomizedProcessing() override {
1219     for (int i = 1; i <= 2; i++) {
1220       TF_RETURN_IF_ERROR(
1221           UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
1222     }
1223     return Status::OK();
1224   }
1225 };
1226 
1227 class AgnosticNodeProcessor : public NodeProcessor {
1228  public:
AgnosticNodeProcessor(const OptimizeContext & opt_cxt)1229   explicit AgnosticNodeProcessor(const OptimizeContext& opt_cxt)
1230       : NodeProcessor(opt_cxt) {}
1231 
1232  protected:
ShouldProcess() const1233   bool ShouldProcess() const override {
1234     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1235            IsNodeAfterNCHWToNHWC() && IsOnGPU();
1236   }
1237 
IsNodeAfterNCHWToNHWC(const NodeDef & node) const1238   bool IsNodeAfterNCHWToNHWC(const NodeDef& node) const {
1239     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1240     std::deque<NodeDef*> queue;
1241     auto data_node_pos = DataInputPos(node);
1242     std::unordered_set<string> visited;
1243     for (const auto& pos : data_node_pos) {
1244       auto input_node = node_map_->GetNode(node.input(pos));
1245       queue.push_back(input_node);
1246       visited.insert(input_node->name());
1247     }
1248     // The code will exit this while loop in one iteration in most cases, as the
1249     // graph is already topologically sorted.
1250     while (!queue.empty()) {
1251       NodeDef* current_node = queue.front();
1252       queue.pop_front();
1253       if (IsTransposeNCHWToNHWC(current_node->name()) ||
1254           IsDimMapNCHWToNHWC(current_node->name()) ||
1255           IsVecPermuteNCHWToNHWC(current_node->name())) {
1256         return true;
1257       }
1258       // We only continue searching if the path is connected through
1259       // format-agnostic nodes.
1260       if (ops_format_agnostic.find(current_node->op()) !=
1261           ops_format_agnostic.end()) {
1262         auto current_node_pos = DataInputPos(*current_node);
1263         for (const auto& pos : current_node_pos) {
1264           auto input_node = node_map_->GetNode(current_node->input(pos));
1265           if (visited.find(input_node->name()) == visited.end()) {
1266             queue.push_back(input_node);
1267             visited.insert(input_node->name());
1268           }
1269         }
1270       }
1271     }
1272     return false;
1273   }
1274 
IsNodeAfterNCHWToNHWC() const1275   bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
1276 };
1277 
1278 class AddNProcessor : public AgnosticNodeProcessor {
1279  public:
AddNProcessor(const OptimizeContext & opt_cxt)1280   explicit AddNProcessor(const OptimizeContext& opt_cxt)
1281       : AgnosticNodeProcessor(opt_cxt) {}
1282 
1283  protected:
GetInputPos() const1284   std::vector<int> GetInputPos() const override {
1285     return NonControlInputs(*node_);
1286   }
1287 };
1288 
1289 class BinaryOpProcessor : public AgnosticNodeProcessor {
1290  public:
BinaryOpProcessor(const OptimizeContext & opt_cxt)1291   explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
1292       : AgnosticNodeProcessor(opt_cxt) {}
1293 
1294  protected:
ShouldProcess() const1295   bool ShouldProcess() const override {
1296     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1297            IsNodeAfterNCHWToNHWC() &&
1298            (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
1299             IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4) ||
1300             IsNDOperateWithMD(1, 4)) &&
1301            IsOnGPU();
1302   }
1303 
GetInputPos() const1304   std::vector<int> GetInputPos() const override {
1305     std::vector<int> input_pos;
1306     auto input0 = node_map_->GetNode(node_->input(0));
1307     auto input1 = node_map_->GetNode(node_->input(1));
1308     int input0_port;
1309     ParseNodeName(node_->input(0), &input0_port);
1310     int input1_port;
1311     ParseNodeName(node_->input(1), &input1_port);
1312     if (IsPortDimsFour(*input0, input0_port)) {
1313       input_pos.push_back(0);
1314     }
1315     if (IsPortDimsFour(*input1, input1_port)) {
1316       input_pos.push_back(1);
1317     }
1318     return input_pos;
1319   }
1320 
IsNDOperateWithMD(int n,int m) const1321   bool IsNDOperateWithMD(int n, int m) const {
1322     auto input0 = node_map_->GetNode(node_->input(0));
1323     auto input1 = node_map_->GetNode(node_->input(1));
1324     int input0_port;
1325     ParseNodeName(node_->input(0), &input0_port);
1326     int input1_port;
1327     ParseNodeName(node_->input(1), &input1_port);
1328 
1329     if (input0 && input1) {
1330       bool input0_is_n = (n == 4) ? IsPortDimsFour(*input0, input0_port)
1331                                   : IsPortDimsN(*input0, input0_port, n);
1332       bool input1_is_m = (m == 4) ? IsPortDimsFour(*input1, input1_port)
1333                                   : IsPortDimsN(*input1, input1_port, m);
1334       return input0_is_n && input1_is_m;
1335     }
1336     return false;
1337   }
1338 
AddNodeShapeConst(const string & name,int num_channels,const string & depended_node)1339   NodeDef* AddNodeShapeConst(const string& name, int num_channels,
1340                              const string& depended_node) {
1341     NodeDef* node = graph_->add_node();
1342     node_map_->AddNode(name, node);
1343     node->set_name(name);
1344     node->set_op("Const");
1345     node->set_device(node_->device());
1346     AttrValue attr_data_type;
1347     attr_data_type.set_type(DT_INT32);
1348     node->mutable_attr()->insert({"dtype", attr_data_type});
1349 
1350     AttrValue attr_tensor;
1351     Tensor tensor(DT_INT32, TensorShape({4}));
1352     std::vector<int> shape = {1, num_channels, 1, 1};
1353     for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1354       tensor.flat<int>()(i) = shape[i];
1355     }
1356     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1357     node->mutable_attr()->insert({"value", attr_tensor});
1358     if (is_in_frame_) {
1359       // This is to ensure the transpose node and the const node are in the
1360       // same frame.
1361       *node->add_input() = AsControlDependency(depended_node);
1362     }
1363     return node;
1364   }
1365 
AddNodeReshape(const string & node_name,const string & input_name,const string & shape_const_node_name,DataType data_type)1366   NodeDef* AddNodeReshape(const string& node_name, const string& input_name,
1367                           const string& shape_const_node_name,
1368                           DataType data_type) {
1369     NodeDef* node = graph_->add_node();
1370     node_map_->AddNode(node_name, node);
1371     node->set_name(node_name);
1372     *node->add_input() = input_name;
1373     *node->add_input() = shape_const_node_name;
1374     node->set_op("Reshape");
1375     node->set_device(node_->device());
1376 
1377     AttrValue attr_type_indices;
1378     attr_type_indices.set_type(DT_INT32);
1379     node->mutable_attr()->insert({"Tshape", attr_type_indices});
1380 
1381     AttrValue attr_type_params;
1382     attr_type_params.set_type(data_type);
1383     node->mutable_attr()->insert({"T", attr_type_params});
1384     return node;
1385   }
1386 
CustomizedProcessing()1387   Status CustomizedProcessing() override {
1388     int vector_index = -1;
1389     if (IsNDOperateWithMD(4, 1)) {
1390       vector_index = 1;
1391     } else if (IsNDOperateWithMD(1, 4)) {
1392       vector_index = 0;
1393     }
1394     if (vector_index != -1) {
1395       string base_name = strings::StrCat(node_->name(), "-", vector_index);
1396       string reshape_node_name = LayoutOptimizerNode(
1397           strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
1398       string shape_const_node_name =
1399           LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
1400       auto input_node = node_map_->GetNode(node_->input(vector_index));
1401       TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
1402       int port;
1403       ParseNodeName(node_->input(vector_index), &port);
1404       int vector_size = input_node->attr()
1405                             .at("_output_shapes")
1406                             .list()
1407                             .shape(port)
1408                             .dim(0)
1409                             .size();
1410       AddNodeShapeConst(shape_const_node_name, vector_size,
1411                         NodeName(node_->input(vector_index)));
1412       TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
1413       AddNodeReshape(reshape_node_name, node_->input(vector_index),
1414                      shape_const_node_name, node_->attr().at("T").type());
1415       node_map_->AddOutput(shape_const_node_name, reshape_node_name);
1416       node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
1417                               node_->name(), reshape_node_name);
1418       node_map_->AddOutput(reshape_node_name, node_->name());
1419       *node_->mutable_input(vector_index) = reshape_node_name;
1420     }
1421     return Status::OK();
1422   }
1423 };
1424 
1425 class ConcatProcessor : public AgnosticNodeProcessor {
1426  public:
ConcatProcessor(const OptimizeContext & opt_cxt)1427   explicit ConcatProcessor(const OptimizeContext& opt_cxt)
1428       : AgnosticNodeProcessor(opt_cxt) {
1429     // For Concat,  the concat axis is the first input; for ConcatV2,
1430     // the last input. Note that if with control inputs, the number of inputs
1431     // is larger than the integer attribute N.
1432     int n = node_->attr().at("N").i();
1433     axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
1434   }
1435 
1436  protected:
GetInputPos() const1437   std::vector<int> GetInputPos() const override {
1438     return DataInputPosConcat(*node_);
1439   }
1440 
CustomizedProcessing()1441   Status CustomizedProcessing() override {
1442     DataType dtype =
1443         (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
1444     return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
1445                                        dtype);
1446   }
1447 
1448   int axis_node_pos_;
1449 };
1450 
1451 class FillProcessor : public AgnosticNodeProcessor {
1452  public:
FillProcessor(const OptimizeContext & opt_cxt)1453   explicit FillProcessor(const OptimizeContext& opt_cxt)
1454       : AgnosticNodeProcessor(opt_cxt) {}
1455 
1456  protected:
GetInputPos() const1457   std::vector<int> GetInputPos() const override { return {}; }
1458 
CustomizedProcessing()1459   Status CustomizedProcessing() override {
1460     DataType dtype = node_->attr().at("index_type").type();
1461     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
1462   }
1463 };
1464 
1465 class HistogramSummaryProcessor : public AgnosticNodeProcessor {
1466  public:
HistogramSummaryProcessor(const OptimizeContext & opt_cxt)1467   explicit HistogramSummaryProcessor(const OptimizeContext& opt_cxt)
1468       : AgnosticNodeProcessor(opt_cxt) {}
1469 
1470  protected:
ShouldProcess() const1471   bool ShouldProcess() const override {
1472     auto input1 = node_map_->GetNode(node_->input(1));
1473     int port;
1474     ParseNodeName(node_->input(1), &port);
1475     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1476            IsPortDimsFour(*input1, port) && IsOnGPU();
1477   }
1478 
GetInputPos() const1479   std::vector<int> GetInputPos() const override { return {1}; }
1480 
AddLayoutTransposeToOutputs()1481   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1482 };
1483 
1484 class IdentityNProcessor : public AgnosticNodeProcessor {
1485  public:
IdentityNProcessor(const OptimizeContext & opt_cxt)1486   explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
1487       : AgnosticNodeProcessor(opt_cxt) {
1488     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1489     for (int i = 0; i < node_->input_size(); i++) {
1490       auto input = node_map_->GetNode(node_->input(i));
1491       int port;
1492       ParseNodeName(node_->input(i), &port);
1493       // Skip control input.
1494       if (port != -1) {
1495         bool is_agnostic =
1496             ops_format_agnostic.find(input->op()) != ops_format_agnostic.end();
1497         if (IsPortDimsFour(*input, port) &&
1498             ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) ||
1499              IsTransposeNCHWToNHWC(input->name()))) {
1500           input_pos_.push_back(i);
1501         }
1502       }
1503     }
1504   }
1505 
1506  protected:
ShouldProcess() const1507   bool ShouldProcess() const override {
1508     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1509            IsOnGPU();
1510   }
1511 
GetInputPos() const1512   std::vector<int> GetInputPos() const override { return input_pos_; }
1513 
GetOutputPos() const1514   std::set<int> GetOutputPos() const override {
1515     std::set<int> output_pos{};
1516     for (const auto& input_pos : input_pos_) {
1517       output_pos.insert(input_pos);
1518     }
1519     return output_pos;
1520   }
1521 
1522  private:
1523   std::vector<int> input_pos_;
1524 };
1525 
1526 class ShapeProcessor : public IdentityNProcessor {
1527  public:
ShapeProcessor(const OptimizeContext & opt_cxt)1528   explicit ShapeProcessor(const OptimizeContext& opt_cxt)
1529       : IdentityNProcessor(opt_cxt) {}
1530 
1531  protected:
AddLayoutTransposeToOutputs()1532   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1533 
CustomizedProcessing()1534   Status CustomizedProcessing() override {
1535     return AddTransformToOutputs("DataFormatVecPermute");
1536   }
1537 };
1538 
1539 class MergeProcessor : public AgnosticNodeProcessor {
1540  public:
MergeProcessor(const OptimizeContext & opt_cxt)1541   explicit MergeProcessor(const OptimizeContext& opt_cxt)
1542       : AgnosticNodeProcessor(opt_cxt) {}
1543 
1544  protected:
ShouldProcess() const1545   bool ShouldProcess() const override {
1546     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1547            IsEveryInputAfterNCHWToNHWC() && IsOnGPU();
1548   }
1549 
GetInputPos() const1550   std::vector<int> GetInputPos() const override {
1551     std::vector<int> input_pos;
1552     int n = node_->attr().at("N").i();
1553     input_pos.reserve(n);
1554     for (int i = 0; i < n; i++) {
1555       input_pos.push_back(i);
1556     }
1557     return input_pos;
1558   }
1559 
1560  private:
IsEveryInputAfterNCHWToNHWC() const1561   bool IsEveryInputAfterNCHWToNHWC() const {
1562     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
1563     for (const auto& input : node_->input()) {
1564       auto input_node = node_map_->GetNode(input);
1565       int port;
1566       ParseNodeName(input, &port);
1567       bool is_agnostic = ops_format_agnostic.find(input_node->op()) !=
1568                          ops_format_agnostic.end();
1569       if (IsPortDimsFour(*input_node, port) &&
1570           ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) ||
1571            IsTransposeNCHWToNHWC(input_node->name()))) {
1572         continue;
1573       }
1574       return false;
1575     }
1576     return true;
1577   }
1578 };
1579 
1580 class PadProcessor : public AgnosticNodeProcessor {
1581  public:
PadProcessor(const OptimizeContext & opt_cxt)1582   explicit PadProcessor(const OptimizeContext& opt_cxt)
1583       : AgnosticNodeProcessor(opt_cxt) {}
1584 
1585  protected:
CustomizedProcessing()1586   Status CustomizedProcessing() override {
1587     DataType dtype = node_->attr().at("Tpaddings").type();
1588     return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
1589   }
1590 };
1591 
1592 class ReverseProcessor : public AgnosticNodeProcessor {
1593  public:
ReverseProcessor(const OptimizeContext & opt_cxt)1594   explicit ReverseProcessor(const OptimizeContext& opt_cxt)
1595       : AgnosticNodeProcessor(opt_cxt) {}
1596 
1597  protected:
CustomizedProcessing()1598   Status CustomizedProcessing() override {
1599     DataType dtype = node_->attr().at("Tidx").type();
1600     return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
1601   }
1602 };
1603 
1604 class SplitProcessor : public AgnosticNodeProcessor {
1605  public:
SplitProcessor(const OptimizeContext & opt_cxt)1606   explicit SplitProcessor(const OptimizeContext& opt_cxt)
1607       : AgnosticNodeProcessor(opt_cxt) {
1608     axis_node_pos_ = 0;
1609   }
1610 
1611  protected:
GetInputPos() const1612   std::vector<int> GetInputPos() const override { return {1}; }
1613 
GetOutputPos() const1614   std::set<int> GetOutputPos() const override {
1615     std::set<int> output_pos{0};
1616     if (HasAttribute(*node_, "num_split").ok()) {
1617       for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
1618         output_pos.insert(i);
1619       }
1620     }
1621     return output_pos;
1622   }
1623 
CustomizedProcessing()1624   Status CustomizedProcessing() override {
1625     return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
1626                                        DT_INT32);
1627   }
1628 
1629   int axis_node_pos_;
1630 };
1631 
1632 class SplitVProcessor : public SplitProcessor {
1633  public:
SplitVProcessor(const OptimizeContext & opt_cxt)1634   explicit SplitVProcessor(const OptimizeContext& opt_cxt)
1635       : SplitProcessor(opt_cxt) {
1636     axis_node_pos_ = 2;
1637   }
1638 
1639  protected:
GetInputPos() const1640   std::vector<int> GetInputPos() const override { return {0}; }
1641 };
1642 
1643 class TernaryOpProcessor : public AgnosticNodeProcessor {
1644  public:
TernaryOpProcessor(const OptimizeContext & opt_cxt)1645   explicit TernaryOpProcessor(const OptimizeContext& opt_cxt)
1646       : AgnosticNodeProcessor(opt_cxt) {}
1647 
1648  protected:
GetInputPos() const1649   std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
1650 };
1651 
1652 class SelectProcessor : public AgnosticNodeProcessor {
1653  public:
SelectProcessor(const OptimizeContext & opt_cxt)1654   explicit SelectProcessor(const OptimizeContext& opt_cxt)
1655       : AgnosticNodeProcessor(opt_cxt) {}
1656 
1657  protected:
ShouldProcess() const1658   bool ShouldProcess() const override {
1659     auto input0 = node_map_->GetNode(node_->input(0));
1660     int input0_port;
1661     ParseNodeName(node_->input(0), &input0_port);
1662     bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
1663                                       IsPortDimsN(*input0, input0_port, 1) ||
1664                                       IsPortDimsN(*input0, input0_port, 4);
1665     return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
1666   }
1667 
GetInputPos() const1668   std::vector<int> GetInputPos() const override {
1669     auto input0 = node_map_->GetNode(node_->input(0));
1670     int input0_port;
1671     ParseNodeName(node_->input(0), &input0_port);
1672     // Input 0 could be a scalar, a vector with size matching the first
1673     // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
1674     if (IsPortDimsFour(*input0, input0_port)) {
1675       return {0, 1, 2};
1676     } else {
1677       return {1, 2};
1678     }
1679   }
1680 };
1681 
1682 class UnaryGradProcessor : public AgnosticNodeProcessor {
1683  public:
UnaryGradProcessor(const OptimizeContext & opt_cxt)1684   explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
1685       : AgnosticNodeProcessor(opt_cxt) {}
1686 
1687  protected:
GetInputPos() const1688   std::vector<int> GetInputPos() const override { return {0, 1}; }
1689 };
1690 
1691 class SliceProcessor : public AgnosticNodeProcessor {
1692  public:
SliceProcessor(const OptimizeContext & opt_cxt)1693   explicit SliceProcessor(const OptimizeContext& opt_cxt)
1694       : AgnosticNodeProcessor(opt_cxt) {
1695     // Skip the first input, which is the data to be sliced.
1696     start_ = 1;
1697     // Note that we can't use node_->input_size() here because there
1698     // could be control inputs.
1699     end_ = 2;
1700   }
1701 
1702  protected:
ProcessInputs()1703   Status ProcessInputs() {
1704     for (int i = start_; i <= end_; i++) {
1705       DataType dtype = node_->attr().at("Index").type();
1706       TF_RETURN_IF_ERROR(
1707           UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
1708     }
1709     return Status::OK();
1710   }
1711 
CustomizedProcessing()1712   Status CustomizedProcessing() override { return ProcessInputs(); }
1713 
1714   int start_;
1715   int end_;
1716 };
1717 
1718 class StridedSliceProcessor : public SliceProcessor {
1719  public:
StridedSliceProcessor(const OptimizeContext & opt_cxt)1720   explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
1721       : SliceProcessor(opt_cxt) {
1722     start_ = 1;
1723     end_ = 3;
1724   }
1725 
1726  protected:
ShouldProcess() const1727   bool ShouldProcess() const override {
1728     return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
1729   }
1730 
CustomizedProcessing()1731   Status CustomizedProcessing() override {
1732     TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
1733     TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
1734     TF_RETURN_IF_ERROR(ProcessInputs());
1735     return Status::OK();
1736   }
1737 
1738  private:
IsMaskZero(const string & mask) const1739   bool IsMaskZero(const string& mask) const {
1740     return node_->attr().at(mask).i() == 0;
1741   }
1742 
IsOnlyBeginEndMask() const1743   bool IsOnlyBeginEndMask() const {
1744     return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
1745            IsMaskZero("shrink_axis_mask");
1746   }
1747 
UpdateMask(const string & mask)1748   Status UpdateMask(const string& mask) {
1749     int i = node_->attr().at(mask).i();
1750     if (i < 0 || i > 15) {
1751       return errors::InvalidArgument("invalid mask value: ", i);
1752     }
1753     if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
1754     switch (i) {
1755       case 2:
1756       case 3:
1757         i += 2;
1758         break;
1759       case 4:
1760       case 5:
1761         i += 4;
1762         break;
1763       case 6:
1764       case 7:
1765         i += 6;
1766         break;
1767       case 8:
1768       case 9:
1769         i -= 6;
1770         break;
1771       case 10:
1772       case 11:
1773         i -= 4;
1774         break;
1775       case 12:
1776       case 13:
1777         i -= 2;
1778         break;
1779     }
1780     node_->mutable_attr()->at(mask).set_i(i);
1781     return Status::OK();
1782   }
1783 };
1784 
1785 class StridedSliceGradProcessor : public StridedSliceProcessor {
1786  public:
StridedSliceGradProcessor(const OptimizeContext & opt_cxt)1787   explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
1788       : StridedSliceProcessor(opt_cxt) {
1789     start_ = 0;
1790     end_ = 3;
1791   }
1792 
1793  protected:
GetInputPos() const1794   std::vector<int> GetInputPos() const override { return {4}; }
1795 };
1796 
1797 class SqueezeProcessor : public AgnosticNodeProcessor {
1798  public:
SqueezeProcessor(const OptimizeContext & opt_cxt)1799   explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
1800       : AgnosticNodeProcessor(opt_cxt) {}
1801 
1802  protected:
ShouldProcess() const1803   bool ShouldProcess() const override {
1804     bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
1805                              (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
1806     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1807            IsInputConvertible() && is_dims_supported && IsOnGPU();
1808   }
1809 
AddLayoutTransposeToOutputs()1810   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
1811 
CustomizedProcessing()1812   Status CustomizedProcessing() override {
1813     TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
1814     auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
1815     if (list->i_size() == 2) {
1816       list->set_i(0, 2);
1817       list->set_i(1, 3);
1818     } else if (list->i_size() == 3) {
1819       list->set_i(1, 2);
1820       list->set_i(2, 3);
1821     }
1822     return Status::OK();
1823   }
1824 
1825  private:
IsInputConvertible() const1826   bool IsInputConvertible() const {
1827     int input_port;
1828     auto input = node_map_->GetNode(node_->input(0));
1829     ParseNodeName(node_->input(0), &input_port);
1830     if (input->attr().find("_output_shapes") != input->attr().end()) {
1831       auto shape = input->attr().at("_output_shapes").list().shape(input_port);
1832       if (shape.dim_size() != 4) {
1833         return false;
1834       }
1835       if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
1836         return true;
1837       }
1838       if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
1839           shape.dim(2).size() == 1) {
1840         return true;
1841       }
1842     }
1843     return false;
1844   }
1845 
IsAlongAxis(const std::vector<int> & axis) const1846   bool IsAlongAxis(const std::vector<int>& axis) const {
1847     if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
1848       auto list = node_->attr().at("squeeze_dims").list();
1849       // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1850       if (list.i_size() == 0) return true;
1851       if (list.i_size() == axis.size()) {
1852         bool along_axis = true;
1853         for (int i = 0; i < axis.size(); i++) {
1854           along_axis = along_axis && (list.i(i) == axis[i]);
1855         }
1856         if (along_axis) return true;
1857       }
1858     }
1859     return false;
1860   }
IsAlongHW() const1861   bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
IsAlongNHW() const1862   bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
1863 };
1864 
1865 class ReduceProcessor : public AgnosticNodeProcessor {
1866  public:
ReduceProcessor(const OptimizeContext & opt_cxt)1867   explicit ReduceProcessor(const OptimizeContext& opt_cxt)
1868       : AgnosticNodeProcessor(opt_cxt) {}
1869 
1870  protected:
ShouldProcess() const1871   bool ShouldProcess() const override {
1872     auto input0 = node_map_->GetNode(node_->input(0));
1873     int port;
1874     ParseNodeName(node_->input(0), &port);
1875     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
1876            IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
1877            IsOnGPU();
1878   }
1879 
CustomizedProcessing()1880   Status CustomizedProcessing() override {
1881     if (IsReduceAxisSupported()) {
1882       DataType dtype = node_->attr().at("Tidx").type();
1883       TF_RETURN_IF_ERROR(
1884           UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
1885     }
1886     return Status::OK();
1887   }
1888 
AddLayoutTransposeToOutputs()1889   Status AddLayoutTransposeToOutputs() override {
1890     if (KeepDims()) {
1891       return AddTransformToOutputs("Transpose");
1892     }
1893     return Status::OK();
1894   }
1895 
1896  private:
IsReduceAxisSupported() const1897   bool IsReduceAxisSupported() const {
1898     return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
1899                            IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
1900                           !KeepDims());
1901   }
1902 
IsAlongAxis(const std::vector<int> & axis) const1903   bool IsAlongAxis(const std::vector<int>& axis) const {
1904     auto axis_node = node_map_->GetNode(node_->input(1));
1905     if (!IsConstant(*axis_node)) {
1906       return false;
1907     }
1908     if (HasAttribute(*axis_node, "value").ok()) {
1909       Tensor tensor;
1910       auto success = tensor.FromProto(axis_node->attr().at({"value"}).tensor());
1911       if (!success) {
1912         LOG(ERROR) << "Failed to parse TensorProto.";
1913       }
1914       if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
1915         bool along_axis = true;
1916         for (int i = 0; i < axis.size(); i++) {
1917           along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
1918         }
1919         if (along_axis) return true;
1920       }
1921     }
1922     return false;
1923   }
1924 
IsAlongAllFourDims() const1925   bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
1926 
IsAlongHWC() const1927   bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
1928 
IsAlongNHW() const1929   bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
1930 
IsAlongHW() const1931   bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
1932 
IsAlongC() const1933   bool IsAlongC() const { return IsAlongAxis({3}); }
1934 
KeepDims() const1935   bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
1936 };
1937 
1938 class SwitchProcessor : public AgnosticNodeProcessor {
1939  public:
SwitchProcessor(const OptimizeContext & opt_cxt)1940   explicit SwitchProcessor(const OptimizeContext& opt_cxt)
1941       : AgnosticNodeProcessor(opt_cxt) {}
1942 
1943  protected:
GetOutputPos() const1944   std::set<int> GetOutputPos() const override { return {0, 1}; }
1945 };
1946 
1947 class TileProcessor : public AgnosticNodeProcessor {
1948  public:
TileProcessor(const OptimizeContext & opt_cxt)1949   explicit TileProcessor(const OptimizeContext& opt_cxt)
1950       : AgnosticNodeProcessor(opt_cxt) {}
1951 
1952  protected:
CustomizedProcessing()1953   Status CustomizedProcessing() override {
1954     DataType dtype = node_->attr().at("Tmultiples").type();
1955     return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
1956   }
1957 };
1958 
1959 class DataLayoutOptimizer : GraphProcessor {
1960  public:
DataLayoutOptimizer(const GraphProperties & graph_properties,const VirtualPlacer & virtual_placer,const LayoutOptimizer::TuningConfig & config,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,NodeMap * node_map)1961   explicit DataLayoutOptimizer(
1962       const GraphProperties& graph_properties,
1963       const VirtualPlacer& virtual_placer,
1964       const LayoutOptimizer::TuningConfig& config,
1965       const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
1966       NodeMap* node_map)
1967       : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
1968                        graph, node_map),
1969         config_(config) {}
1970 
Optimize()1971   Status Optimize() {
1972     VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
1973     TF_RETURN_IF_ERROR(Expand());
1974     VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
1975     TF_RETURN_IF_ERROR(Collapse());
1976     VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
1977     return Status::OK();
1978   }
1979 
1980  private:
AddNodePermNHWCToNCHW()1981   NodeDef* AddNodePermNHWCToNCHW() {
1982     return AddNodePermConst(LayoutOptimizerNode(kPermNHWCToNCHW), "",
1983                             {0, 3, 1, 2});
1984   }
1985 
AddNodePermNCHWToNHWC()1986   NodeDef* AddNodePermNCHWToNHWC() {
1987     return AddNodePermConst(LayoutOptimizerNode(kPermNCHWToNHWC), "",
1988                             {0, 2, 3, 1});
1989   }
1990 
1991   // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Expand()1992   Status Expand() {
1993     int node_size_original = graph_->node_size();
1994 
1995     FrameView frame_view;
1996     TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph_));
1997 
1998     // This is the first pass where we expand the nodes which support NCHW.
1999     std::set<string> ops_format_supported = GetOpsFormatSupported();
2000     for (int i = 0; i < node_size_original; i++) {
2001       if (IsNodeByLayoutOptimizer(graph_->node(i).name())) {
2002         return Status(error::INVALID_ARGUMENT,
2003                       "The graph is already optimized by layout optimizer.");
2004       }
2005       if (ops_format_supported.find(graph_->node(i).op()) !=
2006           ops_format_supported.end()) {
2007         auto node = graph_->mutable_node(i);
2008         bool is_in_frame = frame_view.IsInFrame(*node);
2009         OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
2010                                 virtual_placer_, nodes_to_preserve_,
2011                                 is_in_frame);
2012         std::unique_ptr<NodeProcessor> node_processor;
2013         if (IsAvgPoolGrad(*node)) {
2014           node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
2015         } else if (IsBiasAddGrad(*node)) {
2016           node_processor.reset(new BiasAddGradProcessor(opt_cxt));
2017         } else if (IsConv2D(*node)) {
2018           node_processor.reset(new Conv2DProcessor(opt_cxt, config_.no_gemm));
2019         } else if (IsConv2DBackpropFilter(*node)) {
2020           node_processor.reset(
2021               new Conv2DBackpropFilterProcessor(opt_cxt, config_.no_gemm));
2022         } else if (IsConv2DBackpropInput(*node)) {
2023           node_processor.reset(
2024               new Conv2DBackpropInputProcessor(opt_cxt, config_.no_gemm));
2025         } else if (IsDepthwiseConv2dNative(*node)) {
2026           node_processor.reset(new Conv2DProcessor(opt_cxt, true));
2027         } else if (IsDepthwiseConv2dNativeBackpropFilter(*node)) {
2028           node_processor.reset(
2029               new Conv2DBackpropFilterProcessor(opt_cxt, true));
2030         } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
2031           node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
2032         } else if (IsFusedBatchNormGrad(*node)) {
2033           node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
2034         } else if (IsMaxPoolV2(*node)) {
2035           node_processor.reset(new MaxPoolV2Processor(opt_cxt));
2036         } else if (IsMaxPoolGradV1(*node) || IsMaxPoolGradGradV1(*node)) {
2037           node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
2038         } else if (IsMaxPoolGradV2(*node) || IsMaxPoolGradGradV2(*node)) {
2039           node_processor.reset(new MaxPoolGradV2Processor(opt_cxt));
2040         } else {
2041           node_processor.reset(new NodeProcessor(opt_cxt));
2042         }
2043         TF_RETURN_IF_ERROR(node_processor->ConvertNode());
2044       }
2045     }
2046 
2047     // This is the second pass where we expand layout-agnostic nodes. This pass
2048     // only needs to be performed if at least one node in the previous pass is
2049     // expanded.
2050     if (graph_->node_size() > node_size_original) {
2051       // Create Const nodes holding the permutation used by added Transposes of
2052       // nodes not in a frame.
2053       AddNodePermNHWCToNCHW();
2054       AddNodePermNCHWToNHWC();
2055       std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
2056       for (int i = 0; i < graph_->node_size(); i++) {
2057         if (ops_format_agnostic.find(graph_->node(i).op()) !=
2058             ops_format_agnostic.end()) {
2059           auto node = graph_->mutable_node(i);
2060           bool is_in_frame = frame_view.IsInFrame(*node);
2061           OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
2062                                   virtual_placer_, nodes_to_preserve_,
2063                                   is_in_frame);
2064           std::unique_ptr<NodeProcessor> node_processor;
2065           if (IsAddN(*node)) {
2066             node_processor.reset(new AddNProcessor(opt_cxt));
2067           } else if (IsBetainc(*node)) {
2068             node_processor.reset(new TernaryOpProcessor(opt_cxt));
2069           } else if (IsBinaryOp(*node)) {
2070             node_processor.reset(new BinaryOpProcessor(opt_cxt));
2071           } else if (IsConcat(*node)) {
2072             node_processor.reset(new ConcatProcessor(opt_cxt));
2073           } else if (IsFill(*node)) {
2074             node_processor.reset(new FillProcessor(opt_cxt));
2075           } else if (IsHistogramSummary(*node)) {
2076             node_processor.reset(new HistogramSummaryProcessor(opt_cxt));
2077           } else if (IsIdentityN(*node)) {
2078             node_processor.reset(new IdentityNProcessor(opt_cxt));
2079           } else if (IsMerge(*node)) {
2080             node_processor.reset(new MergeProcessor(opt_cxt));
2081           } else if (IsPad(*node) || IsMirrorPad(*node) ||
2082                      IsMirrorPadGrad(*node)) {
2083             node_processor.reset(new PadProcessor(opt_cxt));
2084           } else if (IsReduceOp(*node)) {
2085             node_processor.reset(new ReduceProcessor(opt_cxt));
2086           } else if (IsReverseV2(*node)) {
2087             node_processor.reset(new ReverseProcessor(opt_cxt));
2088           } else if (IsSelect(*node)) {
2089             node_processor.reset(new SelectProcessor(opt_cxt));
2090           } else if (IsSlice(*node)) {
2091             node_processor.reset(new SliceProcessor(opt_cxt));
2092           } else if (IsStridedSlice(*node)) {
2093             node_processor.reset(new StridedSliceProcessor(opt_cxt));
2094           } else if (IsShape(*node) || IsShapeN(*node)) {
2095             node_processor.reset(new ShapeProcessor(opt_cxt));
2096           } else if (IsSplit(*node)) {
2097             node_processor.reset(new SplitProcessor(opt_cxt));
2098           } else if (IsSplitV(*node)) {
2099             node_processor.reset(new SplitVProcessor(opt_cxt));
2100           } else if (IsSqueeze(*node)) {
2101             node_processor.reset(new SqueezeProcessor(opt_cxt));
2102           } else if (IsStridedSliceGrad(*node)) {
2103             node_processor.reset(new StridedSliceGradProcessor(opt_cxt));
2104           } else if (IsSwitch(*node)) {
2105             node_processor.reset(new SwitchProcessor(opt_cxt));
2106           } else if (IsTile(*node)) {
2107             node_processor.reset(new TileProcessor(opt_cxt));
2108           } else if (IsUnaryGrad(*node)) {
2109             node_processor.reset(new UnaryGradProcessor(opt_cxt));
2110           } else {
2111             node_processor.reset(new AgnosticNodeProcessor(opt_cxt));
2112           }
2113           TF_RETURN_IF_ERROR(node_processor->ConvertNode());
2114         }
2115       }
2116     }
2117     return Status::OK();
2118   }
2119 
2120   // Remove all node pairs, where a NCHW-to-NHWC node is followed by
2121   // a NHWC-to-NCHW node.
Collapse()2122   Status Collapse() {
2123     std::unordered_set<string> nodes_removable;
2124     for (int i = 0; i < graph_->node_size(); i++) {
2125       auto node = graph_->mutable_node(i);
2126       node->mutable_attr()->erase("_output_shapes");
2127       if (IsTransposeNHWCToNCHW(node->name()) ||
2128           IsDimMapNHWCToNCHW(node->name()) ||
2129           IsVecPermuteNHWCToNCHW(node->name())) {
2130         bool transpose_pair = IsTransposeNHWCToNCHW(node->name()) &&
2131                               IsTransposeNCHWToNHWC(node->input(0));
2132         bool dim_map_pair = IsDimMapNHWCToNCHW(node->name()) &&
2133                             IsDimMapNCHWToNHWC(node->input(0));
2134         bool vec_permute_pair = IsVecPermuteNHWCToNCHW(node->name()) &&
2135                                 IsVecPermuteNCHWToNHWC(node->input(0));
2136         if (transpose_pair || dim_map_pair || vec_permute_pair) {
2137           const string& trans_first = node->input(0);
2138           const string& trans_second = node->name();
2139           auto outputs = node_map_->GetOutputs(trans_second);
2140           CHECK(outputs.size() == 1)
2141               << "There is always only a single output for a Transpose node, "
2142               << "due to the way it is added by NodeProcessor.";
2143           NodeDef* output = *outputs.begin();
2144           string input = node_map_->GetNode(trans_first)->input(0);
2145           for (int i = 0; i < output->input_size(); i++) {
2146             if (output->input(i).compare(trans_second) == 0) {
2147               *output->mutable_input(i) = input;
2148               break;
2149             }
2150           }
2151           nodes_removable.insert(trans_first);
2152           nodes_removable.insert(trans_second);
2153         }
2154       }
2155     }
2156     graph_->mutable_node()->erase(
2157         std::remove_if(
2158             graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
2159             [nodes_removable](const NodeDef& node) {
2160               return nodes_removable.find(node.name()) != nodes_removable.end();
2161             }),
2162         graph_->mutable_node()->end());
2163     return Status::OK();
2164   }
2165 
2166   const LayoutOptimizer::TuningConfig& config_;
2167 };
2168 
GetNumGPUs(const Cluster & cluster)2169 int GetNumGPUs(const Cluster& cluster) {
2170   auto devices = cluster.GetDevices();
2171   int num_gpus = 0;
2172   for (const auto& device : devices) {
2173     if (device.second.type() == "GPU") {
2174       num_gpus++;
2175     }
2176   }
2177   return num_gpus;
2178 }
2179 }  // namespace
2180 
Tune(const GrapplerItem & item,const GraphProperties & graph_properties,const TuningConfig & config,GraphDef * output)2181 Status LayoutOptimizer::Tune(const GrapplerItem& item,
2182                              const GraphProperties& graph_properties,
2183                              const TuningConfig& config, GraphDef* output) {
2184   auto status = graph_properties.AnnotateOutputShapes(output);
2185   if (!status.ok()) {
2186     VLOG(1) << "Annotate shape return status: " << status.ToString();
2187     *output = item.graph;
2188     return status;
2189   }
2190   NodeMap node_map(output);
2191   DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
2192                                        config, nodes_to_preserve_, output,
2193                                        &node_map);
2194   status = layout_optimizer.Optimize();
2195   return status;
2196 }
2197 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2198 Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
2199                                  GraphDef* output) {
2200   if (cluster == nullptr) {
2201     return errors::InvalidArgument("cluster == nullptr");
2202   }
2203 
2204   if (GetNumGPUs(*cluster) < 1) {
2205     // LayoutOptimizer is currently only tuned for GPU.
2206     *output = item.graph;
2207     return Status::OK();
2208   }
2209 
2210   virtual_placer_.reset(new VirtualPlacer(cluster));
2211   nodes_to_preserve_ = item.NodesToPreserve();
2212   GraphProperties graph_properties(item);
2213   auto status = graph_properties.InferStatically(false);
2214   if (!status.ok()) {
2215     VLOG(1) << "Infer shape return status: " << status.ToString();
2216     *output = item.graph;
2217     return status;
2218   }
2219   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
2220 
2221   TuningConfig config;
2222   config.no_gemm = true;
2223   // TODO(yaozhang): Enable tuning with various TuningConfig choices with
2224   // the measurement-based estimator.
2225   status = Tune(item, graph_properties, config, output);
2226   if (!status.ok()) {
2227     *output = item.graph;
2228   }
2229   return status;
2230 }
2231 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)2232 void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
2233                                const GraphDef& optimize_output, double result) {
2234   // Nothing to do for LayoutOptimizer.
2235 }
2236 
2237 }  // end namespace grappler
2238 }  // end namespace tensorflow
2239