• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
17 
18 #include <string>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/numbers.h"
22 #include "absl/strings/str_split.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/grappler/costs/graph_properties.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/stringpiece.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 constexpr char kConstOp[] = "Const";
43 constexpr char kCaseOp[] = "Case";
44 constexpr char kStatelessCaseOp[] = "StatelessCase";
45 constexpr char kDeviceIndexOp[] = "DeviceIndex";
46 
47 // TODO(b/157615690): clean up function implementation swap code.
48 // The overall idea for the function swap is like below:
49 //          -----------                            -----------
50 //  inp_1 ->|  P_C    | -> out_1         g_inp_1 ->|  P_C    | -> g_out_1
51 //  inp_2 ->| forward | -> out_2         g_inp_2 ->| backward| -> g_out_2
52 //          | FUNC_1  | -> out_3         g_inp_3 ->| FUNC_1  |
53 //          -----------                            -----------
54 //           |  |  |                                 ^  ^  ^
55 //           v  v  v                                 |  |  |
56 //           s1 s2 s3                                s1 s2 s3
57 //           |                                       ^
58 //           |                                       |
59 //           |             --------------            |
60 //           |-----------> | Identity_1 | ---------->|
61 //                         --------------
62 // P_C: op Partitioned_call or stateful_partitioned_call
63 // FUNC1 (forward): TF function generated for the forward path.
64 // FUNC1 (backward): TF function generated for the backward path.
65 // inp_x: input tensors for the forward path.
66 // out_x: output tensors for the forward path.
67 // g_inp_x: gradient input tensors for the backward path.
68 // g_out_x: gradient output tensors for the backward path.
69 // s_x: intermediate result generated by forward tf function, which will be
70 //      consumed by backward function for gradient calculation.
71 //
72 // In the example above, the FUNC_1 takes 2 inputs, and return 3 outputs, in the
73 // meantime, generate 3 intermediate results for gradient calculation.
74 // The backward function will take 6 inputs, 3 for the gradient value for out_x,
75 // and 3 for the intermediate results s1/2/3. It returns 2 outputs for gradient
76 // value wrt inp_x.
77 //
78 // Given the graph, especially after the device placement is done, we could
79 // check if there is an alternative FUNC_2 that is better for the assigned
80 // device type. Note that FUNC_2 (both forward and backward) should have same
81 // amount of input output tensor with same dtype. However, it can generate
82 // different intermediate state tensor, both number wise and type wise, since it
83 // depends on the implementation detail.
84 //
85 // Also note that there might be some Identity op being added to the output of
86 // the forward function by IsolatePlacerInspectionRequiredOps for device
87 // placement. When the output DTYPE changes when switching from FUNC_1 to
88 // FUNC_2, the Identity node down the stream also need to be updated with new
89 // DTYPE.
90 //
91 // Based on this, the rewrite need to happen for following items:
92 //
93 // 1. P_C forward/backward need to use FUNC_2 instead of FUNC_1.
94 // 2. The T_IN for P_C backward need to be updated since the s_x can be
95 //    different between FUNC_1 and FUNC_2.
96 // 3. The T_OUT for P_C forward need to be updated since the s_x can be
97 //    different between FUNC_1 and FUNC_2.
98 // 4. The input edge for P_C backward need to be updated since the amount of
99 //    intermediate result can be different between FUNC_1 and FUNC_2.
100 // 5. DTYPE of the Identity node after s_1/2/3 need to be updated if they exist.
101 
FindForwardNode(utils::MutableNodeView * backward_node)102 string FindForwardNode(utils::MutableNodeView* backward_node) {
103   // For the tf function, Identity op node might be added by
104   // placer_inspection_required_ops_utils for device placement. Those ops might
105   // be removed by model_pruner, or stay there if the Identity op is cross
106   // device. Given the partitioned_call node for backward function, we want to
107   // find the partitioned_call node for the forward function, so that we can
108   // add/remove/updated input tensors for backward function, which is the step
109   // 4 as described above.
110 
111   // Find the last input
112   const int last_input_index = backward_node->NumRegularFanins() - 1;
113   const utils::MutableFanoutView& input =
114       backward_node->GetRegularFanin(last_input_index);
115   // For the input node, it should either be the partitioned call, which is
116   // the forward node we need, or a Identity op which just pass through the
117   // output of the partitioned call.
118   if (IsIdentity(*input.node_view()->node())) {
119     // Find the only input to this op, which should be the original forward node
120     return input.node_view()->node()->input(0);
121   } else if (IsPartitionedCall(*input.node_view()->node()) ||
122              IsStatefulPartitionedCall(*input.node_view()->node())) {
123     // Found the forward node.
124     return backward_node->node()->input(last_input_index);
125   } else {
126     // Unhandled situation.
127     return "";
128   }
129 }
130 
UpdateForwardIdentityNodeDtype(utils::MutableNodeView * forward_node,const DataTypeVector & dtypes)131 void UpdateForwardIdentityNodeDtype(utils::MutableNodeView* forward_node,
132                                     const DataTypeVector& dtypes) {
133   const auto& fanouts_vector = forward_node->GetRegularFanouts();
134   for (int pos = 0, pos_limit = fanouts_vector.size(); pos < pos_limit; ++pos) {
135     const auto& fanouts_at_pos = fanouts_vector[pos];
136     for (const auto& fanout : fanouts_at_pos) {
137       if ("Identity" == fanout.node_view()->GetOp()) {
138         (*fanout.node_view()->node()->mutable_attr())["T"].set_type(
139             dtypes[pos]);
140         VLOG(3) << "Updated DTYPE for Identity node: "
141                 << fanout.node_view()->node()->DebugString();
142       }
143     }
144   }
145 }
146 
UpdateNodeDef(utils::MutableNodeView * node_view,const string & funcName,const FunctionApiInfo & apiInfo)147 Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName,
148                      const FunctionApiInfo& apiInfo) {
149   NodeDef* node_def = node_view->node();
150 
151   VLOG(3) << "Node def before swap is: " << node_def->DebugString();
152 
153   // For step 1 above.
154   node_def->mutable_attr()->find("f")->second.mutable_func()->set_name(
155       funcName);
156 
157   // For step 2 above.
158   auto tin = node_def->mutable_attr()->find("Tin");
159   tin->second.mutable_list()->clear_type();
160   for (const auto& tin_dtype : apiInfo.input_arg_dtypes()) {
161     tin->second.mutable_list()->add_type(tin_dtype);
162   }
163 
164   // For step 3 above.
165   auto tout = node_def->mutable_attr()->find("Tout");
166   tout->second.mutable_list()->clear_type();
167   for (const auto& tout_dtype : apiInfo.output_arg_dtypes()) {
168     tout->second.mutable_list()->add_type(tout_dtype);
169   }
170 
171   if (apiInfo.function_type() == FunctionApiInfo::BACKWARD) {
172     // Strip node control dependencies. We'll add them back after updating
173     // all the data inputs.
174     std::vector<std::string> control_deps;
175     for (int i = node_def->input_size() - 1; i >= 0; --i) {
176       if (!IsControlInput(node_def->input(i))) break;
177       control_deps.push_back(node_def->input(i));
178       node_def->mutable_input()->RemoveLast();
179     }
180 
181     // For step 4 above.
182     const int prev_input_size = node_def->input_size();
183     const int diff = prev_input_size - apiInfo.input_arg_dtypes().size();
184     if (diff >= 0) {
185       for (int i = 0; i < diff; ++i) node_def->mutable_input()->RemoveLast();
186     } else {
187       // Adding new inputs for internal states, the name of the internal states
188       // should be in format "{forward_node_name}:{index}", where the newly
189       // added index should start from last index of the state.
190       // Eg:
191       // {
192       //   input: "gradients/unified_lstm/strided_slice_1_grad/StridedSliceGrad"
193       //   input: "gradients/zeros_like_1"
194       //   input: "gradients/zeros_like_2"
195       //   input: "unified_lstm/StatefulPartitionedCall:3"
196       //   input: "unified_lstm/StatefulPartitionedCall:4"
197       //   # New input should be "unified_lstm/StatefulPartitionedCall:5"
198       // }
199       const string last_input = FindForwardNode(node_view);
200       const std::vector<string> name_index = ::absl::StrSplit(last_input, ':');
201       if (name_index.size() != 2) {
202         return errors::InvalidArgument(
203             "Invalid format of input node name: ", last_input,
204             " Expected: {forward_node_name}:{index}");
205       }
206       const absl::string_view node_name = name_index[0];
207       int last_index;
208       if (!::absl::SimpleAtoi(name_index[1], &last_index)) {
209         return errors::InvalidArgument(
210             "The index of input node is expected to be number, got: ",
211             name_index[1]);
212       }
213       for (int i = 1; i <= -diff; ++i)
214         node_def->add_input(strings::StrCat(node_name, ":", i + last_index));
215     }
216 
217     // Add control dependencies back.
218     for (std::string& control : control_deps)
219       node_def->add_input(std::move(control));
220 
221   } else if (apiInfo.function_type() == FunctionApiInfo::FORWARD) {
222     // For forward function, since the DTYPE of the intermediate state might
223     // have been changed, we want to update the down stream Identity node if
224     // any. This is the step 5 in the commend above.
225     UpdateForwardIdentityNodeDtype(node_view, apiInfo.output_arg_dtypes());
226   }
227 
228   VLOG(3) << "Node def after swap is: " << node_def->DebugString();
229   return Status::OK();
230 }
231 
LoadFunctions(const GraphDef & graph)232 Status ImplementationSelector::LoadFunctions(const GraphDef& graph) {
233   lib_info_ = absl::make_unique<FunctionLibraryApiInfo>();
234   TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
235   return Status::OK();
236 }
237 
MaybeOptimizeFunctionCall(utils::MutableNodeView * node_view) const238 Status ImplementationSelector::MaybeOptimizeFunctionCall(
239     utils::MutableNodeView* node_view) const {
240   // There are two ways of calling functions:
241   //  1. By specifying an op name as a function name, or
242   //  2. Via the @defun functional interface, where the real function call
243   //     happens with partitionedcall op, and the function name appear as the
244   //     attribute with name "f" and type func. In this use case, there are more
245   //     attributes need to be taken care, like Tin and Tout which take care of
246   //     the DTYPE of input/output.
247   NodeDef* node_def = node_view->node();
248 
249   std::vector<string> function_attribute_names;
250   for (const auto& attr : node_def->attr()) {
251     if (attr.second.has_func() &&
252         lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
253       function_attribute_names.emplace_back(attr.first);
254     }
255   }
256 
257   if (function_attribute_names.empty() &&
258       lib_info_->GetApiInfo(node_def->op()) == nullptr) {
259     // A regular op, or a function which has no interface.
260     return Status::OK();
261   }
262 
263   DeviceNameUtils::ParsedName parsed_name;
264   if (!DeviceNameUtils::ParseFullName(node_def->device(), &parsed_name) ||
265       !parsed_name.has_type) {
266     return errors::Internal("Could not parse device name:", node_def->device());
267   }
268   VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
269           << " = (" << parsed_name.type << ")";
270 
271   for (const auto& attr_name : function_attribute_names) {
272     string function_name = node_def->attr().at(attr_name).func().name();
273     // Skip the function if its already optimized by function optimizer.
274     if (::absl::StrContains(function_name, "_specialized_for_")) continue;
275     std::vector<string> equiv_func_names;
276     TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
277         function_name, &equiv_func_names));
278     for (const auto& func_name : equiv_func_names) {
279       const auto& func_api_info = lib_info_->GetApiInfo(func_name);
280       if (func_api_info->preferred_device() == parsed_name.type) {
281         VLOG(2) << "Swapping: " << function_name << " TO: " << func_name;
282         TF_RETURN_IF_ERROR(UpdateNodeDef(node_view, func_name, *func_api_info));
283         break;
284       }
285     }
286   }
287 
288   if (lib_info_->GetApiInfo(node_def->op()) != nullptr &&
289       !::absl::StrContains(node_def->op(), "_specialized_for_")) {
290     std::vector<string> equiv_func_names;
291     TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
292         node_def->op(), &equiv_func_names));
293     for (const string& func_name : equiv_func_names) {
294       const auto func_api_info = lib_info_->GetApiInfo(func_name);
295       if (func_api_info->preferred_device() == parsed_name.type) {
296         node_def->set_op(func_name);
297         break;
298       }
299     }
300   }
301   return Status::OK();
302 }
303 
304 // Finds the index of the device from the device name list.
FindDeviceIndex(const utils::MutableNodeView * device_index_node,const string & device,int * index)305 Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
306                        const string& device, int* index) {
307   DeviceNameUtils::ParsedName parsed_name;
308   if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
309       !parsed_name.has_type) {
310     return errors::Internal("Could not parse device name:", device);
311   }
312   const auto& device_list =
313       device_index_node->GetAttr("device_names")->list().s();
314   auto it = absl::c_find(device_list, parsed_name.type);
315   if (it != device_list.end()) {
316     *index = it - device_list.begin();
317   } else {
318     // Sets *index to device_list.size() because the default_fn is guaranteed to
319     // be the final item in the case op branching list.
320     *index = device_list.size();
321   }
322   return Status::OK();
323 }
324 
325 // Rewrites the device_index op to a const op with value of the index.
RewriteDeviceIndexOp(utils::MutableNodeView * device_index_node,int index)326 void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
327                           int index) {
328   // Modifies the DeviceIndex node to be an Const op with correct device index.
329   auto node = device_index_node->node();
330   node->set_op(kConstOp);
331   EraseRegularNodeAttributes(node);
332   (*node->mutable_attr())["dtype"].set_type(DT_INT32);
333   auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
334   tensor->set_dtype(DT_INT32);
335   tensor->add_int_val(index);
336   VLOG(2) << "Node after rewriting:" << node->DebugString();
337 }
338 
SelectDeviceIndex(GraphDef * graph) const339 Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
340   Status status;
341   VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
342   utils::MutableGraphView graph_view(graph, &status);
343   TF_RETURN_IF_ERROR(status);
344   const int num_nodes = graph_view.NumNodes();
345   for (int k = 0; k < num_nodes; ++k) {
346     auto* node_view = graph_view.GetNode(k);
347     if (node_view->GetOp() != kDeviceIndexOp) {
348       continue;
349     }
350     VLOG(2) << "Found a node to rewrite the device index";
351 
352     // Find the case node with device index node as input, rewrite the
353     // DeviceIndex node to have the value of the index of device type of the
354     // case node.
355     for (const auto& fanouts : node_view->GetRegularFanouts()) {
356       for (const auto& fanout : fanouts) {
357         if (fanout.node_view()->GetOp() != kCaseOp &&
358             fanout.node_view()->GetOp() != kStatelessCaseOp)
359           continue;
360         int index;
361         // If any error is thrown out during device parsing, we simply skip
362         // and do not modify the DeviceIndexNode.
363         Status status =
364             FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
365         if (status.ok()) {
366           RewriteDeviceIndexOp(node_view, index);
367         }
368       }
369     }
370   }
371   return Status::OK();
372 }
373 
SelectImplementation(GraphDef * graph) const374 Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
375   if (!graph->has_library()) {
376     VLOG(2) << "Skipping graph since it does not have function def";
377     return Status::OK();
378   }
379   if (lib_info_->empty()) {
380     VLOG(2) << "Skipping optimization since lib_info is empty";
381     return Status::OK();
382   }
383 
384   Status status;
385   utils::MutableGraphView graph_view(graph, &status);
386   TF_RETURN_IF_ERROR(status);
387 
388   const int num_nodes = graph_view.NumNodes();
389   for (int k = 0; k < num_nodes; ++k) {
390     TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
391   }
392 
393   return Status::OK();
394 }
395 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)396 Status ImplementationSelector::Optimize(Cluster* cluster,
397                                         const GrapplerItem& item,
398                                         GraphDef* optimized_graph) {
399   auto status = LoadFunctions(item.graph);
400   // Eat up the error from function loading, since this optimizer might run
401   // several times, and might try to run against functions generated by
402   // function_optimizer from previous runs, which will fail due to function
403   // signature mismatch.
404   if (!status.ok()) {
405     VLOG(2) << "Skipping optimization due to error while loading function "
406             << "libraries: " << status;
407     return errors::Aborted("Skipped Optimization");
408   }
409 
410   *optimized_graph = item.graph;
411   status = SelectDeviceIndex(optimized_graph);
412   if (!status.ok()) {
413     *optimized_graph = item.graph;
414     VLOG(2) << "Could not rewrite device index due to error:" << status;
415   }
416   return SelectImplementation(optimized_graph);
417 }
418 
419 }  // end namespace grappler
420 }  // end namespace tensorflow
421