• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
32 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
33 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/topological_sort.h"
36 #include "tensorflow/core/kernels/function_ops.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44 
MakeFusedNode(const NodeDef & map_node,const NodeDef & filter_node,const FunctionDef & fused_function,MutableGraphView * graph)45 NodeDef MakeFusedNode(const NodeDef& map_node, const NodeDef& filter_node,
46                       const FunctionDef& fused_function,
47                       MutableGraphView* graph) {
48   NodeDef fused_node;
49   graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
50   fused_node.set_op(map_node.op());
51 
52   // Copy over inputs.
53   for (int i = 0; i < map_node.input_size(); ++i) {
54     fused_node.add_input(map_node.input(i));
55   }
56 
57   auto attr = map_node.attr().at("f");
58   attr.mutable_func()->set_name(fused_function.signature().name());
59   (*fused_node.mutable_attr())["f"] = std::move(attr);
60 
61   // Required attrs.
62   graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
63   graph_utils::CopyShapesAndTypesAttrs(map_node, &fused_node);
64 
65   // Optional attrs.
66   for (auto key :
67        {"use_inter_op_parallelism", "sloppy", "preserve_cardinality"}) {
68     if (gtl::FindOrNull(map_node.attr(), key)) {
69       graph_utils::CopyAttribute(key, map_node, &fused_node);
70     }
71   }
72   graph_utils::MaybeSetFusedMetadata(map_node, filter_node, &fused_node);
73 
74   // Add the predicate output attributes.
75   (*fused_node.mutable_attr())["output_types"]
76       .mutable_list()
77       ->mutable_type()
78       ->Add(DT_BOOL);
79   (*fused_node.mutable_attr())["output_shapes"]
80       .mutable_list()
81       ->mutable_shape()
82       ->Add();
83 
84   return fused_node;
85 }
86 
MakeFilterNode(const NodeDef & fused_map,const FunctionDef & fused_map_func,MutableGraphView * graph,FunctionDefLibrary * library)87 NodeDef MakeFilterNode(const NodeDef& fused_map,
88                        const FunctionDef& fused_map_func,
89                        MutableGraphView* graph, FunctionDefLibrary* library) {
90   NodeDef filter_node;
91   graph_utils::SetUniqueGraphNodeName("FilterByLast", graph->graph(),
92                                       &filter_node);
93   filter_node.set_op("FilterDataset");
94   filter_node.add_input(fused_map.name());
95 
96   graph_utils::CopyShapesAndTypesAttrs(fused_map, &filter_node);
97 
98   AddNodeAttr("Targuments", std::vector<DataType>({}), &filter_node);
99 
100   OpDef fused_sig = fused_map_func.signature();
101   FunctionDef* func = library->add_function();
102   OpDef* sig = func->mutable_signature();
103   sig->set_name("GetLast");
104   for (const auto& arg : fused_sig.output_arg()) {
105     *(sig->add_input_arg()) = arg;
106   }
107   OpDef::ArgDef* arg = sig->add_output_arg();
108   arg->set_name("predicate_result");
109   arg->set_description("predicate result computed in the fused map");
110   arg->set_type(DT_BOOL);
111   sig->set_description("returns the last argument");
112   (*func->mutable_ret())["predicate_result"] = strings::StrCat(
113       fused_sig.output_arg(fused_sig.output_arg_size() - 1).name(), ":0");
114 
115   (*filter_node.mutable_attr())["predicate"] =
116       FunctionDefHelper::FunctionRef(func->signature().name()).proto;
117   return filter_node;
118 }
119 
MakeMapNode(const NodeDef & updated_filter,const NodeDef & original_map,const FunctionDef & fused_map_func,MutableGraphView * graph,FunctionDefLibrary * library)120 NodeDef MakeMapNode(const NodeDef& updated_filter, const NodeDef& original_map,
121                     const FunctionDef& fused_map_func, MutableGraphView* graph,
122                     FunctionDefLibrary* library) {
123   NodeDef map_node;
124   graph_utils::SetUniqueGraphNodeName("DropLast", graph->graph(), &map_node);
125   // We use MapDataset even if the original map was ParallelMap. Non-parallel
126   // map is more performant for simple short-circuit functions like (x, y) -> x.
127   map_node.set_op("MapDataset");
128   map_node.add_input(updated_filter.name());
129 
130   graph_utils::CopyShapesAndTypesAttrs(original_map, &map_node);
131 
132   AddNodeAttr("Targuments", std::vector<DataType>({}), &map_node);
133 
134   for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
135     if (gtl::FindOrNull(original_map.attr(), key)) {
136       graph_utils::CopyAttribute(key, original_map, &map_node);
137     }
138   }
139 
140   OpDef fused_sig = fused_map_func.signature();
141   FunctionDef* func = library->add_function();
142   OpDef* sig = func->mutable_signature();
143   sig->set_name("DropLast");
144   for (const auto& o : fused_sig.output_arg()) {
145     *(sig->add_input_arg()) = o;
146   }
147   for (int i = 0; i < fused_sig.output_arg_size() - 1; ++i) {
148     auto arg_i = fused_sig.output_arg(i);
149     *(sig->add_output_arg()) = arg_i;
150     (*func->mutable_ret())[arg_i.name()] = strings::StrCat(arg_i.name(), ":0");
151   }
152   sig->set_description("drops the last argument");
153 
154   (*map_node.mutable_attr())["f"] =
155       FunctionDefHelper::FunctionRef(func->signature().name()).proto;
156   return map_node;
157 }
158 
159 }  // namespace
160 
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)161 Status MapAndFilterFusion::OptimizeAndCollectStats(Cluster* cluster,
162                                                    const GrapplerItem& item,
163                                                    GraphDef* output,
164                                                    OptimizationStats* stats) {
165   GraphDef sorted_old_graph = item.graph;
166   TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
167   // TODO(prazek): We might have some problems with performance if we copy
168   // the whole graph too much.
169   *output = sorted_old_graph;
170 
171   MutableGraphView graph(output);
172   absl::flat_hash_set<string> nodes_to_delete;
173   FunctionLibraryDefinition function_library(OpRegistry::Global(),
174                                              item.graph.library());
175   auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
176     // TODO(b/148614315): Support captured inputs.
177     if ((node.op() == "MapDataset" && node.input_size() == 1) ||
178         (node.op() == "ParallelMapDataset" && node.input_size() == 2)) {
179       return &node;
180     }
181     return nullptr;
182   };
183 
184   auto get_filter_node = [](const NodeDef& node) -> const NodeDef* {
185     // TODO(b/148614315): Support captured inputs.
186     if (node.op() == "FilterDataset" && node.input_size() == 1) return &node;
187     return nullptr;
188   };
189 
190   auto make_fused_function = [&function_library, &output](
191                                  const NodeDef* map_node,
192                                  const NodeDef* filter_node) -> FunctionDef* {
193     const auto& parent_fun = map_node->attr().at("f");
194     const FunctionDef* map_func =
195         function_library.Find(parent_fun.func().name());
196     const auto& fun = filter_node->attr().at("predicate");
197     const FunctionDef* filter_func = function_library.Find(fun.func().name());
198     if (!fusion_utils::CanCompose(map_func->signature(),
199                                   filter_func->signature())) {
200       VLOG(1) << "Can't fuse map and filter because the output signature of "
201                  "the map function does not match the input signature of the "
202                  "filter function\n";
203       return nullptr;
204     }
205     return fusion_utils::FuseFunctions(
206         *map_func, *filter_func, "fused_map_and_filter_function",
207         fusion_utils::CombineSignature, fusion_utils::ComposeInput,
208         fusion_utils::CombineOutput, fusion_utils::MergeNodes,
209         output->mutable_library());
210   };
211 
212   for (const NodeDef& node : sorted_old_graph.node()) {
213     const NodeDef* filter_node = get_filter_node(node);
214     if (!filter_node) continue;
215 
216     const NodeDef* map_node =
217         get_map_node(*graph_utils::GetInputNode(*filter_node, graph));
218     if (!map_node) continue;
219 
220     const auto* fused_function = make_fused_function(map_node, filter_node);
221     if (fused_function == nullptr) continue;
222 
223     const auto* fused_maps = graph.AddNode(
224         MakeFusedNode(*map_node, *filter_node, *fused_function, &graph));
225 
226     const auto* new_filter_node = graph.AddNode(MakeFilterNode(
227         *fused_maps, *fused_function, &graph, output->mutable_library()));
228 
229     const auto* new_map_node =
230         graph.AddNode(MakeMapNode(*new_filter_node, *map_node, *fused_function,
231                                   &graph, output->mutable_library()));
232 
233     TF_RETURN_IF_ERROR(
234         graph.UpdateFanouts(filter_node->name(), new_map_node->name()));
235     TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
236 
237     nodes_to_delete.insert(map_node->name());
238     nodes_to_delete.insert(filter_node->name());
239     stats->num_changes++;
240   }
241 
242   TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
243   return OkStatus();
244 }
245 
246 REGISTER_GRAPH_OPTIMIZER_AS(MapAndFilterFusion, "map_and_filter_fusion");
247 
248 }  // namespace grappler
249 }  // namespace tensorflow
250