1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
17
18 #include <string>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/numbers.h"
22 #include "absl/strings/str_split.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/grappler/costs/graph_properties.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/stringpiece.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40 namespace grappler {
41
42 constexpr char kConstOp[] = "Const";
43 constexpr char kCaseOp[] = "Case";
44 constexpr char kStatelessCaseOp[] = "StatelessCase";
45 constexpr char kDeviceIndexOp[] = "DeviceIndex";
46
47 // TODO(b/157615690): clean up function implementation swap code.
48 // The overall idea for the function swap is like below:
49 // ----------- -----------
50 // inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
51 // inp_2 ->| forward | -> out_2 g_inp_2 ->| backward| -> g_out_2
52 // | FUNC_1 | -> out_3 g_inp_3 ->| FUNC_1 |
53 // ----------- -----------
54 // | | | ^ ^ ^
55 // v v v | | |
56 // s1 s2 s3 s1 s2 s3
57 // | ^
58 // | |
59 // | -------------- |
60 // |-----------> | Identity_1 | ---------->|
61 // --------------
62 // P_C: op Partitioned_call or stateful_partitioned_call
63 // FUNC1 (forward): TF function generated for the forward path.
64 // FUNC1 (backward): TF function generated for the backward path.
65 // inp_x: input tensors for the forward path.
66 // out_x: output tensors for the forward path.
67 // g_inp_x: gradient input tensors for the backward path.
68 // g_out_x: gradient output tensors for the backward path.
69 // s_x: intermediate result generated by forward tf function, which will be
70 // consumed by backward function for gradient calculation.
71 //
72 // In the example above, the FUNC_1 takes 2 inputs, and return 3 outputs, in the
73 // meantime, generate 3 intermediate results for gradient calculation.
74 // The backward function will take 6 inputs, 3 for the gradient value for out_x,
75 // and 3 for the intermediate results s1/2/3. It returns 2 outputs for gradient
76 // value wrt inp_x.
77 //
78 // Given the graph, especially after the device placement is done, we could
79 // check if there is an alternative FUNC_2 that is better for the assigned
80 // device type. Note that FUNC_2 (both forward and backward) should have same
81 // amount of input output tensor with same dtype. However, it can generate
82 // different intermediate state tensor, both number wise and type wise, since it
83 // depends on the implementation detail.
84 //
85 // Also note that there might be some Identity op being added to the output of
86 // the forward function by IsolatePlacerInspectionRequiredOps for device
87 // placement. When the output DTYPE changes when switching from FUNC_1 to
88 // FUNC_2, the Identity node down the stream also need to be updated with new
89 // DTYPE.
90 //
91 // Based on this, the rewrite need to happen for following items:
92 //
93 // 1. P_C forward/backward need to use FUNC_2 instead of FUNC_1.
94 // 2. The T_IN for P_C backward need to be updated since the s_x can be
95 // different between FUNC_1 and FUNC_2.
96 // 3. The T_OUT for P_C forward need to be updated since the s_x can be
97 // different between FUNC_1 and FUNC_2.
98 // 4. The input edge for P_C backward need to be updated since the amount of
99 // intermediate result can be different between FUNC_1 and FUNC_2.
100 // 5. DTYPE of the Identity node after s_1/2/3 need to be updated if they exist.
101
FindForwardNode(utils::MutableNodeView * backward_node)102 string FindForwardNode(utils::MutableNodeView* backward_node) {
103 // For the tf function, Identity op node might be added by
104 // placer_inspection_required_ops_utils for device placement. Those ops might
105 // be removed by model_pruner, or stay there if the Identity op is cross
106 // device. Given the partitioned_call node for backward function, we want to
107 // find the partitioned_call node for the forward function, so that we can
108 // add/remove/updated input tensors for backward function, which is the step
109 // 4 as described above.
110
111 // Find the last input
112 const int last_input_index = backward_node->NumRegularFanins() - 1;
113 const utils::MutableFanoutView& input =
114 backward_node->GetRegularFanin(last_input_index);
115 // For the input node, it should either be the partitioned call, which is
116 // the forward node we need, or a Identity op which just pass through the
117 // output of the partitioned call.
118 if (IsIdentity(*input.node_view()->node())) {
119 // Find the only input to this op, which should be the original forward node
120 return input.node_view()->node()->input(0);
121 } else if (IsPartitionedCall(*input.node_view()->node()) ||
122 IsStatefulPartitionedCall(*input.node_view()->node())) {
123 // Found the forward node.
124 return backward_node->node()->input(last_input_index);
125 } else {
126 // Unhandled situation.
127 return "";
128 }
129 }
130
UpdateForwardIdentityNodeDtype(utils::MutableNodeView * forward_node,const DataTypeVector & dtypes)131 void UpdateForwardIdentityNodeDtype(utils::MutableNodeView* forward_node,
132 const DataTypeVector& dtypes) {
133 const auto& fanouts_vector = forward_node->GetRegularFanouts();
134 for (int pos = 0, pos_limit = fanouts_vector.size(); pos < pos_limit; ++pos) {
135 const auto& fanouts_at_pos = fanouts_vector[pos];
136 for (const auto& fanout : fanouts_at_pos) {
137 if ("Identity" == fanout.node_view()->GetOp()) {
138 (*fanout.node_view()->node()->mutable_attr())["T"].set_type(
139 dtypes[pos]);
140 VLOG(3) << "Updated DTYPE for Identity node: "
141 << fanout.node_view()->node()->DebugString();
142 }
143 }
144 }
145 }
146
UpdateNodeDef(utils::MutableNodeView * node_view,const string & funcName,const FunctionApiInfo & apiInfo)147 Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName,
148 const FunctionApiInfo& apiInfo) {
149 NodeDef* node_def = node_view->node();
150
151 VLOG(3) << "Node def before swap is: " << node_def->DebugString();
152
153 // For step 1 above.
154 node_def->mutable_attr()->find("f")->second.mutable_func()->set_name(
155 funcName);
156
157 // For step 2 above.
158 auto tin = node_def->mutable_attr()->find("Tin");
159 tin->second.mutable_list()->clear_type();
160 for (const auto& tin_dtype : apiInfo.input_arg_dtypes()) {
161 tin->second.mutable_list()->add_type(tin_dtype);
162 }
163
164 // For step 3 above.
165 auto tout = node_def->mutable_attr()->find("Tout");
166 tout->second.mutable_list()->clear_type();
167 for (const auto& tout_dtype : apiInfo.output_arg_dtypes()) {
168 tout->second.mutable_list()->add_type(tout_dtype);
169 }
170
171 if (apiInfo.function_type() == FunctionApiInfo::BACKWARD) {
172 // Strip node control dependencies. We'll add them back after updating
173 // all the data inputs.
174 std::vector<std::string> control_deps;
175 for (int i = node_def->input_size() - 1; i >= 0; --i) {
176 if (!IsControlInput(node_def->input(i))) break;
177 control_deps.push_back(node_def->input(i));
178 node_def->mutable_input()->RemoveLast();
179 }
180
181 // For step 4 above.
182 const int prev_input_size = node_def->input_size();
183 const int diff = prev_input_size - apiInfo.input_arg_dtypes().size();
184 if (diff >= 0) {
185 for (int i = 0; i < diff; ++i) node_def->mutable_input()->RemoveLast();
186 } else {
187 // Adding new inputs for internal states, the name of the internal states
188 // should be in format "{forward_node_name}:{index}", where the newly
189 // added index should start from last index of the state.
190 // Eg:
191 // {
192 // input: "gradients/unified_lstm/strided_slice_1_grad/StridedSliceGrad"
193 // input: "gradients/zeros_like_1"
194 // input: "gradients/zeros_like_2"
195 // input: "unified_lstm/StatefulPartitionedCall:3"
196 // input: "unified_lstm/StatefulPartitionedCall:4"
197 // # New input should be "unified_lstm/StatefulPartitionedCall:5"
198 // }
199 const string last_input = FindForwardNode(node_view);
200 const std::vector<string> name_index = ::absl::StrSplit(last_input, ':');
201 if (name_index.size() != 2) {
202 return errors::InvalidArgument(
203 "Invalid format of input node name: ", last_input,
204 " Expected: {forward_node_name}:{index}");
205 }
206 const absl::string_view node_name = name_index[0];
207 int last_index;
208 if (!::absl::SimpleAtoi(name_index[1], &last_index)) {
209 return errors::InvalidArgument(
210 "The index of input node is expected to be number, got: ",
211 name_index[1]);
212 }
213 for (int i = 1; i <= -diff; ++i)
214 node_def->add_input(strings::StrCat(node_name, ":", i + last_index));
215 }
216
217 // Add control dependencies back.
218 for (std::string& control : control_deps)
219 node_def->add_input(std::move(control));
220
221 } else if (apiInfo.function_type() == FunctionApiInfo::FORWARD) {
222 // For forward function, since the DTYPE of the intermediate state might
223 // have been changed, we want to update the down stream Identity node if
224 // any. This is the step 5 in the commend above.
225 UpdateForwardIdentityNodeDtype(node_view, apiInfo.output_arg_dtypes());
226 }
227
228 VLOG(3) << "Node def after swap is: " << node_def->DebugString();
229 return Status::OK();
230 }
231
LoadFunctions(const GraphDef & graph)232 Status ImplementationSelector::LoadFunctions(const GraphDef& graph) {
233 lib_info_ = absl::make_unique<FunctionLibraryApiInfo>();
234 TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
235 return Status::OK();
236 }
237
MaybeOptimizeFunctionCall(utils::MutableNodeView * node_view) const238 Status ImplementationSelector::MaybeOptimizeFunctionCall(
239 utils::MutableNodeView* node_view) const {
240 // There are two ways of calling functions:
241 // 1. By specifying an op name as a function name, or
242 // 2. Via the @defun functional interface, where the real function call
243 // happens with partitionedcall op, and the function name appear as the
244 // attribute with name "f" and type func. In this use case, there are more
245 // attributes need to be taken care, like Tin and Tout which take care of
246 // the DTYPE of input/output.
247 NodeDef* node_def = node_view->node();
248
249 std::vector<string> function_attribute_names;
250 for (const auto& attr : node_def->attr()) {
251 if (attr.second.has_func() &&
252 lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
253 function_attribute_names.emplace_back(attr.first);
254 }
255 }
256
257 if (function_attribute_names.empty() &&
258 lib_info_->GetApiInfo(node_def->op()) == nullptr) {
259 // A regular op, or a function which has no interface.
260 return Status::OK();
261 }
262
263 DeviceNameUtils::ParsedName parsed_name;
264 if (!DeviceNameUtils::ParseFullName(node_def->device(), &parsed_name) ||
265 !parsed_name.has_type) {
266 return errors::Internal("Could not parse device name:", node_def->device());
267 }
268 VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
269 << " = (" << parsed_name.type << ")";
270
271 for (const auto& attr_name : function_attribute_names) {
272 string function_name = node_def->attr().at(attr_name).func().name();
273 // Skip the function if its already optimized by function optimizer.
274 if (::absl::StrContains(function_name, "_specialized_for_")) continue;
275 std::vector<string> equiv_func_names;
276 TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
277 function_name, &equiv_func_names));
278 for (const auto& func_name : equiv_func_names) {
279 const auto& func_api_info = lib_info_->GetApiInfo(func_name);
280 if (func_api_info->preferred_device() == parsed_name.type) {
281 VLOG(2) << "Swapping: " << function_name << " TO: " << func_name;
282 TF_RETURN_IF_ERROR(UpdateNodeDef(node_view, func_name, *func_api_info));
283 break;
284 }
285 }
286 }
287
288 if (lib_info_->GetApiInfo(node_def->op()) != nullptr &&
289 !::absl::StrContains(node_def->op(), "_specialized_for_")) {
290 std::vector<string> equiv_func_names;
291 TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
292 node_def->op(), &equiv_func_names));
293 for (const string& func_name : equiv_func_names) {
294 const auto func_api_info = lib_info_->GetApiInfo(func_name);
295 if (func_api_info->preferred_device() == parsed_name.type) {
296 node_def->set_op(func_name);
297 break;
298 }
299 }
300 }
301 return Status::OK();
302 }
303
304 // Finds the index of the device from the device name list.
FindDeviceIndex(const utils::MutableNodeView * device_index_node,const string & device,int * index)305 Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
306 const string& device, int* index) {
307 DeviceNameUtils::ParsedName parsed_name;
308 if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
309 !parsed_name.has_type) {
310 return errors::Internal("Could not parse device name:", device);
311 }
312 const auto& device_list =
313 device_index_node->GetAttr("device_names")->list().s();
314 auto it = absl::c_find(device_list, parsed_name.type);
315 if (it != device_list.end()) {
316 *index = it - device_list.begin();
317 } else {
318 // Sets *index to device_list.size() because the default_fn is guaranteed to
319 // be the final item in the case op branching list.
320 *index = device_list.size();
321 }
322 return Status::OK();
323 }
324
325 // Rewrites the device_index op to a const op with value of the index.
RewriteDeviceIndexOp(utils::MutableNodeView * device_index_node,int index)326 void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
327 int index) {
328 // Modifies the DeviceIndex node to be an Const op with correct device index.
329 auto node = device_index_node->node();
330 node->set_op(kConstOp);
331 EraseRegularNodeAttributes(node);
332 (*node->mutable_attr())["dtype"].set_type(DT_INT32);
333 auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
334 tensor->set_dtype(DT_INT32);
335 tensor->add_int_val(index);
336 VLOG(2) << "Node after rewriting:" << node->DebugString();
337 }
338
SelectDeviceIndex(GraphDef * graph) const339 Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
340 Status status;
341 VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
342 utils::MutableGraphView graph_view(graph, &status);
343 TF_RETURN_IF_ERROR(status);
344 const int num_nodes = graph_view.NumNodes();
345 for (int k = 0; k < num_nodes; ++k) {
346 auto* node_view = graph_view.GetNode(k);
347 if (node_view->GetOp() != kDeviceIndexOp) {
348 continue;
349 }
350 VLOG(2) << "Found a node to rewrite the device index";
351
352 // Find the case node with device index node as input, rewrite the
353 // DeviceIndex node to have the value of the index of device type of the
354 // case node.
355 for (const auto& fanouts : node_view->GetRegularFanouts()) {
356 for (const auto& fanout : fanouts) {
357 if (fanout.node_view()->GetOp() != kCaseOp &&
358 fanout.node_view()->GetOp() != kStatelessCaseOp)
359 continue;
360 int index;
361 // If any error is thrown out during device parsing, we simply skip
362 // and do not modify the DeviceIndexNode.
363 Status status =
364 FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
365 if (status.ok()) {
366 RewriteDeviceIndexOp(node_view, index);
367 }
368 }
369 }
370 }
371 return Status::OK();
372 }
373
SelectImplementation(GraphDef * graph) const374 Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
375 if (!graph->has_library()) {
376 VLOG(2) << "Skipping graph since it does not have function def";
377 return Status::OK();
378 }
379 if (lib_info_->empty()) {
380 VLOG(2) << "Skipping optimization since lib_info is empty";
381 return Status::OK();
382 }
383
384 Status status;
385 utils::MutableGraphView graph_view(graph, &status);
386 TF_RETURN_IF_ERROR(status);
387
388 const int num_nodes = graph_view.NumNodes();
389 for (int k = 0; k < num_nodes; ++k) {
390 TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
391 }
392
393 return Status::OK();
394 }
395
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)396 Status ImplementationSelector::Optimize(Cluster* cluster,
397 const GrapplerItem& item,
398 GraphDef* optimized_graph) {
399 auto status = LoadFunctions(item.graph);
400 // Eat up the error from function loading, since this optimizer might run
401 // several times, and might try to run against functions generated by
402 // function_optimizer from previous runs, which will fail due to function
403 // signature mismatch.
404 if (!status.ok()) {
405 VLOG(2) << "Skipping optimization due to error while loading function "
406 << "libraries: " << status;
407 return errors::Aborted("Skipped Optimization");
408 }
409
410 *optimized_graph = item.graph;
411 status = SelectDeviceIndex(optimized_graph);
412 if (!status.ok()) {
413 *optimized_graph = item.graph;
414 VLOG(2) << "Could not rewrite device index due to error:" << status;
415 }
416 return SelectImplementation(optimized_graph);
417 }
418
419 } // end namespace grappler
420 } // end namespace tensorflow
421