• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <vector>
22 
23 #include "tensorflow/lite/builtin_ops.h"
24 #include "tensorflow/lite/context_util.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace delegates {
29 
CreateNewTensorWithDifferentType(TfLiteContext * context,const int original_tensor_index,TfLiteType new_type,TfLiteTensor ** new_tensor,int * new_tensor_index)30 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
31                                               const int original_tensor_index,
32                                               TfLiteType new_type,
33                                               TfLiteTensor** new_tensor,
34                                               int* new_tensor_index) {
35   TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
36   const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
37   *new_tensor = &context->tensors[*new_tensor_index];
38   (*new_tensor)->type = new_type;
39   (*new_tensor)->allocation_type = kTfLiteArenaRw;
40   const auto* original_dims = original_tensor.dims;
41   TfLiteIntArray* dims = TfLiteIntArrayCreate(original_dims->size);
42   for (int i = 0; i < original_dims->size; ++i) {
43     dims->data[i] = original_dims->data[i];
44   }
45   if (context->ResizeTensor(context, *new_tensor, dims) != kTfLiteOk) {
46     TF_LITE_KERNEL_LOG(context, "Could not resize new delegate tensor");
47     return kTfLiteError;
48   }
49   return kTfLiteOk;
50 }
51 
Partition(std::set<std::string> * unsupported_nodes_info)52 TfLiteStatus GraphPartitionHelper::Partition(
53     std::set<std::string>* unsupported_nodes_info) {
54   const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
55   if (prepare_status != kTfLiteOk) return prepare_status;
56 
57   TfLiteDelegateParams* partition_params_array_ = nullptr;
58   int num_partitions_ = 0;
59   if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
60                                             &partition_params_array_,
61                                             &num_partitions_) != kTfLiteOk) {
62     TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n");
63     return kTfLiteError;
64   }
65 
66   for (int i = 0; i < num_partitions_; ++i) {
67     partitions_.push_back(partition_params_array_ + i);
68   }
69 
70   return kTfLiteOk;
71 }
72 
73 std::vector<TfLiteDelegateParams*>
GetFirstNLargestPartitions(int n,int min_nodes_per_partition) const74 GraphPartitionHelper::GetFirstNLargestPartitions(
75     int n, int min_nodes_per_partition) const {
76   // In general, the number of partitions in a delegate is never likely to be
77   // high enough to cause latency issues. Also considering this is generally a
78   // one-time work, we simply unconditionally sort partitions here according to
79   // the size.
80   std::vector<TfLiteDelegateParams*> sorted_partitions(partitions_);
81   std::sort(sorted_partitions.begin(), sorted_partitions.end(),
82             [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) {
83               // Reverse sort
84               return left->nodes_to_replace->size >
85                      right->nodes_to_replace->size;
86             });
87 
88   std::vector<TfLiteDelegateParams*> results;
89   auto p_it = sorted_partitions.begin();
90   const int total = sorted_partitions.size();
91   for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
92     auto* p = (*p_it);
93     if (p->nodes_to_replace->size < min_nodes_per_partition) {
94       break;
95     }
96     results.push_back(p);
97   }
98   return results;
99 }
100 
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)101 std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
102     int n, int min_nodes_per_partition) {
103   auto first_n_partitions =
104       GetFirstNLargestPartitions(n, min_nodes_per_partition);
105   std::vector<int> ops_to_replace;
106   for (const auto p : first_n_partitions) {
107     auto nodes = p->nodes_to_replace;
108     ops_to_replace.insert(ops_to_replace.end(), nodes->data,
109                           nodes->data + nodes->size);
110   }
111   return ops_to_replace;
112 }
113 
PrepareSupportedNodes(std::set<std::string> * unsupported_nodes_info)114 TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
115     std::set<std::string>* unsupported_nodes_info) {
116   if (!is_node_supported_fn_) return kTfLiteOk;
117 
118   TfLiteIntArray* execution_plan = nullptr;
119   auto status = context_->GetExecutionPlan(context_, &execution_plan);
120   if (status != kTfLiteOk) {
121     TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n");
122     return status;
123   }
124   // context->GetExecutionPlan invalidates memory obtained from previous calls,
125   // which is dangerous if a delegate's IsNodeSupportedFn uses it anywhere.
126   // So we store a copy to ensure validity.
127   num_total_nodes_ = execution_plan->size;
128   original_execution_plan_ = TfLiteIntArrayCreate(execution_plan->size);
129   std::memcpy(original_execution_plan_->data, execution_plan->data,
130               num_total_nodes_ * sizeof(int32_t));
131 
132   supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
133   supported_nodes_->size = 0;
134   for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
135     TfLiteNode* node;
136     TfLiteRegistration* registration;
137 
138     status = context_->GetNodeAndRegistration(context_, node_id, &node,
139                                               &registration);
140     if (status != kTfLiteOk) {
141       TF_LITE_KERNEL_LOG(context_,
142                          "Couldn't get node and registration info for op: %d\n",
143                          node_id);
144       supported_nodes_->size = 0;
145       return status;
146     }
147 
148     std::string unsupported_details;
149     if (IsNodeSupported(context_, node, registration, node_id,
150                         &unsupported_details)) {
151       supported_nodes_->data[supported_nodes_->size++] = node_id;
152     } else if (unsupported_nodes_info) {
153       std::string node_info = GetOpNameByRegistration(*registration);
154       node_info.append(": ");
155       node_info.append(unsupported_details);
156       unsupported_nodes_info->insert(node_info);
157     }
158   }
159 
160   num_supported_nodes_ = supported_nodes_->size;
161   return kTfLiteOk;
162 }
163 
164 std::vector<int>
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)165 FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
166     int n, int min_nodes_per_partition) {
167   std::vector<int> ops_to_replace;
168 
169   if (num_supported_nodes() + constant_dequant_nodes_.size() ==
170       num_total_nodes()) {
171     // Scenario 1: Full Delegation.
172     // We delegate all nodes in this case to avoid unnecessary partitions due to
173     // FP16 DEQUANT nodes. This is safe to do since no non-delegated op needs
174     // the output of such a DEQUANT.
175     for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
176       ops_to_replace.push_back(node_id);
177     }
178   } else {
179     // Scenario 2: Partial Delegation.
180     // In this case, we just select the top 'n' applicable node subsets to
181     // delegate, devoid of any FP16 DEQUANT ops. Handling the latter is tricky
182     // in partial delegation cases & causes edge cases if non-delegated nodes
183     // consume their output. So we keep all of them on CPU.
184     auto first_n_partitions =
185         GetFirstNLargestPartitions(n, min_nodes_per_partition);
186     if (first_n_partitions.empty()) return ops_to_replace;
187     for (int i = 0; i < first_n_partitions.size(); ++i) {
188       auto nodes = first_n_partitions[i]->nodes_to_replace;
189       ops_to_replace.insert(ops_to_replace.end(), nodes->data,
190                             nodes->data + nodes->size);
191     }
192   }
193 
194   // Modify the inputs of relevant ops that support fp16 constants.
195   RemapFp16InputTensors(ops_to_replace);
196   return ops_to_replace;
197 }
198 
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)199 bool FP16GraphPartitionHelper::IsNodeSupported(
200     TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
201     int node_id, std::string* unsupported_details) {
202   if (registration->builtin_code == kTfLiteBuiltinDequantize) {
203     auto& dequantize_input = context_->tensors[node->inputs->data[0]];
204     if (dequantize_input.type == kTfLiteFloat16 &&
205         IsConstantTensor(&dequantize_input)) {
206       // Update mappings if this node is a fp16 DEQUANTIZE node that
207       // works on a **constant** input tensor.
208       // If the input is not a constant, the remapping that we do here will
209       // cause bugs due to preceding ops such as DENSIFY.
210       constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
211       constant_dequant_nodes_[node->outputs->data[0]] = node_id;
212       // We do not accept these ops right now.
213       // This is done to support use-cases where a DEQUANTIZE output might be
214       // consumed by a CPU op.
215       return false;
216     }
217   }
218 
219   // To check if a (possibly) FP16 node is supported, we temporarily point the
220   // node's inputs to the original fp16 tensors. This 'mutated' node is then
221   // passed to the base IsNodeSupported function for checking. After the check,
222   // we remap the original node inputs, so that the TFLite graph remains the
223   // same.
224   std::vector<int> orig_inputs;
225   if (!constant_dequant_nodes_.empty()) {
226     RemapFp16InputTensors(node, &orig_inputs);
227   }
228 
229   const auto is_supported = GraphPartitionHelper::IsNodeSupported(
230       context, node, registration, node_id, unsupported_details);
231 
232   if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
233     // Remapping happened. Restore original inputs.
234     for (int j = 0; j < node->inputs->size; ++j) {
235       node->inputs->data[j] = orig_inputs[j];
236     }
237   }
238   return is_supported;
239 }
240 
RemapFp16InputTensors(const std::vector<int> & nodes) const241 void FP16GraphPartitionHelper::RemapFp16InputTensors(
242     const std::vector<int>& nodes) const {
243   for (int node_id : nodes) {
244     TfLiteNode* node;
245     TfLiteRegistration* registration;
246     TfLiteStatus status = context_->GetNodeAndRegistration(
247         context_, node_id, &node, &registration);
248     if (status != kTfLiteOk) {
249       TF_LITE_KERNEL_LOG(context_,
250                          "Couldn't get node and registration info for op: %d\n",
251                          node_id);
252     }
253     RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
254   }
255 }
256 
RemapFp16InputTensors(TfLiteNode * node,std::vector<int> * orig_inputs) const257 void FP16GraphPartitionHelper::RemapFp16InputTensors(
258     TfLiteNode* node, std::vector<int>* orig_inputs) const {
259   TfLiteIntArray* inputs = node->inputs;
260   auto inputs_view = TfLiteIntArrayView(inputs);
261   // Prepopulate 'orig_inputs' first and clear it if there's no input from a
262   // dequant op.
263   if (orig_inputs) {
264     orig_inputs->clear();
265     orig_inputs->reserve(inputs->size);
266     for (auto tid : inputs_view) {
267       orig_inputs->push_back(tid);
268     }
269   }
270   // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
271   // order to test if it is supported.
272   bool is_remapped = false;
273   for (int j = 0; j < inputs->size; ++j) {
274     const int input_tid = inputs->data[j];
275     const auto it = constant_dequant_map_.find(input_tid);
276     if (it != constant_dequant_map_.end()) {
277       inputs->data[j] = it->second;
278       is_remapped = true;
279     }
280   }
281   if (!is_remapped && orig_inputs) orig_inputs->clear();
282 }
283 
284 }  // namespace delegates
285 }  // namespace tflite
286