1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/grappler_item_builder.h"
26 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28
29 namespace tensorflow {
30 namespace {
31
32 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
33
34 // Returns the set of ops that we want to generate shared_names for them if
35 // empty.
GetSharedNameGenerationCompatibleOps()36 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
37 static auto* const ops = new llvm::StringSet<>({"VariableV2", "Variable"});
38 return *ops;
39 }
40
41 } // namespace
42
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)43 Status GenerateResourceSharedNameIfEmpty(
44 GraphDef& gdef, const OpRegistryInterface* default_registry) {
45 auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
46 const OpDef& op_def) {
47 if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
48 // If this op is not in the allowlist, then it is likely a custom op.
49 // Currently for these ops, we are relying on its "use_node_name_sharing"
50 // to decide whether it is valid to generate shared_names. If the OpDef
51 // has "use_node_name_sharing" field, then it is valid to use node names
52 // as shared names.
53 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
54 [](const auto& attr_def) {
55 return attr_def.name() == "use_node_name_sharing" &&
56 attr_def.type() == "bool";
57 }))
58 return false;
59 }
60
61 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
62 [](const auto& attr_def) {
63 return attr_def.name() == "shared_name" &&
64 attr_def.type() == "string";
65 }))
66 return false;
67
68 auto iter = node_def.attr().find("shared_name");
69 if (iter == node_def.attr().end()) return true;
70 return iter->second.s().empty();
71 };
72
73 FunctionDefLibrary* library = gdef.mutable_library();
74 auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
75 default_registry, *library)
76 : std::make_unique<FunctionLibraryDefinition>(
77 default_registry, FunctionDefLibrary());
78
79 if (library) {
80 // Upgrade nodes in the functions.
81 for (FunctionDef& fdef : *library->mutable_function()) {
82 auto func_name = fdef.signature().name();
83 for (auto& node_def : *fdef.mutable_node_def()) {
84 const OpDef* op_def = nullptr;
85 TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
86 if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
87 // Use the concat of function name and node name for such ops in a
88 // function as the shared_name. "@" is used as the separator because
89 // it is not allowed in the function name or the node name.
90 (*node_def.mutable_attr())["shared_name"].set_s(
91 absl::StrCat(node_def.name(), "@", func_name));
92 }
93 }
94 }
95 }
96
97 // Upgrade nodes in the GraphDef.
98 for (auto& node_def : *gdef.mutable_node()) {
99 const OpDef* op_def = nullptr;
100 TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
101 if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
102 (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
103 }
104 }
105
106 return tensorflow::Status::OK();
107 }
108
109 // The static device manager is used to avoid creating the new device every time
110 // RunGrappler() is called. In addition, the optimized graph may contain tensor
111 // protos that are only valid when the corresponding device is alive.
GetStaticDeviceMgr()112 static const DeviceMgr* GetStaticDeviceMgr() {
113 static const auto* const device_mgr = []() -> const DeviceMgr* {
114 std::vector<std::unique_ptr<Device>> devices;
115 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
116 // with dummy session config, which will conflict with user defined options
117 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
118 // only devices.
119 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
120 SessionOptions options;
121 auto status = cpu_factory->CreateDevices(
122 options, "/job:localhost/replica:0/task:0", &devices);
123 if (!status.ok()) {
124 LOG(ERROR) << "Failed to create devices for Grappler: " << status;
125 return nullptr;
126 }
127
128 return new StaticDeviceMgr(std::move(devices));
129 }();
130
131 return device_mgr;
132 }
133
RunGrappler(const MetaGraphDef & meta_graph_def)134 stream_executor::port::StatusOr<GraphDef> RunGrappler(
135 const MetaGraphDef& meta_graph_def) {
136 ConfigProto config_proto;
137 // Avoid grappler logic that lowers to v1 control flow.
138 config_proto.mutable_experimental()->set_use_tfrt(true);
139 config_proto.mutable_graph_options()
140 ->mutable_optimizer_options()
141 ->set_do_function_inlining(false);
142 // Do not skip grappler optimization even for small graphs.
143 config_proto.mutable_graph_options()
144 ->mutable_rewrite_options()
145 ->set_min_graph_nodes(-1);
146 // Disable function inlining because it may cause restore graphs to be removed
147 // as we optimize all graphs together.
148 config_proto.mutable_graph_options()
149 ->mutable_rewrite_options()
150 ->set_function_optimization(RewriterConfig::OFF);
151
152 grappler::ItemConfig item_config;
153 item_config.ignore_user_placement = false;
154 std::unique_ptr<grappler::GrapplerItem> item =
155 grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
156 item_config);
157 if (!item) {
158 return tensorflow::errors::Internal(
159 "Failed to create grappler item from MetaGraphDef.");
160 }
161
162 const auto* device_mgr = GetStaticDeviceMgr();
163 if (!device_mgr) {
164 return tensorflow::errors::Internal(
165 "Failed to get devices in RunGrappler().");
166 }
167
168 DeviceSet dev_set;
169 for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
170 grappler::VirtualCluster cluster(&dev_set);
171 Device* cpu_device = device_mgr->HostCPU();
172
173 GraphDef output_graph_def;
174 TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
175 std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
176
177 return output_graph_def;
178 }
179
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_tpu_nodes)180 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
181 bool restrict_functionalization_to_tpu_nodes) {
182 // If `restrict_functionalization_to_tpu_nodes` is true let filter function
183 // return true for `_tpu_replicate` nodes, otherwise don't set filter.
184 NodeFilter node_filter =
185 restrict_functionalization_to_tpu_nodes
186 ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
187 : NodeFilter{};
188 TF_RETURN_WITH_CONTEXT_IF_ERROR(
189 FunctionalizeControlFlow(graph, flib_def, node_filter,
190 /*include_functions=*/true),
191 "Failed to functionalize Control Flow V1 ops. Consider using Control "
192 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
193 "compat/v1/enable_control_flow_v2.");
194 return Status::OK();
195 }
196
197 } // namespace tensorflow
198