1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
17
18 #include <string>
19
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
27 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
28 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace tensorflow {
35 namespace grappler {
36
UpdateNodeDef(NodeDef * node_def,const string & funcName,const FunctionApiInfo & apiInfo)37 Status UpdateNodeDef(NodeDef* node_def, const string& funcName,
38 const FunctionApiInfo& apiInfo) {
39 VLOG(3) << "Node def before swap is: " << node_def->DebugString();
40 auto tin = node_def->mutable_attr()->find("Tin");
41 tin->second.mutable_list()->clear_type();
42 for (const auto& tin_dtype : apiInfo.input_arg_dtypes()) {
43 tin->second.mutable_list()->add_type(tin_dtype);
44 }
45
46 auto tout = node_def->mutable_attr()->find("Tout");
47 tout->second.mutable_list()->clear_type();
48 for (const auto& tout_dtype : apiInfo.output_arg_dtypes()) {
49 tout->second.mutable_list()->add_type(tout_dtype);
50 }
51
52 if (apiInfo.function_type() == FunctionApiInfo::BACKWARD) {
53 // Update the inputs since for backward function, it might have different
54 // number of inputs due the different number output from forward function.
55 // The output of forward function are composed by two parts:
56 // 1. Real output tensors from defun.
57 // 2. Internal states that will be used for gradient calculation.
58 // Part 1 will be static, and part 2 could be different based on the
59 // different implementation.
60
61 const int prev_input_size = node_def->input_size();
62 const int diff = prev_input_size - apiInfo.input_arg_dtypes().size();
63 if (diff >= 0) {
64 for (int i = 0; i < diff; ++i) node_def->mutable_input()->RemoveLast();
65 } else {
66 // Adding new inputs for internal states, the name of the internal states
67 // should be in format "{forward_node_name}:{index}", where the newly
68 // added index should start from last index of the state.
69 // Eg:
70 // {
71 // input: "gradients/unified_lstm/strided_slice_1_grad/StridedSliceGrad"
72 // input: "gradients/zeros_like_1"
73 // input: "gradients/zeros_like_2"
74 // input: "unified_lstm/StatefulPartitionedCall:3"
75 // input: "unified_lstm/StatefulPartitionedCall:4"
76 // # New input should be "unified_lstm/StatefulPartitionedCall:5"
77 // }
78 const string last_input = node_def->input(prev_input_size - 1);
79 const std::vector<string> name_index = ::absl::StrSplit(last_input, ':');
80 if (name_index.size() != 2) {
81 return errors::InvalidArgument(
82 "Invalid format of input node name: ", last_input,
83 " Expected: {forward_node_name}:{index}");
84 }
85 const absl::string_view node_name = name_index[0];
86 int last_index;
87 if (!::absl::SimpleAtoi(name_index[1], &last_index)) {
88 return errors::InvalidArgument(
89 "The index of input node is expected to be number, got: ",
90 name_index[1]);
91 }
92 for (int i = 1; i <= -diff; ++i)
93 node_def->add_input(strings::StrCat(node_name, ":", i + last_index));
94 }
95 }
96
97 node_def->mutable_attr()->find("f")->second.mutable_func()->set_name(
98 funcName);
99
100 VLOG(3) << "Node def after swap is: " << node_def->DebugString();
101 return Status::OK();
102 }
103
LoadFunctions(const GraphDef & graph)104 Status ImplementationSelector::LoadFunctions(
105 const GraphDef& graph) {
106 lib_info_.reset(new FunctionLibraryApiInfo);
107 TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
108 return Status::OK();
109 }
110
MaybeOptimizeFunctionCall(NodeDef * node_def) const111 Status ImplementationSelector::MaybeOptimizeFunctionCall(
112 NodeDef* node_def) const {
113 // There are two ways of calling functions:
114 // 1. By specifying an op name as a function name, or
115 // 2. Via the @defun functional interface, where the real function call
116 // happens with partitionedcall op, and the function name appear as the
117 // attribute with name "f" and type func. In this use case, there are more
118 // attributes need to be taken care, like Tin and Tout which take care of
119 // the DTYPE of input/output.
120 std::vector<string> function_attribute_names;
121 for (const auto& attr : node_def->attr()) {
122 if (attr.second.has_func() &&
123 lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
124 function_attribute_names.emplace_back(attr.first);
125 }
126 }
127
128 if (function_attribute_names.empty() &&
129 lib_info_->GetApiInfo(node_def->op()) == nullptr) {
130 // A regular op, or a function which has no interface.
131 return Status::OK();
132 }
133
134 string task, device;
135 if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
136 return errors::Internal("Could not split device name:", node_def->device());
137 }
138 VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
139 << " = (" << task << ", " << device << ")";
140 DeviceNameUtils::ParsedName parsed_name;
141 DeviceNameUtils::ParseLocalName(device, &parsed_name);
142
143 for (const auto& attr_name : function_attribute_names) {
144 string function_name = node_def->attr().at(attr_name).func().name();
145 std::vector<string> equiv_func_names;
146 TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
147 function_name, &equiv_func_names));
148 for (const auto& func_name : equiv_func_names) {
149 const auto& func_api_info = lib_info_->GetApiInfo(func_name);
150 if (func_api_info->preferred_device() == parsed_name.type) {
151 VLOG(2) << "Swapping: " << function_name << " TO: " << func_name;
152 TF_RETURN_IF_ERROR(UpdateNodeDef(node_def, func_name, *func_api_info));
153 break;
154 }
155 }
156 }
157
158 if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
159 std::vector<string> equiv_func_names;
160 TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
161 node_def->op(), &equiv_func_names));
162 for (const string& func_name : equiv_func_names) {
163 const auto func_api_info = lib_info_->GetApiInfo(func_name);
164 if (func_api_info->preferred_device() == parsed_name.type) {
165 node_def->set_op(func_name);
166 break;
167 }
168 }
169 }
170 return Status::OK();
171 }
172
SelectImplementation(GraphDef * graph) const173 Status ImplementationSelector::SelectImplementation(
174 GraphDef* graph) const {
175 if (!graph->has_library()) {
176 VLOG(2) << "Skipping graph since it does not have function def";
177 return Status::OK();
178 }
179 if (lib_info_->empty()) {
180 VLOG(2) << "Skipping optimization since lib_info is empty";
181 return Status::OK();
182 }
183
184 for (int k = 0; k < graph->node_size(); ++k)
185 TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
186
187 return Status::OK();
188 }
189
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)190 Status ImplementationSelector::Optimize(Cluster* cluster,
191 const GrapplerItem& item,
192 GraphDef* optimized_graph) {
193 *optimized_graph = item.graph;
194 TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
195 return SelectImplementation(optimized_graph);
196 }
197
198 } // end namespace grappler
199 } // end namespace tensorflow
200