• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17 
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/grappler_item_builder.h"
26 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
33 
34 // Returns the set of ops that we want to generate shared_names for them if
35 // empty.
GetSharedNameGenerationCompatibleOps()36 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
37   static auto* const ops = new llvm::StringSet<>({"VariableV2", "Variable"});
38   return *ops;
39 }
40 
41 }  // namespace
42 
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)43 Status GenerateResourceSharedNameIfEmpty(
44     GraphDef& gdef, const OpRegistryInterface* default_registry) {
45   auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
46                                                   const OpDef& op_def) {
47     if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
48       // If this op is not in the allowlist, then it is likely a custom op.
49       // Currently for these ops, we are relying on its "use_node_name_sharing"
50       // to decide whether it is valid to generate shared_names. If the OpDef
51       // has "use_node_name_sharing" field, then it is valid to use node names
52       // as shared names.
53       if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
54                        [](const auto& attr_def) {
55                          return attr_def.name() == "use_node_name_sharing" &&
56                                 attr_def.type() == "bool";
57                        }))
58         return false;
59     }
60 
61     if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
62                      [](const auto& attr_def) {
63                        return attr_def.name() == "shared_name" &&
64                               attr_def.type() == "string";
65                      }))
66       return false;
67 
68     auto iter = node_def.attr().find("shared_name");
69     if (iter == node_def.attr().end()) return true;
70     return iter->second.s().empty();
71   };
72 
73   FunctionDefLibrary* library = gdef.mutable_library();
74   auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
75                                 default_registry, *library)
76                           : std::make_unique<FunctionLibraryDefinition>(
77                                 default_registry, FunctionDefLibrary());
78 
79   if (library) {
80     // Upgrade nodes in the functions.
81     for (FunctionDef& fdef : *library->mutable_function()) {
82       auto func_name = fdef.signature().name();
83       for (auto& node_def : *fdef.mutable_node_def()) {
84         const OpDef* op_def = nullptr;
85         TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
86         if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
87           // Use the concat of function name and node name for such ops in a
88           // function as the shared_name. "@" is used as the separator because
89           // it is not allowed in the function name or the node name.
90           (*node_def.mutable_attr())["shared_name"].set_s(
91               absl::StrCat(node_def.name(), "@", func_name));
92         }
93       }
94     }
95   }
96 
97   // Upgrade nodes in the GraphDef.
98   for (auto& node_def : *gdef.mutable_node()) {
99     const OpDef* op_def = nullptr;
100     TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
101     if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
102       (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
103     }
104   }
105 
106   return tensorflow::Status::OK();
107 }
108 
109 // The static device manager is used to avoid creating the new device every time
110 // RunGrappler() is called. In addition, the optimized graph may contain tensor
111 // protos that are only valid when the corresponding device is alive.
GetStaticDeviceMgr()112 static const DeviceMgr* GetStaticDeviceMgr() {
113   static const auto* const device_mgr = []() -> const DeviceMgr* {
114     std::vector<std::unique_ptr<Device>> devices;
115     // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
116     // with dummy session config, which will conflict with user defined options
117     // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
118     // only devices.
119     DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
120     SessionOptions options;
121     auto status = cpu_factory->CreateDevices(
122         options, "/job:localhost/replica:0/task:0", &devices);
123     if (!status.ok()) {
124       LOG(ERROR) << "Failed to create devices for Grappler: " << status;
125       return nullptr;
126     }
127 
128     return new StaticDeviceMgr(std::move(devices));
129   }();
130 
131   return device_mgr;
132 }
133 
RunGrappler(const MetaGraphDef & meta_graph_def)134 stream_executor::port::StatusOr<GraphDef> RunGrappler(
135     const MetaGraphDef& meta_graph_def) {
136   ConfigProto config_proto;
137   // Avoid grappler logic that lowers to v1 control flow.
138   config_proto.mutable_experimental()->set_use_tfrt(true);
139   config_proto.mutable_graph_options()
140       ->mutable_optimizer_options()
141       ->set_do_function_inlining(false);
142   // Do not skip grappler optimization even for small graphs.
143   config_proto.mutable_graph_options()
144       ->mutable_rewrite_options()
145       ->set_min_graph_nodes(-1);
146   // Disable function inlining because it may cause restore graphs to be removed
147   // as we optimize all graphs together.
148   config_proto.mutable_graph_options()
149       ->mutable_rewrite_options()
150       ->set_function_optimization(RewriterConfig::OFF);
151 
152   grappler::ItemConfig item_config;
153   item_config.ignore_user_placement = false;
154   std::unique_ptr<grappler::GrapplerItem> item =
155       grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
156                                              item_config);
157   if (!item) {
158     return tensorflow::errors::Internal(
159         "Failed to create grappler item from MetaGraphDef.");
160   }
161 
162   const auto* device_mgr = GetStaticDeviceMgr();
163   if (!device_mgr) {
164     return tensorflow::errors::Internal(
165         "Failed to get devices in RunGrappler().");
166   }
167 
168   DeviceSet dev_set;
169   for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
170   grappler::VirtualCluster cluster(&dev_set);
171   Device* cpu_device = device_mgr->HostCPU();
172 
173   GraphDef output_graph_def;
174   TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
175       std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
176 
177   return output_graph_def;
178 }
179 
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_tpu_nodes)180 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
181                           bool restrict_functionalization_to_tpu_nodes) {
182   // If `restrict_functionalization_to_tpu_nodes` is true let filter function
183   // return true for `_tpu_replicate` nodes, otherwise don't set filter.
184   NodeFilter node_filter =
185       restrict_functionalization_to_tpu_nodes
186           ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
187           : NodeFilter{};
188   TF_RETURN_WITH_CONTEXT_IF_ERROR(
189       FunctionalizeControlFlow(graph, flib_def, node_filter,
190                                /*include_functions=*/true),
191       "Failed to functionalize Control Flow V1 ops. Consider using Control "
192       "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
193       "compat/v1/enable_control_flow_v2.");
194   return Status::OK();
195 }
196 
197 }  // namespace tensorflow
198