1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/grappler_item_builder.h"
26 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28
29 namespace tensorflow {
30 namespace {
31
32 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
33
34 // Returns the set of ops that we want to generate shared_names for them if
35 // empty.
GetSharedNameGenerationCompatibleOps()36 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
37 static auto* const ops = new llvm::StringSet<>({"VariableV2", "Variable"});
38 return *ops;
39 }
40
41 } // namespace
42
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)43 Status GenerateResourceSharedNameIfEmpty(
44 GraphDef& gdef, const OpRegistryInterface* default_registry) {
45 auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
46 const OpDef& op_def) {
47 if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
48 // If this op is not in the allowlist, then it is likely a custom op.
49 // Currently for these ops, we are relying on its "use_node_name_sharing"
50 // to decide whether it is valid to generate shared_names. If the OpDef
51 // has "use_node_name_sharing" field, then it is valid to use node names
52 // as shared names.
53 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
54 [](const auto& attr_def) {
55 return attr_def.name() == "use_node_name_sharing" &&
56 attr_def.type() == "bool";
57 }))
58 return false;
59 }
60
61 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
62 [](const auto& attr_def) {
63 return attr_def.name() == "shared_name" &&
64 attr_def.type() == "string";
65 }))
66 return false;
67
68 auto iter = node_def.attr().find("shared_name");
69 if (iter == node_def.attr().end()) return true;
70 return iter->second.s().empty();
71 };
72
73 FunctionDefLibrary* library = gdef.mutable_library();
74 auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
75 default_registry, *library)
76 : std::make_unique<FunctionLibraryDefinition>(
77 default_registry, FunctionDefLibrary());
78
79 if (library) {
80 // Upgrade nodes in the functions.
81 for (FunctionDef& fdef : *library->mutable_function()) {
82 auto func_name = fdef.signature().name();
83 for (auto& node_def : *fdef.mutable_node_def()) {
84 const OpDef* op_def = nullptr;
85 // With lazy loading, some functions might not be executed, thus we skip
86 // the node if the op is not registered.
87 if (flib_def->LookUpOpDef(node_def.op(), &op_def).ok() &&
88 is_resource_op_with_empty_shared_name(node_def, *op_def)) {
89 // Use the concat of function name and node name for such ops in a
90 // function as the shared_name. "@" is used as the separator because
91 // it is not allowed in the function name or the node name.
92 (*node_def.mutable_attr())["shared_name"].set_s(
93 absl::StrCat(node_def.name(), "@", func_name));
94 }
95 }
96 }
97 }
98
99 // Upgrade nodes in the GraphDef.
100 for (auto& node_def : *gdef.mutable_node()) {
101 const OpDef* op_def = nullptr;
102 TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
103 if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
104 (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
105 }
106 }
107
108 return tensorflow::Status::OK();
109 }
110
111 // The static device manager is used to avoid creating the new device every time
112 // RunGrappler() is called. In addition, the optimized graph may contain tensor
113 // protos that are only valid when the corresponding device is alive.
GetStaticDeviceMgr()114 static const DeviceMgr* GetStaticDeviceMgr() {
115 static const auto* const device_mgr = []() -> const DeviceMgr* {
116 std::vector<std::unique_ptr<Device>> devices;
117 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
118 // with dummy session config, which will conflict with user defined options
119 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
120 // only devices.
121 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
122 SessionOptions options;
123 auto status = cpu_factory->CreateDevices(
124 options, "/job:localhost/replica:0/task:0", &devices);
125 if (!status.ok()) {
126 LOG(ERROR) << "Failed to create devices for Grappler: " << status;
127 return nullptr;
128 }
129
130 return new StaticDeviceMgr(std::move(devices));
131 }();
132
133 return device_mgr;
134 }
135
RunGrappler(const MetaGraphDef & meta_graph_def)136 stream_executor::port::StatusOr<GraphDef> RunGrappler(
137 const MetaGraphDef& meta_graph_def) {
138 ConfigProto config_proto;
139 // Avoid grappler logic that lowers to v1 control flow.
140 config_proto.mutable_experimental()->set_use_tfrt(true);
141 config_proto.mutable_graph_options()
142 ->mutable_optimizer_options()
143 ->set_do_function_inlining(false);
144 // Do not skip grappler optimization even for small graphs.
145 config_proto.mutable_graph_options()
146 ->mutable_rewrite_options()
147 ->set_min_graph_nodes(-1);
148 // Disable function inlining because it may cause restore graphs to be removed
149 // as we optimize all graphs together.
150 config_proto.mutable_graph_options()
151 ->mutable_rewrite_options()
152 ->set_function_optimization(RewriterConfig::OFF);
153
154 grappler::ItemConfig item_config;
155 item_config.ignore_user_placement = false;
156 std::unique_ptr<grappler::GrapplerItem> item =
157 grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
158 item_config);
159 if (!item) {
160 return tensorflow::errors::Internal(
161 "Failed to create grappler item from MetaGraphDef.");
162 }
163
164 const auto* device_mgr = GetStaticDeviceMgr();
165 if (!device_mgr) {
166 return tensorflow::errors::Internal(
167 "Failed to get devices in RunGrappler().");
168 }
169
170 DeviceSet dev_set;
171 for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
172 grappler::VirtualCluster cluster(&dev_set);
173 Device* cpu_device = device_mgr->HostCPU();
174
175 GraphDef output_graph_def;
176 TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
177 std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
178
179 return output_graph_def;
180 }
181
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_tpu_nodes)182 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
183 bool restrict_functionalization_to_tpu_nodes) {
184 // If `restrict_functionalization_to_tpu_nodes` is true let filter function
185 // return true for `_tpu_replicate` nodes, otherwise don't set filter.
186 NodeFilter node_filter =
187 restrict_functionalization_to_tpu_nodes
188 ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
189 : NodeFilter{};
190 TF_RETURN_WITH_CONTEXT_IF_ERROR(
191 FunctionalizeControlFlow(graph, flib_def, node_filter,
192 /*include_functions=*/true),
193 "Failed to functionalize Control Flow V1 ops. Consider using Control "
194 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
195 "compat/v1/enable_control_flow_v2.");
196 return Status::OK();
197 }
198
199 } // namespace tensorflow
200