1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
32 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
33 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/topological_sort.h"
36 #include "tensorflow/core/kernels/function_ops.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/protobuf.h"
40
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44
MakeFusedNode(const NodeDef & map_node,const NodeDef & filter_node,const FunctionDef & fused_function,MutableGraphView * graph)45 NodeDef MakeFusedNode(const NodeDef& map_node, const NodeDef& filter_node,
46 const FunctionDef& fused_function,
47 MutableGraphView* graph) {
48 NodeDef fused_node;
49 graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
50 fused_node.set_op(map_node.op());
51
52 // Copy over inputs.
53 for (int i = 0; i < map_node.input_size(); ++i) {
54 fused_node.add_input(map_node.input(i));
55 }
56
57 auto attr = map_node.attr().at("f");
58 attr.mutable_func()->set_name(fused_function.signature().name());
59 (*fused_node.mutable_attr())["f"] = std::move(attr);
60
61 // Required attrs.
62 graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
63 graph_utils::CopyShapesAndTypesAttrs(map_node, &fused_node);
64
65 // Optional attrs.
66 for (auto key :
67 {"use_inter_op_parallelism", "sloppy", "preserve_cardinality"}) {
68 if (gtl::FindOrNull(map_node.attr(), key)) {
69 graph_utils::CopyAttribute(key, map_node, &fused_node);
70 }
71 }
72 graph_utils::MaybeSetFusedMetadata(map_node, filter_node, &fused_node);
73
74 // Add the predicate output attributes.
75 (*fused_node.mutable_attr())["output_types"]
76 .mutable_list()
77 ->mutable_type()
78 ->Add(DT_BOOL);
79 (*fused_node.mutable_attr())["output_shapes"]
80 .mutable_list()
81 ->mutable_shape()
82 ->Add();
83
84 return fused_node;
85 }
86
MakeFilterNode(const NodeDef & fused_map,const FunctionDef & fused_map_func,MutableGraphView * graph,FunctionDefLibrary * library)87 NodeDef MakeFilterNode(const NodeDef& fused_map,
88 const FunctionDef& fused_map_func,
89 MutableGraphView* graph, FunctionDefLibrary* library) {
90 NodeDef filter_node;
91 graph_utils::SetUniqueGraphNodeName("FilterByLast", graph->graph(),
92 &filter_node);
93 filter_node.set_op("FilterDataset");
94 filter_node.add_input(fused_map.name());
95
96 graph_utils::CopyShapesAndTypesAttrs(fused_map, &filter_node);
97
98 AddNodeAttr("Targuments", std::vector<DataType>({}), &filter_node);
99
100 OpDef fused_sig = fused_map_func.signature();
101 FunctionDef* func = library->add_function();
102 OpDef* sig = func->mutable_signature();
103 sig->set_name("GetLast");
104 for (const auto& arg : fused_sig.output_arg()) {
105 *(sig->add_input_arg()) = arg;
106 }
107 OpDef::ArgDef* arg = sig->add_output_arg();
108 arg->set_name("predicate_result");
109 arg->set_description("predicate result computed in the fused map");
110 arg->set_type(DT_BOOL);
111 sig->set_description("returns the last argument");
112 (*func->mutable_ret())["predicate_result"] = strings::StrCat(
113 fused_sig.output_arg(fused_sig.output_arg_size() - 1).name(), ":0");
114
115 (*filter_node.mutable_attr())["predicate"] =
116 FunctionDefHelper::FunctionRef(func->signature().name()).proto;
117 return filter_node;
118 }
119
MakeMapNode(const NodeDef & updated_filter,const NodeDef & original_map,const FunctionDef & fused_map_func,MutableGraphView * graph,FunctionDefLibrary * library)120 NodeDef MakeMapNode(const NodeDef& updated_filter, const NodeDef& original_map,
121 const FunctionDef& fused_map_func, MutableGraphView* graph,
122 FunctionDefLibrary* library) {
123 NodeDef map_node;
124 graph_utils::SetUniqueGraphNodeName("DropLast", graph->graph(), &map_node);
125 // We use MapDataset even if the original map was ParallelMap. Non-parallel
126 // map is more performant for simple short-circuit functions like (x, y) -> x.
127 map_node.set_op("MapDataset");
128 map_node.add_input(updated_filter.name());
129
130 graph_utils::CopyShapesAndTypesAttrs(original_map, &map_node);
131
132 AddNodeAttr("Targuments", std::vector<DataType>({}), &map_node);
133
134 for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
135 if (gtl::FindOrNull(original_map.attr(), key)) {
136 graph_utils::CopyAttribute(key, original_map, &map_node);
137 }
138 }
139
140 OpDef fused_sig = fused_map_func.signature();
141 FunctionDef* func = library->add_function();
142 OpDef* sig = func->mutable_signature();
143 sig->set_name("DropLast");
144 for (const auto& o : fused_sig.output_arg()) {
145 *(sig->add_input_arg()) = o;
146 }
147 for (int i = 0; i < fused_sig.output_arg_size() - 1; ++i) {
148 auto arg_i = fused_sig.output_arg(i);
149 *(sig->add_output_arg()) = arg_i;
150 (*func->mutable_ret())[arg_i.name()] = strings::StrCat(arg_i.name(), ":0");
151 }
152 sig->set_description("drops the last argument");
153
154 (*map_node.mutable_attr())["f"] =
155 FunctionDefHelper::FunctionRef(func->signature().name()).proto;
156 return map_node;
157 }
158
159 } // namespace
160
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)161 Status MapAndFilterFusion::OptimizeAndCollectStats(Cluster* cluster,
162 const GrapplerItem& item,
163 GraphDef* output,
164 OptimizationStats* stats) {
165 GraphDef sorted_old_graph = item.graph;
166 TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
167 // TODO(prazek): We might have some problems with performance if we copy
168 // the whole graph too much.
169 *output = sorted_old_graph;
170
171 MutableGraphView graph(output);
172 absl::flat_hash_set<string> nodes_to_delete;
173 FunctionLibraryDefinition function_library(OpRegistry::Global(),
174 item.graph.library());
175 auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
176 // TODO(b/148614315): Support captured inputs.
177 if ((node.op() == "MapDataset" && node.input_size() == 1) ||
178 (node.op() == "ParallelMapDataset" && node.input_size() == 2)) {
179 return &node;
180 }
181 return nullptr;
182 };
183
184 auto get_filter_node = [](const NodeDef& node) -> const NodeDef* {
185 // TODO(b/148614315): Support captured inputs.
186 if (node.op() == "FilterDataset" && node.input_size() == 1) return &node;
187 return nullptr;
188 };
189
190 auto make_fused_function = [&function_library, &output](
191 const NodeDef* map_node,
192 const NodeDef* filter_node) -> FunctionDef* {
193 const auto& parent_fun = map_node->attr().at("f");
194 const FunctionDef* map_func =
195 function_library.Find(parent_fun.func().name());
196 const auto& fun = filter_node->attr().at("predicate");
197 const FunctionDef* filter_func = function_library.Find(fun.func().name());
198 if (!fusion_utils::CanCompose(map_func->signature(),
199 filter_func->signature())) {
200 VLOG(1) << "Can't fuse map and filter because the output signature of "
201 "the map function does not match the input signature of the "
202 "filter function\n";
203 return nullptr;
204 }
205 return fusion_utils::FuseFunctions(
206 *map_func, *filter_func, "fused_map_and_filter_function",
207 fusion_utils::CombineSignature, fusion_utils::ComposeInput,
208 fusion_utils::CombineOutput, fusion_utils::MergeNodes,
209 output->mutable_library());
210 };
211
212 for (const NodeDef& node : sorted_old_graph.node()) {
213 const NodeDef* filter_node = get_filter_node(node);
214 if (!filter_node) continue;
215
216 const NodeDef* map_node =
217 get_map_node(*graph_utils::GetInputNode(*filter_node, graph));
218 if (!map_node) continue;
219
220 const auto* fused_function = make_fused_function(map_node, filter_node);
221 if (fused_function == nullptr) continue;
222
223 const auto* fused_maps = graph.AddNode(
224 MakeFusedNode(*map_node, *filter_node, *fused_function, &graph));
225
226 const auto* new_filter_node = graph.AddNode(MakeFilterNode(
227 *fused_maps, *fused_function, &graph, output->mutable_library()));
228
229 const auto* new_map_node =
230 graph.AddNode(MakeMapNode(*new_filter_node, *map_node, *fused_function,
231 &graph, output->mutable_library()));
232
233 TF_RETURN_IF_ERROR(
234 graph.UpdateFanouts(filter_node->name(), new_map_node->name()));
235 TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
236
237 nodes_to_delete.insert(map_node->name());
238 nodes_to_delete.insert(filter_node->name());
239 stats->num_changes++;
240 }
241
242 TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
243 return OkStatus();
244 }
245
246 REGISTER_GRAPH_OPTIMIZER_AS(MapAndFilterFusion, "map_and_filter_fusion");
247
248 } // namespace grappler
249 } // namespace tensorflow
250