• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 #include "tensorflow/core/common_runtime/graph_constructor.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/graph/subgraph.h"
20 #include "tensorflow/core/platform/init_main.h"
21 #include "tensorflow/core/public/session.h"
22 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24 
25 namespace tensorflow {
26 namespace graph_transforms {
27 
FuseResizePadAndConv(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)28 Status FuseResizePadAndConv(const GraphDef& input_graph_def,
29                             const TransformFuncContext& context,
30                             GraphDef* output_graph_def) {
31   GraphDef replaced_graph_def;
32   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
33       input_graph_def,  // clang-format off
34       {"Conv2D",
35           {
36               {"MirrorPad",
37                   {
38                       {"ResizeBilinear"},
39                       {"*"}
40                   }
41               },
42               {"*"}
43           }
44       },  // clang-format on
45       [](const NodeMatch& match, const std::set<string>& input_nodes,
46          const std::set<string>& output_nodes,
47          std::vector<NodeDef>* new_nodes) {
48         // Find all the nodes we expect in the subgraph.
49         const NodeDef& conv_node = match.node;
50         const NodeDef& mirror_pad_node = match.inputs[0].node;
51         const NodeDef& weights_node = match.inputs[1].node;
52         const NodeDef& resize_node = match.inputs[0].inputs[0].node;
53         const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
54 
55         // We'll be reusing the old weights and pad dimensions.
56         new_nodes->push_back(weights_node);
57         new_nodes->push_back(pad_dims_node);
58 
59         // Set up the new fused version of the convolution op.
60         NodeDef fused_conv;
61         fused_conv.set_op("FusedResizeAndPadConv2D");
62         fused_conv.set_name(match.node.name());
63         AddNodeInput(resize_node.input(0), &fused_conv);
64         AddNodeInput(resize_node.input(1), &fused_conv);
65         AddNodeInput(mirror_pad_node.input(1), &fused_conv);
66         AddNodeInput(conv_node.input(1), &fused_conv);
67         CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
68                      &fused_conv);
69         CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
70         CopyNodeAttr(conv_node, "T", "T", &fused_conv);
71         CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
72         CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
73         new_nodes->push_back(fused_conv);
74 
75         return OkStatus();
76       },
77       {}, &replaced_graph_def));
78   *output_graph_def = replaced_graph_def;
79   return OkStatus();
80 }
81 
FuseResizeAndConv(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)82 Status FuseResizeAndConv(const GraphDef& input_graph_def,
83                          const TransformFuncContext& context,
84                          GraphDef* output_graph_def) {
85   GraphDef replaced_graph_def;
86   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
87       input_graph_def,  // clang-format off
88       {"Conv2D",
89           {
90               {"ResizeBilinear"},
91               {"*"}
92           }
93       },  // clang-format on
94       [](const NodeMatch& match, const std::set<string>& input_nodes,
95          const std::set<string>& output_nodes,
96          std::vector<NodeDef>* new_nodes) {
97         // Find all the nodes we expect in the subgraph.
98         const NodeDef& conv_node = match.node;
99         const NodeDef& resize_node = match.inputs[0].node;
100         const NodeDef& weights_node = match.inputs[1].node;
101 
102         // We'll be reusing the old weights.
103         new_nodes->push_back(weights_node);
104 
105         // Create a 'no-op' mirror padding node that has no effect.
106         NodeDef pad_dims_node;
107         pad_dims_node.set_op("Const");
108         pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
109         SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
110         SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
111                                  &pad_dims_node);
112         new_nodes->push_back(pad_dims_node);
113 
114         // Set up the new fused version of the convolution op.
115         NodeDef fused_conv;
116         fused_conv.set_op("FusedResizeAndPadConv2D");
117         fused_conv.set_name(match.node.name());
118         AddNodeInput(resize_node.input(0), &fused_conv);
119         AddNodeInput(resize_node.input(1), &fused_conv);
120         AddNodeInput(pad_dims_node.name(), &fused_conv);
121         AddNodeInput(conv_node.input(1), &fused_conv);
122         CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
123                      &fused_conv);
124         SetNodeAttr("mode", "REFLECT", &fused_conv);
125         CopyNodeAttr(conv_node, "T", "T", &fused_conv);
126         CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
127         CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
128         new_nodes->push_back(fused_conv);
129 
130         return OkStatus();
131       },
132       {}, &replaced_graph_def));
133   *output_graph_def = replaced_graph_def;
134   return OkStatus();
135 }
136 
FusePadAndConv(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)137 Status FusePadAndConv(const GraphDef& input_graph_def,
138                       const TransformFuncContext& context,
139                       GraphDef* output_graph_def) {
140   GraphDef replaced_graph_def;
141   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
142       input_graph_def,  // clang-format off
143       {"Conv2D",
144           {
145               {"MirrorPad",
146                   {
147                       {"*"},
148                       {"*"},
149                   }
150               },
151               {"*"}
152           }
153       },  // clang-format on
154       [](const NodeMatch& match, const std::set<string>& input_nodes,
155          const std::set<string>& output_nodes,
156          std::vector<NodeDef>* new_nodes) {
157         // Find all the nodes we expect in the subgraph.
158         const NodeDef& conv_node = match.node;
159         CHECK_EQ("Conv2D", conv_node.op());
160         const NodeDef& mirror_pad_node = match.inputs[0].node;
161         CHECK_EQ("MirrorPad", mirror_pad_node.op());
162         const NodeDef& weights_node = match.inputs[1].node;
163         const NodeDef& input_node = match.inputs[0].inputs[0].node;
164         const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
165 
166         // We'll be reusing the old weights and pad dimensions.
167         new_nodes->push_back(weights_node);
168         new_nodes->push_back(input_node);
169         new_nodes->push_back(pad_dims_node);
170 
171         // Set up the new fused version of the convolution op.
172         NodeDef fused_conv;
173         fused_conv.set_op("FusedPadConv2D");
174         fused_conv.set_name(match.node.name());
175         AddNodeInput(mirror_pad_node.input(0), &fused_conv);
176         AddNodeInput(mirror_pad_node.input(1), &fused_conv);
177         AddNodeInput(conv_node.input(1), &fused_conv);
178         CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
179         CopyNodeAttr(conv_node, "T", "T", &fused_conv);
180         CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
181         CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
182         new_nodes->push_back(fused_conv);
183 
184         return OkStatus();
185       },
186       {}, &replaced_graph_def));
187   *output_graph_def = replaced_graph_def;
188   return OkStatus();
189 }
190 
191 REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv", FuseResizePadAndConv);
192 
193 REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv", FuseResizeAndConv);
194 
195 REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv", FusePadAndConv);
196 
197 }  // namespace graph_transforms
198 }  // namespace tensorflow
199