1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/device_set.h"
17 #include "tensorflow/core/common_runtime/eager/context.h"
18 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
19 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
20 #include "tensorflow/core/common_runtime/placer.h"
21 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
22 #include "tensorflow/core/framework/graph_to_functiondef.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
25 #include "tfrt/host_context/device.h" // from @tf_runtime
26 #include "tfrt/support/error_util.h" // from @tf_runtime
27
28 namespace tensorflow {
29
30 namespace {
31 constexpr char kDefaultCpuDeviceName[] = "CPU:0";
32 } // namespace
33
TransformGraphFunction(const std::string & func_name,const FunctionDef & fdef,const std::string & device_name,const tensorflow::DeviceSet & device_set,EagerContext * eager_ctx,bool enable_grappler,std::unique_ptr<FunctionBody> * fbody,std::unique_ptr<Graph> graph,tfrt::ArrayRef<const tfrt::Device * > input_devices,FunctionLibraryDefinition * func_lib_def)34 Status TransformGraphFunction(const std::string& func_name,
35 const FunctionDef& fdef,
36 const std::string& device_name,
37 const tensorflow::DeviceSet& device_set,
38 EagerContext* eager_ctx, bool enable_grappler,
39 std::unique_ptr<FunctionBody>* fbody,
40 std::unique_ptr<Graph> graph,
41 tfrt::ArrayRef<const tfrt::Device*> input_devices,
42 FunctionLibraryDefinition* func_lib_def) {
43 const DeviceMgr* device_mgr = eager_ctx->local_device_mgr();
44 if (device_mgr == nullptr)
45 return errors::Internal("Cannot find device manager");
46 DumpGraph("Input function graph", graph.get());
47
48 std::vector<string> ret_node_names;
49 std::vector<string> control_ret_node_names;
50 // Mapping from a function body node name to the control output name.
51 std::unordered_map<string, string> node_name_to_control_ret;
52 std::vector<Node*> arg_nodes, ret_nodes;
53 DataTypeVector ret_types;
54 auto attrs = AttrSlice(&fdef.attr());
55 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
56 func_name, attrs, &fdef, func_lib_def, &graph, &arg_nodes, &ret_nodes,
57 &ret_node_names, &ret_types, &control_ret_node_names));
58 for (const auto& control_ret : fdef.control_ret()) {
59 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
60 }
61 for (Node* node : arg_nodes) {
62 const AttrValue* attr_value;
63 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
64 int64_t index = attr_value->i();
65 node->set_assigned_device_name(input_devices[index]->name().str());
66 }
67
68 std::vector<string> input_device_names;
69 int input_size = input_devices.size();
70 input_device_names.reserve(input_size);
71 for (int i = 0; i < input_size; ++i) {
72 input_device_names.push_back(input_devices[i]->name().str());
73 }
74
75 std::vector<string> output_device_names;
76 int output_size = fdef.signature().output_arg_size();
77 output_device_names.reserve(output_size);
78 for (int i = 0; i < output_size; ++i) {
79 output_device_names.push_back(device_name);
80 }
81
82 // set default_device for placer.
83 Device* default_device = nullptr;
84 tensorflow::Status s = device_mgr->LookupDevice(device_name, &default_device);
85 if (!s.ok())
86 VLOG(1) << "TransformGraphFunction(): " << device_name << " is unknown."
87 << " default device for placer is not set.";
88
89 TF_RETURN_IF_ERROR(ProcessFunctionLibraryRuntime::PinArgsAndRets(
90 input_device_names, output_device_names, device_set, arg_nodes, ret_nodes,
91 func_lib_def,
92 eager_ctx->AllowSoftPlacement() ? default_device : nullptr));
93 DumpGraph("After running PinArgsAndRets", graph.get());
94
95 ConfigProto config;
96 bool control_rets_updated = false;
97 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
98 device_set, config, &graph, func_lib_def, &control_ret_node_names,
99 &control_rets_updated));
100
101 if (control_rets_updated) {
102 // Function graph pass may have resulted in different nodes/node names for
103 // control rets.
104 for (const auto& control_ret : control_ret_node_names) {
105 node_name_to_control_ret.emplace(control_ret, control_ret);
106 }
107 } else {
108 for (const auto& control_ret : fdef.control_ret()) {
109 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
110 }
111 }
112 DumpGraph("After running function optimization pass (bridge)", graph.get());
113
114 // Run function inlining so that placer can place ops in nested functions.
115 GraphOptimizationPassOptions optimization_options;
116 SessionOptions session_options;
117 // In TFRT we don't lower v2 control flow to v1.
118 session_options.config.mutable_experimental()->set_use_tfrt(true);
119 session_options.config.mutable_graph_options()
120 ->mutable_optimizer_options()
121 ->set_do_function_inlining(true);
122 optimization_options.session_options = &session_options;
123 optimization_options.graph = &graph;
124 optimization_options.flib_def = func_lib_def;
125 optimization_options.device_set = &device_set;
126 optimization_options.is_function_graph = true;
127 optimization_options.default_function_device = default_device;
128 optimization_options.function_def = &fdef;
129
130 LowerFunctionalOpsPass pass;
131 TF_RETURN_IF_ERROR(pass.Run(optimization_options));
132
133 // Run placer before importing GraphDef to MLIR.
134 Placer placer(graph.get(), func_name, func_lib_def, &device_set,
135 default_device, eager_ctx->AllowSoftPlacement(),
136 /*log_device_placement=*/false);
137 TF_RETURN_IF_ERROR(placer.Run());
138 DumpGraph("After running placer", graph.get());
139
140 if (enable_grappler) {
141 Device* cpu_device;
142 TF_RETURN_IF_ERROR(
143 device_mgr->LookupDevice(kDefaultCpuDeviceName, &cpu_device));
144
145 ConfigProto config_proto;
146 config_proto.mutable_experimental()->set_use_tfrt(true);
147 config_proto.mutable_graph_options()
148 ->mutable_optimizer_options()
149 ->set_do_function_inlining(true);
150 // Do not skip grappler optimization even for small graphs.
151 config_proto.mutable_graph_options()
152 ->mutable_rewrite_options()
153 ->set_min_graph_nodes(-1);
154
155 grappler::GrapplerItem::OptimizationOptions grappler_options =
156 grappler::CreateOptOptionsForEager();
157 auto status = grappler::OptimizeGraph(
158 std::move(ret_node_names), std::move(control_ret_node_names),
159 func_lib_def, device_set, cpu_device, config_proto,
160 fdef.signature().name(), grappler_options, &graph);
161 if (!status.ok()) {
162 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
163 << status.ToString();
164 }
165 DumpGraph("After grappler optimization", graph.get());
166 }
167
168 // We must preserve control returns in each of the function components,
169 // otherwise after function inlining we might prune side-effectful nodes.
170 const auto control_ret =
171 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
172 const auto it = node_name_to_control_ret.find(n->name());
173 if (it != node_name_to_control_ret.end())
174 return absl::make_optional<string>(it->second);
175 return absl::nullopt;
176 };
177 FunctionDef new_func;
178 TF_RETURN_IF_ERROR(
179 GraphToFunctionDef(*graph, func_name, control_ret, &new_func));
180 // Refresh `fbody`.
181 TF_RETURN_IF_ERROR(
182 FunctionDefToBodyHelper(new_func, AttrSlice(), func_lib_def, fbody));
183 return Status::OK();
184 }
185 } // namespace tensorflow
186