1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/rewrite_utils.h"
16
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20
21 #include <algorithm>
22 #include <functional>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/substitute.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_runner.h"
35 #include "tensorflow/core/common_runtime/metrics.h"
36 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
37 #include "tensorflow/core/data/dataset_utils.h"
38 #include "tensorflow/core/data/hash_utils.h"
39 #include "tensorflow/core/data/serialization_utils.h"
40 #include "tensorflow/core/framework/dataset.h"
41 #include "tensorflow/core/framework/function.h"
42 #include "tensorflow/core/framework/function.pb.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/node_def.pb.h"
45 #include "tensorflow/core/framework/op.h"
46 #include "tensorflow/core/framework/op_def_util.h"
47 #include "tensorflow/core/framework/op_kernel.h"
48 #include "tensorflow/core/framework/tensor.h"
49 #include "tensorflow/core/graph/graph.h"
50 #include "tensorflow/core/graph/graph_def_builder.h"
51 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
52 #include "tensorflow/core/grappler/graph_view.h"
53 #include "tensorflow/core/grappler/grappler_item.h"
54 #include "tensorflow/core/grappler/grappler_item_builder.h"
55 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
56 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
57 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
58 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
59 #include "tensorflow/core/lib/hash/hash.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/status.h"
63 #include "tensorflow/core/platform/statusor.h"
64 #include "tensorflow/core/platform/tstring.h"
65 #include "tensorflow/core/protobuf/config.pb.h"
66 #include "tensorflow/core/protobuf/device_properties.pb.h"
67 #include "tensorflow/core/protobuf/meta_graph.pb.h"
68 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
69
70 namespace tensorflow {
71 namespace data {
72 namespace {
73
74 constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
75 constexpr char kOptimizers[] = "optimizers";
76 constexpr char kOptimizerConfigs[] = "optimizer_configs";
77
AddFakeSinks(FunctionDef * function_def)78 void AddFakeSinks(FunctionDef* function_def) {
79 int counter = 0;
80 for (const auto& output : function_def->signature().output_arg()) {
81 NodeDef* node = function_def->add_node_def();
82 tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
83 strings::StrCat("FakeSink", counter++), function_def, node);
84 node->set_op("Identity");
85 node->add_input(function_def->ret().at(output.name()));
86 (*node->mutable_attr())["T"].set_type(output.type());
87
88 (*function_def->mutable_ret())[output.name()] =
89 strings::StrCat(node->name(), ":output:0");
90 }
91 }
92
RemoveFakeSinks(FunctionDef * function_def)93 void RemoveFakeSinks(FunctionDef* function_def) {
94 // Map from identity node names to their input tensor strings
95 std::map<std::string, std::string> identity_map;
96 for (const auto& node : function_def->node_def()) {
97 if (node.op() == "Identity" && node.input_size() == 1) {
98 identity_map[node.name()] = node.input(0);
99 }
100 }
101 for (const auto& output_arg : function_def->signature().output_arg()) {
102 const std::string& tensor = function_def->ret().at(output_arg.name());
103 const std::string& output_node = tensor.substr(0, tensor.find(':'));
104 if (identity_map.find(output_node) != identity_map.end()) {
105 (*function_def->mutable_ret())[output_arg.name()] =
106 identity_map.at(output_node);
107 }
108 }
109 }
110
ApplyRewrites(OpKernelContext * ctx,const std::function<RewriterConfig (void)> config_factory,GraphDef * graph_def,string * dataset_node)111 Status ApplyRewrites(OpKernelContext* ctx,
112 const std::function<RewriterConfig(void)> config_factory,
113 GraphDef* graph_def, string* dataset_node) {
114 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
115 GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true);
116 std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
117 tensorflow::grappler::VirtualCluster cluster(device_map);
118
119 // Run data optimizer using grappler's meta optimizer.
120 tensorflow::ConfigProto config;
121 *config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
122 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
123 std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
124
125 // Remove fake sinks after optimizations are done.
126 //
127 // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
128 // to be optimizable, we will no longer need this.
129 for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
130 RemoveFakeSinks(&function_def);
131 }
132
133 return Status::OK();
134 }
135 } // anonymous namespace
136
CreateRewriterConfig(const absl::flat_hash_set<tstring> & optimizations,const absl::flat_hash_set<tstring> & optimizations_configs)137 RewriterConfig CreateRewriterConfig(
138 const absl::flat_hash_set<tstring>& optimizations,
139 const absl::flat_hash_set<tstring>& optimizations_configs) {
140 RewriterConfig rewriter_config;
141 rewriter_config.add_optimizers(kOptimizerName);
142 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
143 rewriter_config.set_fail_on_optimizer_errors(true);
144 auto custom_optimizer = rewriter_config.add_custom_optimizers();
145 custom_optimizer->set_name(kOptimizerName);
146 auto* custom_optimizations_list =
147 (*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
148 const auto& registered_optimizers =
149 grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
150 for (const auto& optimization : optimizations) {
151 if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
152 optimization) != registered_optimizers.end()) {
153 custom_optimizations_list->add_s(optimization.data(),
154 optimization.size());
155 } else {
156 VLOG(1) << "Optimization " << optimization << " is not registered.";
157 }
158 }
159 auto* config_list =
160 (*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
161 .mutable_list();
162 for (const auto& config : optimizations_configs) {
163 config_list->add_s(config.data(), config.size());
164 }
165 return rewriter_config;
166 }
167
RewriteDataset(OpKernelContext * ctx,const DatasetBase * input,std::function<RewriterConfig (void)> config_factory,bool record_fingerprint,DatasetBase ** rewritten_input)168 Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
169 std::function<RewriterConfig(void)> config_factory,
170 bool record_fingerprint, DatasetBase** rewritten_input) {
171 std::vector<std::pair<string, Tensor>> input_list;
172 GraphDef graph_def;
173 string output_node;
174 TF_RETURN_IF_ERROR(
175 AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node));
176
177 VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
178 TF_RETURN_IF_ERROR(
179 ApplyRewrites(ctx, config_factory, &graph_def, &output_node));
180 VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
181
182 // Instantiate the optimized input pipeline by running the optimized graph
183 // using the optimized function library.
184 FunctionLibraryRuntime* flr = nullptr;
185 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
186 std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
187 TF_RETURN_IF_ERROR(
188 ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
189
190 // Some functions may have been modified without having their names changed
191 // (for example, nested dataset graphs from FlatMap or Interleave).
192 TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
193
194 Graph graph(OpRegistry::Global());
195 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
196 std::vector<Tensor> outputs;
197 GraphRunner graph_runner(flr->device());
198
199 TF_RETURN_IF_ERROR(
200 graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
201 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
202 (*rewritten_input)->Ref();
203
204 if (record_fingerprint) {
205 (*ctx->runner())([graph_def = std::move(graph_def),
206 lib_def = lib_def.release(),
207 input_list = std::move(input_list),
208 output_node = std::move(output_node)]() {
209 std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
210 const NodeDef* node_def = nullptr;
211 for (const auto& node : graph_def.node()) {
212 if (node.name() == output_node) {
213 node_def = &node;
214 break;
215 }
216 }
217 if (node_def == nullptr) {
218 VLOG(3) << "Failed to find node: " << output_node;
219 return;
220 }
221 uint64 hash = 0;
222 Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
223 if (!s.ok()) {
224 VLOG(3) << "Failed to hash graph: " << s.ToString();
225 return;
226 }
227 for (const auto& pair : input_list) {
228 hash = Hash64CombineUnordered(hash, Hash64(pair.first));
229 uint64 tensor_hash = 0;
230 Status s = HashTensor(pair.second, &tensor_hash);
231 if (s.ok()) {
232 hash = Hash64CombineUnordered(hash, tensor_hash);
233 } else {
234 VLOG(3) << "Failed to hash tensor: " << s.ToString();
235 }
236 }
237 string graph_hash =
238 strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
239 metrics::RecordTFDataFingerprint(graph_hash);
240 });
241 }
242
243 return Status::OK();
244 }
245
GetGrapplerItem(GraphDef * graph_def,std::string * dataset_node,bool add_fake_sinks)246 std::unique_ptr<tensorflow::grappler::GrapplerItem> GetGrapplerItem(
247 GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) {
248 // Add an identity node as the fetch node, otherwise we might get 'placeholder
249 // is both fed and fetched' errors in some cases when using input list with
250 // placeholder dataset nodes.
251 NodeDef* node = graph_def->mutable_node()->Add();
252 tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
253 node);
254 node->set_op("Identity");
255 node->add_input(*dataset_node);
256 (*node->mutable_attr())["T"].set_type(DT_VARIANT);
257 *dataset_node = node->name();
258
259 if (add_fake_sinks) {
260 // Add fake sink node to graph and functions to allow rewriting the actual
261 // sink nodes.
262 //
263 // TODO(b/118820916): When MetaOptimizer adds provisions for function
264 // retvals to be optimizable, we will no longer need this.
265 for (auto& function_def :
266 *graph_def->mutable_library()->mutable_function()) {
267 AddFakeSinks(&function_def);
268 }
269 }
270
271 // Create metagraph.
272 MetaGraphDef meta_graph_def;
273 (*meta_graph_def.mutable_graph_def()) = *graph_def;
274
275 // Grappler determines fetch ops from collection 'train_op'.
276 CollectionDef collection_def;
277 auto node_list = collection_def.mutable_node_list();
278 node_list->add_value(*dataset_node);
279 (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
280
281 // Create Grappler item.
282 tensorflow::grappler::ItemConfig item_config;
283 item_config.apply_optimizations = true;
284 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
285 tensorflow::grappler::GrapplerItemFromMetaGraphDef(
286 "graph", meta_graph_def, item_config);
287 // Grappler should not optimize function library of tf.data graphs. The
288 // tf.data meta optimizer takes care of optimizing tf.data functions.
289 grappler_item->optimization_options().optimize_function_library = false;
290 return grappler_item;
291 }
292
GetDatasetNode(const GraphDef & graph_def)293 StatusOr<std::string> GetDatasetNode(const GraphDef& graph_def) {
294 // Symbolic `_Retval` node indicates which node corresponds to the dataset.
295 for (const auto& node : graph_def.node()) {
296 if (node.op() == "_Retval") {
297 return node.input(0);
298 }
299 }
300 return errors::NotFound(
301 absl::Substitute("Dataset node for graph is not found:\n$0",
302 graph_def.ShortDebugString()));
303 }
304
305 } // namespace data
306 } // namespace tensorflow
307 #endif // !IS_MOBILE_PLATFORM
308