• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17 
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/grappler_item_builder.h"
26 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
33 
34 // Returns the set of ops that we want to generate shared_names for them if
35 // empty.
GetSharedNameGenerationCompatibleOps()36 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
37   static auto* const ops = new llvm::StringSet<>({"VariableV2", "Variable"});
38   return *ops;
39 }
40 
41 }  // namespace
42 
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)43 Status GenerateResourceSharedNameIfEmpty(
44     GraphDef& gdef, const OpRegistryInterface* default_registry) {
45   auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
46                                                   const OpDef& op_def) {
47     if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
48       // If this op is not in the allowlist, then it is likely a custom op.
49       // Currently for these ops, we are relying on its "use_node_name_sharing"
50       // to decide whether it is valid to generate shared_names. If the OpDef
51       // has "use_node_name_sharing" field, then it is valid to use node names
52       // as shared names.
53       if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
54                        [](const auto& attr_def) {
55                          return attr_def.name() == "use_node_name_sharing" &&
56                                 attr_def.type() == "bool";
57                        }))
58         return false;
59     }
60 
61     if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
62                      [](const auto& attr_def) {
63                        return attr_def.name() == "shared_name" &&
64                               attr_def.type() == "string";
65                      }))
66       return false;
67 
68     auto iter = node_def.attr().find("shared_name");
69     if (iter == node_def.attr().end()) return true;
70     return iter->second.s().empty();
71   };
72 
73   FunctionDefLibrary* library = gdef.mutable_library();
74   auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
75                                 default_registry, *library)
76                           : std::make_unique<FunctionLibraryDefinition>(
77                                 default_registry, FunctionDefLibrary());
78 
79   if (library) {
80     // Upgrade nodes in the functions.
81     for (FunctionDef& fdef : *library->mutable_function()) {
82       auto func_name = fdef.signature().name();
83       for (auto& node_def : *fdef.mutable_node_def()) {
84         const OpDef* op_def = nullptr;
85         // With lazy loading, some functions might not be executed, thus we skip
86         // the node if the op is not registered.
87         if (flib_def->LookUpOpDef(node_def.op(), &op_def).ok() &&
88             is_resource_op_with_empty_shared_name(node_def, *op_def)) {
89           // Use the concat of function name and node name for such ops in a
90           // function as the shared_name. "@" is used as the separator because
91           // it is not allowed in the function name or the node name.
92           (*node_def.mutable_attr())["shared_name"].set_s(
93               absl::StrCat(node_def.name(), "@", func_name));
94         }
95       }
96     }
97   }
98 
99   // Upgrade nodes in the GraphDef.
100   for (auto& node_def : *gdef.mutable_node()) {
101     const OpDef* op_def = nullptr;
102     TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
103     if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
104       (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
105     }
106   }
107 
108   return tensorflow::Status::OK();
109 }
110 
111 // The static device manager is used to avoid creating the new device every time
112 // RunGrappler() is called. In addition, the optimized graph may contain tensor
113 // protos that are only valid when the corresponding device is alive.
GetStaticDeviceMgr()114 static const DeviceMgr* GetStaticDeviceMgr() {
115   static const auto* const device_mgr = []() -> const DeviceMgr* {
116     std::vector<std::unique_ptr<Device>> devices;
117     // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
118     // with dummy session config, which will conflict with user defined options
119     // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
120     // only devices.
121     DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
122     SessionOptions options;
123     auto status = cpu_factory->CreateDevices(
124         options, "/job:localhost/replica:0/task:0", &devices);
125     if (!status.ok()) {
126       LOG(ERROR) << "Failed to create devices for Grappler: " << status;
127       return nullptr;
128     }
129 
130     return new StaticDeviceMgr(std::move(devices));
131   }();
132 
133   return device_mgr;
134 }
135 
RunGrappler(const MetaGraphDef & meta_graph_def)136 stream_executor::port::StatusOr<GraphDef> RunGrappler(
137     const MetaGraphDef& meta_graph_def) {
138   ConfigProto config_proto;
139   // Avoid grappler logic that lowers to v1 control flow.
140   config_proto.mutable_experimental()->set_use_tfrt(true);
141   config_proto.mutable_graph_options()
142       ->mutable_optimizer_options()
143       ->set_do_function_inlining(false);
144   // Do not skip grappler optimization even for small graphs.
145   config_proto.mutable_graph_options()
146       ->mutable_rewrite_options()
147       ->set_min_graph_nodes(-1);
148   // Disable function inlining because it may cause restore graphs to be removed
149   // as we optimize all graphs together.
150   config_proto.mutable_graph_options()
151       ->mutable_rewrite_options()
152       ->set_function_optimization(RewriterConfig::OFF);
153 
154   grappler::ItemConfig item_config;
155   item_config.ignore_user_placement = false;
156   std::unique_ptr<grappler::GrapplerItem> item =
157       grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
158                                              item_config);
159   if (!item) {
160     return tensorflow::errors::Internal(
161         "Failed to create grappler item from MetaGraphDef.");
162   }
163 
164   const auto* device_mgr = GetStaticDeviceMgr();
165   if (!device_mgr) {
166     return tensorflow::errors::Internal(
167         "Failed to get devices in RunGrappler().");
168   }
169 
170   DeviceSet dev_set;
171   for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
172   grappler::VirtualCluster cluster(&dev_set);
173   Device* cpu_device = device_mgr->HostCPU();
174 
175   GraphDef output_graph_def;
176   TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
177       std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
178 
179   return output_graph_def;
180 }
181 
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_tpu_nodes)182 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
183                           bool restrict_functionalization_to_tpu_nodes) {
184   // If `restrict_functionalization_to_tpu_nodes` is true let filter function
185   // return true for `_tpu_replicate` nodes, otherwise don't set filter.
186   NodeFilter node_filter =
187       restrict_functionalization_to_tpu_nodes
188           ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
189           : NodeFilter{};
190   TF_RETURN_WITH_CONTEXT_IF_ERROR(
191       FunctionalizeControlFlow(graph, flib_def, node_filter,
192                                /*include_functions=*/true),
193       "Failed to functionalize Control Flow V1 ops. Consider using Control "
194       "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
195       "compat/v1/enable_control_flow_v2.");
196   return Status::OK();
197 }
198 
199 }  // namespace tensorflow
200