• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_UTILS_H_
18 
19 // Utility functions and classes for implementing delegates.
20 
21 #include <functional>
22 #include <limits>
23 #include <set>
24 #include <string>
25 #include <unordered_map>
26 #include <utility>
27 #include <vector>
28 
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/util.h"
31 
32 namespace tflite {
33 namespace delegates {
34 
35 // Creates a new Read/Write tensor having the same shape as the original, but
36 // with a different type. Note that this might void existing references to
37 // tensors.
38 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
39                                               const int original_tensor_index,
40                                               TfLiteType new_type,
41                                               TfLiteTensor** new_tensor,
42                                               int* new_tensor_index);
43 
44 using IsNodeSupportedFn =
45     std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
46                        std::string* unsupported_details)>;
47 
48 // A utility class to help model graph parition.
49 // Note the class *needs* to be used in TfLiteDelegate::Prepare.
50 class GraphPartitionHelper {
51  public:
GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)52   GraphPartitionHelper(TfLiteContext* context,
53                        IsNodeSupportedFn is_node_supported_fn)
54       : context_(context), is_node_supported_fn_(is_node_supported_fn) {}
55 
GraphPartitionHelper(TfLiteContext * context,const std::vector<int> & supported_node_indices)56   GraphPartitionHelper(TfLiteContext* context,
57                        const std::vector<int>& supported_node_indices)
58       : context_(context),
59         num_total_nodes_(supported_node_indices.size()),
60         supported_nodes_(
61             ConvertVectorToTfLiteIntArray(supported_node_indices)) {}
62 
~GraphPartitionHelper()63   virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
64 
65   // Partition the graph into node subsets such that each subset could be
66   // replaced with one delegate kernel (i.e. a kTfLiteBuiltinDelegate op).
67   // If 'unsupported_nodes_info' is provided, it will be populated with
68   // information about all different unsupported nodes.
69   virtual TfLiteStatus Partition(std::set<std::string>* unsupported_nodes_info);
70 
71   // Returns the first n largest partitions or all if #partitions is less than
72   // 'n' and each parition has at least (>=) 'min_nodes_per_partition' nodes.
73   // Note that partitions are ranked according to the number of nodes that
74   // a partition has, and the returned TfLiteDelegateParams objects are *owned*
75   // by the TfLite runtime.
76   // TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions
77   std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(
78       int n = std::numeric_limits<int>::max(),
79       int min_nodes_per_partition = 0) const;
80 
81   // Returns a list of node indices of all nodes from the first n largest
82   // partitions. If there are fewer paritions than n, all nodes will be
83   // returned. The partition is ranked according to the number of nodes.
84   std::vector<int> GetNodesOfFirstNLargestPartitions(
85       int n = std::numeric_limits<int>::max(),
86       int min_nodes_per_partition = 0) {
87     // Separated implementation that can be overrided, to preserve default value
88     return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition);
89   }
90 
num_total_nodes()91   int num_total_nodes() const { return num_total_nodes_; }
num_partitions()92   int num_partitions() const { return partitions_.size(); }
93 
94  protected:
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)95   virtual bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
96                                TfLiteRegistration* registration, int node_id,
97                                std::string* unsupported_details) {
98     return is_node_supported_fn_(context, node, registration,
99                                  unsupported_details);
100   }
101   virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
102       int n, int min_nodes_per_partition);
103 
104   TfLiteContext* const context_ = nullptr;
105 
106   // Doesn't own the memory of each TfLiteDelegateParams object as it's
107   // managed by the TfLite runtime itself. See
108   // TfLiteContext::PreviewDelegatePartitioning for details.
109   std::vector<TfLiteDelegateParams*> partitions_;
110 
111  private:
112   // Generate a list of supported nodes (i.e. populating 'supported_nodes_') by
113   // iterating over all nodes (i,e. those listed in the execution_plan
114   // associated w/ 'context_').
115   // If 'unsupported_nodes_info' is provided, it will be populated with
116   // information about all different unsupported nodes.
117   TfLiteStatus PrepareSupportedNodes(
118       std::set<std::string>* unsupported_nodes_info = nullptr);
119 
120   // The number of total nodes passed in for partitioning (i.e. the
121   // execution_plan size associated w/ 'context_')
122   int num_total_nodes_ = 0;
123 
124   // Tells if a node is supported as it could be delegated.
125   const IsNodeSupportedFn is_node_supported_fn_ = nullptr;
126 
127   // Contains an array of supported node indices.
128   TfLiteIntArray* supported_nodes_ = nullptr;  // owns the memory
129 };
130 
131 // Specialized partitioner for graphs that possibly contain fp16 tensors.
132 //
133 // From nodes that accept fp16 inputs, this delegates the following:
134 // 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs
135 // by the delegate (in the TFLite graph, these nodes take in dequantized FP32
136 // outputs).
137 // 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
138 // delegated partition. This is because TFLite's partitioning algorithm
139 // greedily puts all such nodes in the first partition.
140 class FP16GraphPartitionHelper : public GraphPartitionHelper {
141  public:
FP16GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)142   FP16GraphPartitionHelper(TfLiteContext* context,
143                            IsNodeSupportedFn is_node_supported_fn)
144       : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
145 
146  protected:
147   // Specialized function to handle fp16 nodes.
148   bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
149                        TfLiteRegistration* registration, int node_id,
150                        std::string* unsupported_details) override;
151 
152   // This will remap input tensors by removing FP16 to FP32 dequantized tensors.
153   std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
154       int n, int min_nodes_per_partition) override;
155 
156  private:
157   // This remaps fp32 inputs of the given node to their corresponding fp16
158   // version, if applicable. Can be summarized as:
159   // fp16 -> DEQUANTIZE -> fp32 -> OP -> output
160   // becomes
161   // fp16 -> OP -> output
162   void RemapFp16InputTensors(TfLiteNode* node,
163                              std::vector<int>* orig_inputs) const;
164 
165   // Performs the above remapping for all nodes in the given list, without
166   // tracking the original inputs.
167   void RemapFp16InputTensors(const std::vector<int>& nodes) const;
168 
169   // ('dequantize' here refers to fp16 DEQUANTIZE)
170   // Mapping of dequantize nodes' output tensor-id to its node id.
171   // TODO(b/156707497): Use absl hash_maps here.
172   std::unordered_map<int, int> constant_dequant_nodes_;
173   // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
174   std::unordered_map<int, int> constant_dequant_map_;
175   // mapping of DEQUANTIZE output tensor-id to its number of consumers.
176   std::unordered_map<int, int> constant_dequant_consumers_;
177 };
178 
179 }  // namespace delegates
180 }  // namespace tflite
181 
182 #endif  // TENSORFLOW_LITE_DELEGATES_UTILS_H_
183