1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/utils.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "tensorflow/lite/builtin_ops.h"
22 #include "tensorflow/lite/context_util.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace delegates {
27
CreateNewTensorWithDifferentType(TfLiteContext * context,const int original_tensor_index,TfLiteType new_type,TfLiteTensor ** new_tensor,int * new_tensor_index)28 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
29 const int original_tensor_index,
30 TfLiteType new_type,
31 TfLiteTensor** new_tensor,
32 int* new_tensor_index) {
33 TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
34 const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
35 *new_tensor = &context->tensors[*new_tensor_index];
36 (*new_tensor)->type = new_type;
37 (*new_tensor)->allocation_type = kTfLiteArenaRw;
38 const auto* original_dims = original_tensor.dims;
39 TfLiteIntArray* dims = TfLiteIntArrayCreate(original_dims->size);
40 for (int i = 0; i < original_dims->size; ++i) {
41 dims->data[i] = original_dims->data[i];
42 }
43 if (context->ResizeTensor(context, *new_tensor, dims) != kTfLiteOk) {
44 TF_LITE_KERNEL_LOG(context, "Could not resize new delegate tensor");
45 return kTfLiteError;
46 }
47 return kTfLiteOk;
48 }
49
Partition(std::set<std::string> * unsupported_nodes_info)50 TfLiteStatus GraphPartitionHelper::Partition(
51 std::set<std::string>* unsupported_nodes_info) {
52 const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
53 if (prepare_status != kTfLiteOk) return prepare_status;
54
55 TfLiteDelegateParams* partition_params_array_ = nullptr;
56 int num_partitions_ = 0;
57 if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
58 &partition_params_array_,
59 &num_partitions_) != kTfLiteOk) {
60 TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n");
61 return kTfLiteError;
62 }
63
64 for (int i = 0; i < num_partitions_; ++i) {
65 partitions_.push_back(partition_params_array_ + i);
66 }
67
68 return kTfLiteOk;
69 }
70
71 std::vector<TfLiteDelegateParams*>
GetFirstNLargestPartitions(int n,int min_nodes_per_partition) const72 GraphPartitionHelper::GetFirstNLargestPartitions(
73 int n, int min_nodes_per_partition) const {
74 // In general, the number of partitions in a delegate is never likely to be
75 // high enough to cause latency issues. Also considering this is generally a
76 // one-time work, we simply unconditionally sort partitions here according to
77 // the size.
78 std::vector<TfLiteDelegateParams*> sorted_partitions(partitions_);
79 std::sort(sorted_partitions.begin(), sorted_partitions.end(),
80 [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) {
81 // Reverse sort
82 return left->nodes_to_replace->size >
83 right->nodes_to_replace->size;
84 });
85
86 std::vector<TfLiteDelegateParams*> results;
87 auto p_it = sorted_partitions.begin();
88 const int total = sorted_partitions.size();
89 for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
90 auto* p = (*p_it);
91 if (p->nodes_to_replace->size < min_nodes_per_partition) {
92 break;
93 }
94 results.push_back(p);
95 }
96 return results;
97 }
98
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)99 std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
100 int n, int min_nodes_per_partition) {
101 auto first_n_partitions =
102 GetFirstNLargestPartitions(n, min_nodes_per_partition);
103 std::vector<int> ops_to_replace;
104 for (const auto p : first_n_partitions) {
105 auto nodes = p->nodes_to_replace;
106 ops_to_replace.insert(ops_to_replace.end(), nodes->data,
107 nodes->data + nodes->size);
108 }
109 return ops_to_replace;
110 }
111
PrepareSupportedNodes(std::set<std::string> * unsupported_nodes_info)112 TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
113 std::set<std::string>* unsupported_nodes_info) {
114 if (!is_node_supported_fn_) return kTfLiteOk;
115
116 TfLiteIntArray* execution_plan = nullptr;
117 auto status = context_->GetExecutionPlan(context_, &execution_plan);
118 if (status != kTfLiteOk) {
119 TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n");
120 return status;
121 }
122
123 num_total_nodes_ = execution_plan->size;
124 supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
125 supported_nodes_->size = 0;
126 for (int node_id : TfLiteIntArrayView(execution_plan)) {
127 TfLiteNode* node;
128 TfLiteRegistration* registration;
129
130 status = context_->GetNodeAndRegistration(context_, node_id, &node,
131 ®istration);
132 if (status != kTfLiteOk) {
133 TF_LITE_KERNEL_LOG(context_,
134 "Couldn't get node and registration info for op: %d\n",
135 node_id);
136 supported_nodes_->size = 0;
137 return status;
138 }
139
140 std::string unsupported_details;
141 if (IsNodeSupported(context_, node, registration, node_id,
142 &unsupported_details)) {
143 supported_nodes_->data[supported_nodes_->size++] = node_id;
144 } else if (unsupported_nodes_info) {
145 std::string node_info = GetOpNameByRegistration(*registration);
146 node_info.append(": ");
147 node_info.append(unsupported_details);
148 unsupported_nodes_info->insert(node_info);
149 }
150 }
151 return kTfLiteOk;
152 }
153
154 std::vector<int>
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)155 FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
156 int n, int min_nodes_per_partition) {
157 auto first_n_partitions =
158 GetFirstNLargestPartitions(n, min_nodes_per_partition);
159 std::vector<int> ops_to_replace;
160 if (first_n_partitions.empty()) return ops_to_replace;
161
162 // Handle the first delegated partition specially.
163 // All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can
164 // be added to the ops to delegate. Others have to be preserved in the graph,
165 // since the partitioning algorithm will put such nodes greedily in the first
166 // partition.
167 const auto* first_partition = first_n_partitions[0];
168 std::unordered_map<int, int> delegated_dequant_consumers;
169 for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) {
170 const int node_id = first_partition->nodes_to_replace->data[i];
171 ops_to_replace.push_back(node_id);
172 TfLiteNode* node;
173 TfLiteRegistration* registration;
174 const auto status = context_->GetNodeAndRegistration(context_, node_id,
175 &node, ®istration);
176 if (status != kTfLiteOk) {
177 TF_LITE_KERNEL_LOG(context_,
178 "Couldn't get node and registration info for op: %d\n",
179 node_id);
180 ops_to_replace.clear();
181 return ops_to_replace;
182 }
183 // See if any input to the op is a (converted) fp16 value. If yes, increment
184 // its value in delegated_dequant_consumers.
185 for (int j = 0; j < node->inputs->size; ++j) {
186 const int input_tid = node->inputs->data[j];
187 if (constant_dequant_consumers_.find(input_tid) !=
188 constant_dequant_consumers_.end()) {
189 delegated_dequant_consumers[input_tid] += 1;
190 }
191 }
192 }
193 // Check all dequant nodes that have some consumers in the first partition.
194 // If the number of delegated consumers is same as total number of consumers,
195 // add the corresponding DEQUANTIZE op to the delegated nodes.
196 for (auto tensor_and_consumers : delegated_dequant_consumers) {
197 if (constant_dequant_consumers_[tensor_and_consumers.first] ==
198 tensor_and_consumers.second) {
199 ops_to_replace.emplace_back(
200 constant_dequant_nodes_[tensor_and_consumers.first]);
201 }
202 }
203
204 // For all other partitions after the first one, insert all nodes into
205 // ops_to_replace.
206 for (int i = 1; i < first_n_partitions.size(); ++i) {
207 auto nodes = first_n_partitions[i]->nodes_to_replace;
208 ops_to_replace.insert(ops_to_replace.end(), nodes->data,
209 nodes->data + nodes->size);
210 }
211
212 // Modify the inputs of relevant ops that support fp16 constants.
213 // TODO(b/156707497): Ensure that these inputs are remapped during the
214 // delegate's 'free', so that CPU fallback works for fp16 models.
215 RemapFp16InputTensors(ops_to_replace);
216 return ops_to_replace;
217 }
218
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)219 bool FP16GraphPartitionHelper::IsNodeSupported(
220 TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
221 int node_id, std::string* unsupported_details) {
222 if (registration->builtin_code == kTfLiteBuiltinDequantize) {
223 auto& dequantize_input = context_->tensors[node->inputs->data[0]];
224 if (dequantize_input.type == kTfLiteFloat16 &&
225 IsConstantTensor(&dequantize_input)) {
226 // Update mappings if this node is a fp16 DEQUANTIZE node that
227 // works on a **constant** input tensor.
228 // If the input is not a constant, the remapping that we do here will
229 // cause bugs due to preceding ops such as DENSIFY.
230 constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
231 constant_dequant_nodes_[node->outputs->data[0]] = node_id;
232 // We do not accept these ops right now.
233 // This is done to support use-cases where a DEQUANTIZE output might be
234 // consumed by a CPU op.
235 return false;
236 }
237 }
238
239 // To check if a (possibly) FP16 node is supported, we temporarily point the
240 // node's inputs to the original fp16 tensors. This 'mutated' node is then
241 // passed to the base IsNodeSupported function for checking. After the check,
242 // we remap the original node inputs, so that the TFLite graph remains the
243 // same.
244 std::vector<int> orig_inputs;
245 if (!constant_dequant_nodes_.empty()) {
246 RemapFp16InputTensors(node, &orig_inputs);
247 }
248
249 const auto is_supported = GraphPartitionHelper::IsNodeSupported(
250 context, node, registration, node_id, unsupported_details);
251
252 if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
253 // Remapping happened. Restore original inputs.
254 for (int j = 0; j < node->inputs->size; ++j) {
255 node->inputs->data[j] = orig_inputs[j];
256 if (constant_dequant_nodes_.find(orig_inputs[j]) !=
257 constant_dequant_nodes_.end()) {
258 // If its a fp16 tensor, increment number of consumers of the
259 // corresponding DEQUANTIZE.
260 constant_dequant_consumers_[orig_inputs[j]] += 1;
261 }
262 }
263 }
264 return is_supported;
265 }
266
RemapFp16InputTensors(const std::vector<int> & nodes) const267 void FP16GraphPartitionHelper::RemapFp16InputTensors(
268 const std::vector<int>& nodes) const {
269 for (int node_id : nodes) {
270 TfLiteNode* node;
271 TfLiteRegistration* registration;
272 TfLiteStatus status = context_->GetNodeAndRegistration(
273 context_, node_id, &node, ®istration);
274 if (status != kTfLiteOk) {
275 TF_LITE_KERNEL_LOG(context_,
276 "Couldn't get node and registration info for op: %d\n",
277 node_id);
278 }
279 RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
280 }
281 }
282
RemapFp16InputTensors(TfLiteNode * node,std::vector<int> * orig_inputs) const283 void FP16GraphPartitionHelper::RemapFp16InputTensors(
284 TfLiteNode* node, std::vector<int>* orig_inputs) const {
285 TfLiteIntArray* inputs = node->inputs;
286 auto inputs_view = TfLiteIntArrayView(inputs);
287 // Prepopulate 'orig_inputs' first and clear it if there's no input from a
288 // dequant op.
289 if (orig_inputs) {
290 orig_inputs->clear();
291 orig_inputs->reserve(inputs->size);
292 for (auto tid : inputs_view) {
293 orig_inputs->push_back(tid);
294 }
295 }
296 // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
297 // order to test if it is supported.
298 bool is_remapped = false;
299 for (int j = 0; j < inputs->size; ++j) {
300 const int input_tid = inputs->data[j];
301 const auto it = constant_dequant_map_.find(input_tid);
302 if (it != constant_dequant_map_.end()) {
303 inputs->data[j] = it->second;
304 is_remapped = true;
305 }
306 }
307 if (!is_remapped && orig_inputs) orig_inputs->clear();
308 }
309
310 } // namespace delegates
311 } // namespace tflite
312