• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/utils.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/lite/builtin_ops.h"
22 #include "tensorflow/lite/context_util.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace delegates {
27 
CreateNewTensorWithDifferentType(TfLiteContext * context,const int original_tensor_index,TfLiteType new_type,TfLiteTensor ** new_tensor,int * new_tensor_index)28 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
29                                               const int original_tensor_index,
30                                               TfLiteType new_type,
31                                               TfLiteTensor** new_tensor,
32                                               int* new_tensor_index) {
33   TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
34   const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
35   *new_tensor = &context->tensors[*new_tensor_index];
36   (*new_tensor)->type = new_type;
37   (*new_tensor)->allocation_type = kTfLiteArenaRw;
38   const auto* original_dims = original_tensor.dims;
39   TfLiteIntArray* dims = TfLiteIntArrayCreate(original_dims->size);
40   for (int i = 0; i < original_dims->size; ++i) {
41     dims->data[i] = original_dims->data[i];
42   }
43   if (context->ResizeTensor(context, *new_tensor, dims) != kTfLiteOk) {
44     TF_LITE_KERNEL_LOG(context, "Could not resize new delegate tensor");
45     return kTfLiteError;
46   }
47   return kTfLiteOk;
48 }
49 
Partition(std::set<std::string> * unsupported_nodes_info)50 TfLiteStatus GraphPartitionHelper::Partition(
51     std::set<std::string>* unsupported_nodes_info) {
52   const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
53   if (prepare_status != kTfLiteOk) return prepare_status;
54 
55   TfLiteDelegateParams* partition_params_array_ = nullptr;
56   int num_partitions_ = 0;
57   if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
58                                             &partition_params_array_,
59                                             &num_partitions_) != kTfLiteOk) {
60     TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n");
61     return kTfLiteError;
62   }
63 
64   for (int i = 0; i < num_partitions_; ++i) {
65     partitions_.push_back(partition_params_array_ + i);
66   }
67 
68   return kTfLiteOk;
69 }
70 
71 std::vector<TfLiteDelegateParams*>
GetFirstNLargestPartitions(int n,int min_nodes_per_partition) const72 GraphPartitionHelper::GetFirstNLargestPartitions(
73     int n, int min_nodes_per_partition) const {
74   // In general, the number of partitions in a delegate is never likely to be
75   // high enough to cause latency issues. Also considering this is generally a
76   // one-time work, we simply unconditionally sort partitions here according to
77   // the size.
78   std::vector<TfLiteDelegateParams*> sorted_partitions(partitions_);
79   std::sort(sorted_partitions.begin(), sorted_partitions.end(),
80             [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) {
81               // Reverse sort
82               return left->nodes_to_replace->size >
83                      right->nodes_to_replace->size;
84             });
85 
86   std::vector<TfLiteDelegateParams*> results;
87   auto p_it = sorted_partitions.begin();
88   const int total = sorted_partitions.size();
89   for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
90     auto* p = (*p_it);
91     if (p->nodes_to_replace->size < min_nodes_per_partition) {
92       break;
93     }
94     results.push_back(p);
95   }
96   return results;
97 }
98 
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)99 std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
100     int n, int min_nodes_per_partition) {
101   auto first_n_partitions =
102       GetFirstNLargestPartitions(n, min_nodes_per_partition);
103   std::vector<int> ops_to_replace;
104   for (const auto p : first_n_partitions) {
105     auto nodes = p->nodes_to_replace;
106     ops_to_replace.insert(ops_to_replace.end(), nodes->data,
107                           nodes->data + nodes->size);
108   }
109   return ops_to_replace;
110 }
111 
PrepareSupportedNodes(std::set<std::string> * unsupported_nodes_info)112 TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
113     std::set<std::string>* unsupported_nodes_info) {
114   if (!is_node_supported_fn_) return kTfLiteOk;
115 
116   TfLiteIntArray* execution_plan = nullptr;
117   auto status = context_->GetExecutionPlan(context_, &execution_plan);
118   if (status != kTfLiteOk) {
119     TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n");
120     return status;
121   }
122 
123   num_total_nodes_ = execution_plan->size;
124   supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
125   supported_nodes_->size = 0;
126   for (int node_id : TfLiteIntArrayView(execution_plan)) {
127     TfLiteNode* node;
128     TfLiteRegistration* registration;
129 
130     status = context_->GetNodeAndRegistration(context_, node_id, &node,
131                                               &registration);
132     if (status != kTfLiteOk) {
133       TF_LITE_KERNEL_LOG(context_,
134                          "Couldn't get node and registration info for op: %d\n",
135                          node_id);
136       supported_nodes_->size = 0;
137       return status;
138     }
139 
140     std::string unsupported_details;
141     if (IsNodeSupported(context_, node, registration, node_id,
142                         &unsupported_details)) {
143       supported_nodes_->data[supported_nodes_->size++] = node_id;
144     } else if (unsupported_nodes_info) {
145       std::string node_info = GetOpNameByRegistration(*registration);
146       node_info.append(": ");
147       node_info.append(unsupported_details);
148       unsupported_nodes_info->insert(node_info);
149     }
150   }
151   return kTfLiteOk;
152 }
153 
154 std::vector<int>
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)155 FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
156     int n, int min_nodes_per_partition) {
157   auto first_n_partitions =
158       GetFirstNLargestPartitions(n, min_nodes_per_partition);
159   std::vector<int> ops_to_replace;
160   if (first_n_partitions.empty()) return ops_to_replace;
161 
162   // Handle the first delegated partition specially.
163   // All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can
164   // be added to the ops to delegate. Others have to be preserved in the graph,
165   // since the partitioning algorithm will put such nodes greedily in the first
166   // partition.
167   const auto* first_partition = first_n_partitions[0];
168   std::unordered_map<int, int> delegated_dequant_consumers;
169   for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) {
170     const int node_id = first_partition->nodes_to_replace->data[i];
171     ops_to_replace.push_back(node_id);
172     TfLiteNode* node;
173     TfLiteRegistration* registration;
174     const auto status = context_->GetNodeAndRegistration(context_, node_id,
175                                                          &node, &registration);
176     if (status != kTfLiteOk) {
177       TF_LITE_KERNEL_LOG(context_,
178                          "Couldn't get node and registration info for op: %d\n",
179                          node_id);
180       ops_to_replace.clear();
181       return ops_to_replace;
182     }
183     // See if any input to the op is a (converted) fp16 value. If yes, increment
184     // its value in delegated_dequant_consumers.
185     for (int j = 0; j < node->inputs->size; ++j) {
186       const int input_tid = node->inputs->data[j];
187       if (constant_dequant_consumers_.find(input_tid) !=
188           constant_dequant_consumers_.end()) {
189         delegated_dequant_consumers[input_tid] += 1;
190       }
191     }
192   }
193   // Check all dequant nodes that have some consumers in the first partition.
194   // If the number of delegated consumers is same as total number of consumers,
195   // add the corresponding DEQUANTIZE op to the delegated nodes.
196   for (auto tensor_and_consumers : delegated_dequant_consumers) {
197     if (constant_dequant_consumers_[tensor_and_consumers.first] ==
198         tensor_and_consumers.second) {
199       ops_to_replace.emplace_back(
200           constant_dequant_nodes_[tensor_and_consumers.first]);
201     }
202   }
203 
204   // For all other partitions after the first one, insert all nodes into
205   // ops_to_replace.
206   for (int i = 1; i < first_n_partitions.size(); ++i) {
207     auto nodes = first_n_partitions[i]->nodes_to_replace;
208     ops_to_replace.insert(ops_to_replace.end(), nodes->data,
209                           nodes->data + nodes->size);
210   }
211 
212   // Modify the inputs of relevant ops that support fp16 constants.
213   // TODO(b/156707497): Ensure that these inputs are remapped during the
214   // delegate's 'free', so that CPU fallback works for fp16 models.
215   RemapFp16InputTensors(ops_to_replace);
216   return ops_to_replace;
217 }
218 
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)219 bool FP16GraphPartitionHelper::IsNodeSupported(
220     TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
221     int node_id, std::string* unsupported_details) {
222   if (registration->builtin_code == kTfLiteBuiltinDequantize) {
223     auto& dequantize_input = context_->tensors[node->inputs->data[0]];
224     if (dequantize_input.type == kTfLiteFloat16 &&
225         IsConstantTensor(&dequantize_input)) {
226       // Update mappings if this node is a fp16 DEQUANTIZE node that
227       // works on a **constant** input tensor.
228       // If the input is not a constant, the remapping that we do here will
229       // cause bugs due to preceding ops such as DENSIFY.
230       constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
231       constant_dequant_nodes_[node->outputs->data[0]] = node_id;
232       // We do not accept these ops right now.
233       // This is done to support use-cases where a DEQUANTIZE output might be
234       // consumed by a CPU op.
235       return false;
236     }
237   }
238 
239   // To check if a (possibly) FP16 node is supported, we temporarily point the
240   // node's inputs to the original fp16 tensors. This 'mutated' node is then
241   // passed to the base IsNodeSupported function for checking. After the check,
242   // we remap the original node inputs, so that the TFLite graph remains the
243   // same.
244   std::vector<int> orig_inputs;
245   if (!constant_dequant_nodes_.empty()) {
246     RemapFp16InputTensors(node, &orig_inputs);
247   }
248 
249   const auto is_supported = GraphPartitionHelper::IsNodeSupported(
250       context, node, registration, node_id, unsupported_details);
251 
252   if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
253     // Remapping happened. Restore original inputs.
254     for (int j = 0; j < node->inputs->size; ++j) {
255       node->inputs->data[j] = orig_inputs[j];
256       if (constant_dequant_nodes_.find(orig_inputs[j]) !=
257           constant_dequant_nodes_.end()) {
258         // If its a fp16 tensor, increment number of consumers of the
259         // corresponding DEQUANTIZE.
260         constant_dequant_consumers_[orig_inputs[j]] += 1;
261       }
262     }
263   }
264   return is_supported;
265 }
266 
RemapFp16InputTensors(const std::vector<int> & nodes) const267 void FP16GraphPartitionHelper::RemapFp16InputTensors(
268     const std::vector<int>& nodes) const {
269   for (int node_id : nodes) {
270     TfLiteNode* node;
271     TfLiteRegistration* registration;
272     TfLiteStatus status = context_->GetNodeAndRegistration(
273         context_, node_id, &node, &registration);
274     if (status != kTfLiteOk) {
275       TF_LITE_KERNEL_LOG(context_,
276                          "Couldn't get node and registration info for op: %d\n",
277                          node_id);
278     }
279     RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
280   }
281 }
282 
RemapFp16InputTensors(TfLiteNode * node,std::vector<int> * orig_inputs) const283 void FP16GraphPartitionHelper::RemapFp16InputTensors(
284     TfLiteNode* node, std::vector<int>* orig_inputs) const {
285   TfLiteIntArray* inputs = node->inputs;
286   auto inputs_view = TfLiteIntArrayView(inputs);
287   // Prepopulate 'orig_inputs' first and clear it if there's no input from a
288   // dequant op.
289   if (orig_inputs) {
290     orig_inputs->clear();
291     orig_inputs->reserve(inputs->size);
292     for (auto tid : inputs_view) {
293       orig_inputs->push_back(tid);
294     }
295   }
296   // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
297   // order to test if it is supported.
298   bool is_remapped = false;
299   for (int j = 0; j < inputs->size; ++j) {
300     const int input_tid = inputs->data[j];
301     const auto it = constant_dequant_map_.find(input_tid);
302     if (it != constant_dequant_map_.end()) {
303       inputs->data[j] = it->second;
304       is_remapped = true;
305     }
306   }
307   if (!is_remapped && orig_inputs) orig_inputs->clear();
308 }
309 
310 }  // namespace delegates
311 }  // namespace tflite
312