• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/rewrite_utils.h"
16 
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 
21 #include <algorithm>
22 #include <functional>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/substitute.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_runner.h"
35 #include "tensorflow/core/common_runtime/metrics.h"
36 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
37 #include "tensorflow/core/data/dataset_utils.h"
38 #include "tensorflow/core/data/hash_utils.h"
39 #include "tensorflow/core/data/serialization_utils.h"
40 #include "tensorflow/core/framework/dataset.h"
41 #include "tensorflow/core/framework/function.h"
42 #include "tensorflow/core/framework/function.pb.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/node_def.pb.h"
45 #include "tensorflow/core/framework/op.h"
46 #include "tensorflow/core/framework/op_def_util.h"
47 #include "tensorflow/core/framework/op_kernel.h"
48 #include "tensorflow/core/framework/tensor.h"
49 #include "tensorflow/core/graph/graph.h"
50 #include "tensorflow/core/graph/graph_def_builder.h"
51 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
52 #include "tensorflow/core/grappler/graph_view.h"
53 #include "tensorflow/core/grappler/grappler_item.h"
54 #include "tensorflow/core/grappler/grappler_item_builder.h"
55 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
56 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
57 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
58 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
59 #include "tensorflow/core/lib/hash/hash.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/status.h"
63 #include "tensorflow/core/platform/statusor.h"
64 #include "tensorflow/core/platform/tstring.h"
65 #include "tensorflow/core/protobuf/config.pb.h"
66 #include "tensorflow/core/protobuf/device_properties.pb.h"
67 #include "tensorflow/core/protobuf/meta_graph.pb.h"
68 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
69 
70 namespace tensorflow {
71 namespace data {
72 namespace {
73 
74 constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
75 constexpr char kOptimizers[] = "optimizers";
76 constexpr char kOptimizerConfigs[] = "optimizer_configs";
77 
AddFakeSinks(FunctionDef * function_def)78 void AddFakeSinks(FunctionDef* function_def) {
79   int counter = 0;
80   for (const auto& output : function_def->signature().output_arg()) {
81     NodeDef* node = function_def->add_node_def();
82     tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
83         strings::StrCat("FakeSink", counter++), function_def, node);
84     node->set_op("Identity");
85     node->add_input(function_def->ret().at(output.name()));
86     (*node->mutable_attr())["T"].set_type(output.type());
87 
88     (*function_def->mutable_ret())[output.name()] =
89         strings::StrCat(node->name(), ":output:0");
90   }
91 }
92 
RemoveFakeSinks(FunctionDef * function_def)93 void RemoveFakeSinks(FunctionDef* function_def) {
94   // Map from identity node names to their input tensor strings
95   std::map<std::string, std::string> identity_map;
96   for (const auto& node : function_def->node_def()) {
97     if (node.op() == "Identity" && node.input_size() == 1) {
98       identity_map[node.name()] = node.input(0);
99     }
100   }
101   for (const auto& output_arg : function_def->signature().output_arg()) {
102     const std::string& tensor = function_def->ret().at(output_arg.name());
103     const std::string& output_node = tensor.substr(0, tensor.find(':'));
104     if (identity_map.find(output_node) != identity_map.end()) {
105       (*function_def->mutable_ret())[output_arg.name()] =
106           identity_map.at(output_node);
107     }
108   }
109 }
110 
ApplyRewrites(OpKernelContext * ctx,const std::function<RewriterConfig (void)> config_factory,GraphDef * graph_def,string * dataset_node)111 Status ApplyRewrites(OpKernelContext* ctx,
112                      const std::function<RewriterConfig(void)> config_factory,
113                      GraphDef* graph_def, string* dataset_node) {
114   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
115       GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true);
116   std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
117   tensorflow::grappler::VirtualCluster cluster(device_map);
118 
119   // Run data optimizer using grappler's meta optimizer.
120   tensorflow::ConfigProto config;
121   *config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
122   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
123       std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
124 
125   // Remove fake sinks after optimizations are done.
126   //
127   // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
128   // to be optimizable, we will no longer need this.
129   for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
130     RemoveFakeSinks(&function_def);
131   }
132 
133   return Status::OK();
134 }
135 }  // anonymous namespace
136 
CreateRewriterConfig(const absl::flat_hash_set<tstring> & optimizations,const absl::flat_hash_set<tstring> & optimizations_configs)137 RewriterConfig CreateRewriterConfig(
138     const absl::flat_hash_set<tstring>& optimizations,
139     const absl::flat_hash_set<tstring>& optimizations_configs) {
140   RewriterConfig rewriter_config;
141   rewriter_config.add_optimizers(kOptimizerName);
142   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
143   rewriter_config.set_fail_on_optimizer_errors(true);
144   auto custom_optimizer = rewriter_config.add_custom_optimizers();
145   custom_optimizer->set_name(kOptimizerName);
146   auto* custom_optimizations_list =
147       (*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
148   const auto& registered_optimizers =
149       grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
150   for (const auto& optimization : optimizations) {
151     if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
152                   optimization) != registered_optimizers.end()) {
153       custom_optimizations_list->add_s(optimization.data(),
154                                        optimization.size());
155     } else {
156       VLOG(1) << "Optimization " << optimization << " is not registered.";
157     }
158   }
159   auto* config_list =
160       (*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
161           .mutable_list();
162   for (const auto& config : optimizations_configs) {
163     config_list->add_s(config.data(), config.size());
164   }
165   return rewriter_config;
166 }
167 
RewriteDataset(OpKernelContext * ctx,const DatasetBase * input,std::function<RewriterConfig (void)> config_factory,bool record_fingerprint,DatasetBase ** rewritten_input)168 Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
169                       std::function<RewriterConfig(void)> config_factory,
170                       bool record_fingerprint, DatasetBase** rewritten_input) {
171   std::vector<std::pair<string, Tensor>> input_list;
172   GraphDef graph_def;
173   string output_node;
174   TF_RETURN_IF_ERROR(
175       AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node));
176 
177   VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
178   TF_RETURN_IF_ERROR(
179       ApplyRewrites(ctx, config_factory, &graph_def, &output_node));
180   VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
181 
182   // Instantiate the optimized input pipeline by running the optimized graph
183   // using the optimized function library.
184   FunctionLibraryRuntime* flr = nullptr;
185   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
186   std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
187   TF_RETURN_IF_ERROR(
188       ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
189 
190   // Some functions may have been modified without having their names changed
191   // (for example, nested dataset graphs from FlatMap or Interleave).
192   TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
193 
194   Graph graph(OpRegistry::Global());
195   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
196   std::vector<Tensor> outputs;
197   GraphRunner graph_runner(flr->device());
198 
199   TF_RETURN_IF_ERROR(
200       graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
201   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
202   (*rewritten_input)->Ref();
203 
204   if (record_fingerprint) {
205     (*ctx->runner())([graph_def = std::move(graph_def),
206                       lib_def = lib_def.release(),
207                       input_list = std::move(input_list),
208                       output_node = std::move(output_node)]() {
209       std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
210       const NodeDef* node_def = nullptr;
211       for (const auto& node : graph_def.node()) {
212         if (node.name() == output_node) {
213           node_def = &node;
214           break;
215         }
216       }
217       if (node_def == nullptr) {
218         VLOG(3) << "Failed to find node: " << output_node;
219         return;
220       }
221       uint64 hash = 0;
222       Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
223       if (!s.ok()) {
224         VLOG(3) << "Failed to hash graph: " << s.ToString();
225         return;
226       }
227       for (const auto& pair : input_list) {
228         hash = Hash64CombineUnordered(hash, Hash64(pair.first));
229         uint64 tensor_hash = 0;
230         Status s = HashTensor(pair.second, &tensor_hash);
231         if (s.ok()) {
232           hash = Hash64CombineUnordered(hash, tensor_hash);
233         } else {
234           VLOG(3) << "Failed to hash tensor: " << s.ToString();
235         }
236       }
237       string graph_hash =
238           strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
239       metrics::RecordTFDataFingerprint(graph_hash);
240     });
241   }
242 
243   return Status::OK();
244 }
245 
GetGrapplerItem(GraphDef * graph_def,std::string * dataset_node,bool add_fake_sinks)246 std::unique_ptr<tensorflow::grappler::GrapplerItem> GetGrapplerItem(
247     GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) {
248   // Add an identity node as the fetch node, otherwise we might get 'placeholder
249   // is both fed and fetched' errors in some cases when using input list with
250   // placeholder dataset nodes.
251   NodeDef* node = graph_def->mutable_node()->Add();
252   tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
253                                                             node);
254   node->set_op("Identity");
255   node->add_input(*dataset_node);
256   (*node->mutable_attr())["T"].set_type(DT_VARIANT);
257   *dataset_node = node->name();
258 
259   if (add_fake_sinks) {
260     // Add fake sink node to graph and functions to allow rewriting the actual
261     // sink nodes.
262     //
263     // TODO(b/118820916): When MetaOptimizer adds provisions for function
264     // retvals to be optimizable, we will no longer need this.
265     for (auto& function_def :
266          *graph_def->mutable_library()->mutable_function()) {
267       AddFakeSinks(&function_def);
268     }
269   }
270 
271   // Create metagraph.
272   MetaGraphDef meta_graph_def;
273   (*meta_graph_def.mutable_graph_def()) = *graph_def;
274 
275   // Grappler determines fetch ops from collection 'train_op'.
276   CollectionDef collection_def;
277   auto node_list = collection_def.mutable_node_list();
278   node_list->add_value(*dataset_node);
279   (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
280 
281   // Create Grappler item.
282   tensorflow::grappler::ItemConfig item_config;
283   item_config.apply_optimizations = true;
284   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
285       tensorflow::grappler::GrapplerItemFromMetaGraphDef(
286           "graph", meta_graph_def, item_config);
287   // Grappler should not optimize function library of tf.data graphs. The
288   // tf.data meta optimizer takes care of optimizing tf.data functions.
289   grappler_item->optimization_options().optimize_function_library = false;
290   return grappler_item;
291 }
292 
GetDatasetNode(const GraphDef & graph_def)293 StatusOr<std::string> GetDatasetNode(const GraphDef& graph_def) {
294   // Symbolic `_Retval` node indicates which node corresponds to the dataset.
295   for (const auto& node : graph_def.node()) {
296     if (node.op() == "_Retval") {
297       return node.input(0);
298     }
299   }
300   return errors::NotFound(
301       absl::Substitute("Dataset node for graph is not found:\n$0",
302                        graph_def.ShortDebugString()));
303 }
304 
305 }  // namespace data
306 }  // namespace tensorflow
307 #endif  // !IS_MOBILE_PLATFORM
308