• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/device_set.h"
17 #include "tensorflow/core/common_runtime/eager/context.h"
18 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
19 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
20 #include "tensorflow/core/common_runtime/placer.h"
21 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
22 #include "tensorflow/core/framework/graph_to_functiondef.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
25 #include "tfrt/host_context/device.h"  // from @tf_runtime
26 #include "tfrt/support/error_util.h"  // from @tf_runtime
27 
28 namespace tensorflow {
29 
30 namespace {
31 constexpr char kDefaultCpuDeviceName[] = "CPU:0";
32 }  // namespace
33 
TransformGraphFunction(const std::string & func_name,const FunctionDef & fdef,const std::string & device_name,const tensorflow::DeviceSet & device_set,EagerContext * eager_ctx,bool enable_grappler,std::unique_ptr<FunctionBody> * fbody,std::unique_ptr<Graph> graph,tfrt::ArrayRef<const tfrt::Device * > input_devices,FunctionLibraryDefinition * func_lib_def)34 Status TransformGraphFunction(const std::string& func_name,
35                               const FunctionDef& fdef,
36                               const std::string& device_name,
37                               const tensorflow::DeviceSet& device_set,
38                               EagerContext* eager_ctx, bool enable_grappler,
39                               std::unique_ptr<FunctionBody>* fbody,
40                               std::unique_ptr<Graph> graph,
41                               tfrt::ArrayRef<const tfrt::Device*> input_devices,
42                               FunctionLibraryDefinition* func_lib_def) {
43   const DeviceMgr* device_mgr = eager_ctx->local_device_mgr();
44   if (device_mgr == nullptr)
45     return errors::Internal("Cannot find device manager");
46   DumpGraph("Input function graph", graph.get());
47 
48   std::vector<string> ret_node_names;
49   std::vector<string> control_ret_node_names;
50   // Mapping from a function body node name to the control output name.
51   std::unordered_map<string, string> node_name_to_control_ret;
52   std::vector<Node*> arg_nodes, ret_nodes;
53   DataTypeVector ret_types;
54   auto attrs = AttrSlice(&fdef.attr());
55   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
56       func_name, attrs, &fdef, func_lib_def, &graph, &arg_nodes, &ret_nodes,
57       &ret_node_names, &ret_types, &control_ret_node_names));
58   for (const auto& control_ret : fdef.control_ret()) {
59     node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
60   }
61   for (Node* node : arg_nodes) {
62     const AttrValue* attr_value;
63     TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
64     int64_t index = attr_value->i();
65     node->set_assigned_device_name(input_devices[index]->name().str());
66   }
67 
68   std::vector<string> input_device_names;
69   int input_size = input_devices.size();
70   input_device_names.reserve(input_size);
71   for (int i = 0; i < input_size; ++i) {
72     input_device_names.push_back(input_devices[i]->name().str());
73   }
74 
75   std::vector<string> output_device_names;
76   int output_size = fdef.signature().output_arg_size();
77   output_device_names.reserve(output_size);
78   for (int i = 0; i < output_size; ++i) {
79     output_device_names.push_back(device_name);
80   }
81 
82   // set default_device for placer.
83   Device* default_device = nullptr;
84   tensorflow::Status s = device_mgr->LookupDevice(device_name, &default_device);
85   if (!s.ok())
86     VLOG(1) << "TransformGraphFunction(): " << device_name << " is unknown."
87             << " default device for placer is not set.";
88 
89   TF_RETURN_IF_ERROR(ProcessFunctionLibraryRuntime::PinArgsAndRets(
90       input_device_names, output_device_names, device_set, arg_nodes, ret_nodes,
91       func_lib_def,
92       eager_ctx->AllowSoftPlacement() ? default_device : nullptr));
93   DumpGraph("After running PinArgsAndRets", graph.get());
94 
95   ConfigProto config;
96   bool control_rets_updated = false;
97   TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
98       device_set, config, &graph, func_lib_def, &control_ret_node_names,
99       &control_rets_updated));
100 
101   if (control_rets_updated) {
102     // Function graph pass may have resulted in different nodes/node names for
103     // control rets.
104     for (const auto& control_ret : control_ret_node_names) {
105       node_name_to_control_ret.emplace(control_ret, control_ret);
106     }
107   } else {
108     for (const auto& control_ret : fdef.control_ret()) {
109       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
110     }
111   }
112   DumpGraph("After running function optimization pass (bridge)", graph.get());
113 
114   // Run function inlining so that placer can place ops in nested functions.
115   GraphOptimizationPassOptions optimization_options;
116   SessionOptions session_options;
117   // In TFRT we don't lower v2 control flow to v1.
118   session_options.config.mutable_experimental()->set_use_tfrt(true);
119   session_options.config.mutable_graph_options()
120       ->mutable_optimizer_options()
121       ->set_do_function_inlining(true);
122   optimization_options.session_options = &session_options;
123   optimization_options.graph = &graph;
124   optimization_options.flib_def = func_lib_def;
125   optimization_options.device_set = &device_set;
126   optimization_options.is_function_graph = true;
127   optimization_options.default_function_device = default_device;
128   optimization_options.function_def = &fdef;
129 
130   LowerFunctionalOpsPass pass;
131   TF_RETURN_IF_ERROR(pass.Run(optimization_options));
132 
133   // Run placer before importing GraphDef to MLIR.
134   Placer placer(graph.get(), func_name, func_lib_def, &device_set,
135                 default_device, eager_ctx->AllowSoftPlacement(),
136                 /*log_device_placement=*/false);
137   TF_RETURN_IF_ERROR(placer.Run());
138   DumpGraph("After running placer", graph.get());
139 
140   if (enable_grappler) {
141     Device* cpu_device;
142     TF_RETURN_IF_ERROR(
143         device_mgr->LookupDevice(kDefaultCpuDeviceName, &cpu_device));
144 
145     ConfigProto config_proto;
146     config_proto.mutable_experimental()->set_use_tfrt(true);
147     config_proto.mutable_graph_options()
148         ->mutable_optimizer_options()
149         ->set_do_function_inlining(true);
150     // Do not skip grappler optimization even for small graphs.
151     config_proto.mutable_graph_options()
152         ->mutable_rewrite_options()
153         ->set_min_graph_nodes(-1);
154 
155     grappler::GrapplerItem::OptimizationOptions grappler_options =
156         grappler::CreateOptOptionsForEager();
157     auto status = grappler::OptimizeGraph(
158         std::move(ret_node_names), std::move(control_ret_node_names),
159         func_lib_def, device_set, cpu_device, config_proto,
160         fdef.signature().name(), grappler_options, &graph);
161     if (!status.ok()) {
162       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
163                    << status.ToString();
164     }
165     DumpGraph("After grappler optimization", graph.get());
166   }
167 
168   // We must preserve control returns in each of the function components,
169   // otherwise after function inlining we might prune side-effectful nodes.
170   const auto control_ret =
171       [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
172     const auto it = node_name_to_control_ret.find(n->name());
173     if (it != node_name_to_control_ret.end())
174       return absl::make_optional<string>(it->second);
175     return absl::nullopt;
176   };
177   FunctionDef new_func;
178   TF_RETURN_IF_ERROR(
179       GraphToFunctionDef(*graph, func_name, control_ret, &new_func));
180   // Refresh `fbody`.
181   TF_RETURN_IF_ERROR(
182       FunctionDefToBodyHelper(new_func, AttrSlice(), func_lib_def, fbody));
183   return Status::OK();
184 }
185 }  // namespace tensorflow
186