• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
17 
18 #include <string>
19 
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
27 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
28 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 
UpdateNodeDef(NodeDef * node_def,const string & funcName,const FunctionApiInfo & apiInfo)37 Status UpdateNodeDef(NodeDef* node_def, const string& funcName,
38                      const FunctionApiInfo& apiInfo) {
39   VLOG(3) << "Node def before swap is: " << node_def->DebugString();
40   auto tin = node_def->mutable_attr()->find("Tin");
41   tin->second.mutable_list()->clear_type();
42   for (const auto& tin_dtype : apiInfo.input_arg_dtypes()) {
43     tin->second.mutable_list()->add_type(tin_dtype);
44   }
45 
46   auto tout = node_def->mutable_attr()->find("Tout");
47   tout->second.mutable_list()->clear_type();
48   for (const auto& tout_dtype : apiInfo.output_arg_dtypes()) {
49     tout->second.mutable_list()->add_type(tout_dtype);
50   }
51 
52   if (apiInfo.function_type() == FunctionApiInfo::BACKWARD) {
53     // Update the inputs since for backward function, it might have different
54     // number of inputs due the different number output from forward function.
55     // The output of forward function are composed by two parts:
56     //   1. Real output tensors from defun.
57     //   2. Internal states that will be used for gradient calculation.
58     // Part 1 will be static, and part 2 could be different based on the
59     // different implementation.
60 
61     const int prev_input_size = node_def->input_size();
62     const int diff = prev_input_size - apiInfo.input_arg_dtypes().size();
63     if (diff >= 0) {
64       for (int i = 0; i < diff; ++i) node_def->mutable_input()->RemoveLast();
65     } else {
66       // Adding new inputs for internal states, the name of the internal states
67       // should be in format "{forward_node_name}:{index}", where the newly
68       // added index should start from last index of the state.
69       // Eg:
70       // {
71       //   input: "gradients/unified_lstm/strided_slice_1_grad/StridedSliceGrad"
72       //   input: "gradients/zeros_like_1"
73       //   input: "gradients/zeros_like_2"
74       //   input: "unified_lstm/StatefulPartitionedCall:3"
75       //   input: "unified_lstm/StatefulPartitionedCall:4"
76       //   # New input should be "unified_lstm/StatefulPartitionedCall:5"
77       // }
78       const string last_input = node_def->input(prev_input_size - 1);
79       const std::vector<string> name_index = ::absl::StrSplit(last_input, ':');
80       if (name_index.size() != 2) {
81         return errors::InvalidArgument(
82             "Invalid format of input node name: ", last_input,
83             " Expected: {forward_node_name}:{index}");
84       }
85       const absl::string_view node_name = name_index[0];
86       int last_index;
87       if (!::absl::SimpleAtoi(name_index[1], &last_index)) {
88         return errors::InvalidArgument(
89             "The index of input node is expected to be number, got: ",
90             name_index[1]);
91       }
92       for (int i = 1; i <= -diff; ++i)
93         node_def->add_input(strings::StrCat(node_name, ":", i + last_index));
94     }
95   }
96 
97   node_def->mutable_attr()->find("f")->second.mutable_func()->set_name(
98       funcName);
99 
100   VLOG(3) << "Node def after swap is: " << node_def->DebugString();
101   return Status::OK();
102 }
103 
LoadFunctions(const GraphDef & graph)104 Status ImplementationSelector::LoadFunctions(
105     const GraphDef& graph) {
106   lib_info_.reset(new FunctionLibraryApiInfo);
107   TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
108   return Status::OK();
109 }
110 
MaybeOptimizeFunctionCall(NodeDef * node_def) const111 Status ImplementationSelector::MaybeOptimizeFunctionCall(
112     NodeDef* node_def) const {
113   // There are two ways of calling functions:
114   //  1. By specifying an op name as a function name, or
115   //  2. Via the @defun functional interface, where the real function call
116   //     happens with partitionedcall op, and the function name appear as the
117   //     attribute with name "f" and type func. In this use case, there are more
118   //     attributes need to be taken care, like Tin and Tout which take care of
119   //     the DTYPE of input/output.
120   std::vector<string> function_attribute_names;
121   for (const auto& attr : node_def->attr()) {
122     if (attr.second.has_func() &&
123         lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
124       function_attribute_names.emplace_back(attr.first);
125     }
126   }
127 
128   if (function_attribute_names.empty() &&
129       lib_info_->GetApiInfo(node_def->op()) == nullptr) {
130     // A regular op, or a function which has no interface.
131     return Status::OK();
132   }
133 
134   string task, device;
135   if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
136     return errors::Internal("Could not split device name:", node_def->device());
137   }
138   VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
139           << " = (" << task << ", " << device << ")";
140   DeviceNameUtils::ParsedName parsed_name;
141   DeviceNameUtils::ParseLocalName(device, &parsed_name);
142 
143   for (const auto& attr_name : function_attribute_names) {
144     string function_name = node_def->attr().at(attr_name).func().name();
145     std::vector<string> equiv_func_names;
146     TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
147         function_name, &equiv_func_names));
148     for (const auto& func_name : equiv_func_names) {
149       const auto& func_api_info = lib_info_->GetApiInfo(func_name);
150       if (func_api_info->preferred_device() == parsed_name.type) {
151         VLOG(2) << "Swapping: " << function_name << " TO: " << func_name;
152         TF_RETURN_IF_ERROR(UpdateNodeDef(node_def, func_name, *func_api_info));
153         break;
154       }
155     }
156   }
157 
158   if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
159     std::vector<string> equiv_func_names;
160     TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
161         node_def->op(), &equiv_func_names));
162     for (const string& func_name : equiv_func_names) {
163       const auto func_api_info = lib_info_->GetApiInfo(func_name);
164       if (func_api_info->preferred_device() == parsed_name.type) {
165         node_def->set_op(func_name);
166         break;
167       }
168     }
169   }
170   return Status::OK();
171 }
172 
SelectImplementation(GraphDef * graph) const173 Status ImplementationSelector::SelectImplementation(
174     GraphDef* graph) const {
175   if (!graph->has_library()) {
176     VLOG(2) << "Skipping graph since it does not have function def";
177     return Status::OK();
178   }
179   if (lib_info_->empty()) {
180     VLOG(2) << "Skipping optimization since lib_info is empty";
181     return Status::OK();
182   }
183 
184   for (int k = 0; k < graph->node_size(); ++k)
185     TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
186 
187   return Status::OK();
188 }
189 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)190 Status ImplementationSelector::Optimize(Cluster* cluster,
191                                         const GrapplerItem& item,
192                                         GraphDef* optimized_graph) {
193   *optimized_graph = item.graph;
194   TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
195   return SelectImplementation(optimized_graph);
196 }
197 
198 }  // end namespace grappler
199 }  // end namespace tensorflow
200