1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_ 17 #define TENSORFLOW_LITE_DELEGATES_UTILS_H_ 18 19 // Utility functions and classes for implementing delegates. 20 21 #include <functional> 22 #include <limits> 23 #include <set> 24 #include <string> 25 #include <unordered_map> 26 #include <utility> 27 #include <vector> 28 29 #include "tensorflow/lite/c/common.h" 30 #include "tensorflow/lite/util.h" 31 32 namespace tflite { 33 namespace delegates { 34 35 // Creates a new Read/Write tensor having the same shape as the original, but 36 // with a different type. Note that this might void existing references to 37 // tensors. 38 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context, 39 const int original_tensor_index, 40 TfLiteType new_type, 41 TfLiteTensor** new_tensor, 42 int* new_tensor_index); 43 44 using IsNodeSupportedFn = 45 std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*, 46 std::string* unsupported_details)>; 47 48 // A utility class to help model graph parition. 49 // Note the class *needs* to be used in TfLiteDelegate::Prepare. 50 class GraphPartitionHelper { 51 public: GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)52 GraphPartitionHelper(TfLiteContext* context, 53 IsNodeSupportedFn is_node_supported_fn) 54 : context_(context), is_node_supported_fn_(is_node_supported_fn) {} 55 GraphPartitionHelper(TfLiteContext * context,const std::vector<int> & supported_node_indices)56 GraphPartitionHelper(TfLiteContext* context, 57 const std::vector<int>& supported_node_indices) 58 : context_(context), 59 num_total_nodes_(supported_node_indices.size()), 60 supported_nodes_( 61 ConvertVectorToTfLiteIntArray(supported_node_indices)) {} 62 ~GraphPartitionHelper()63 virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } 64 65 // Partition the graph into node subsets such that each subset could be 66 // replaced with one delegate kernel (i.e. a kTfLiteBuiltinDelegate op). 67 // If 'unsupported_nodes_info' is provided, it will be populated with 68 // information about all different unsupported nodes. 69 virtual TfLiteStatus Partition(std::set<std::string>* unsupported_nodes_info); 70 71 // Returns the first n largest partitions or all if #partitions is less than 72 // 'n' and each parition has at least (>=) 'min_nodes_per_partition' nodes. 73 // Note that partitions are ranked according to the number of nodes that 74 // a partition has, and the returned TfLiteDelegateParams objects are *owned* 75 // by the TfLite runtime. 76 // TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions 77 std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions( 78 int n = std::numeric_limits<int>::max(), 79 int min_nodes_per_partition = 0) const; 80 81 // Returns a list of node indices of all nodes from the first n largest 82 // partitions. If there are fewer paritions than n, all nodes will be 83 // returned. The partition is ranked according to the number of nodes. 84 std::vector<int> GetNodesOfFirstNLargestPartitions( 85 int n = std::numeric_limits<int>::max(), 86 int min_nodes_per_partition = 0) { 87 // Separated implementation that can be overrided, to preserve default value 88 return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition); 89 } 90 num_total_nodes()91 int num_total_nodes() const { return num_total_nodes_; } num_partitions()92 int num_partitions() const { return partitions_.size(); } 93 94 protected: IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)95 virtual bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, 96 TfLiteRegistration* registration, int node_id, 97 std::string* unsupported_details) { 98 return is_node_supported_fn_(context, node, registration, 99 unsupported_details); 100 } 101 virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl( 102 int n, int min_nodes_per_partition); 103 104 TfLiteContext* const context_ = nullptr; 105 106 // Doesn't own the memory of each TfLiteDelegateParams object as it's 107 // managed by the TfLite runtime itself. See 108 // TfLiteContext::PreviewDelegatePartitioning for details. 109 std::vector<TfLiteDelegateParams*> partitions_; 110 111 private: 112 // Generate a list of supported nodes (i.e. populating 'supported_nodes_') by 113 // iterating over all nodes (i,e. those listed in the execution_plan 114 // associated w/ 'context_'). 115 // If 'unsupported_nodes_info' is provided, it will be populated with 116 // information about all different unsupported nodes. 117 TfLiteStatus PrepareSupportedNodes( 118 std::set<std::string>* unsupported_nodes_info = nullptr); 119 120 // The number of total nodes passed in for partitioning (i.e. the 121 // execution_plan size associated w/ 'context_') 122 int num_total_nodes_ = 0; 123 124 // Tells if a node is supported as it could be delegated. 125 const IsNodeSupportedFn is_node_supported_fn_ = nullptr; 126 127 // Contains an array of supported node indices. 128 TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory 129 }; 130 131 // Specialized partitioner for graphs that possibly contain fp16 tensors. 132 // 133 // From nodes that accept fp16 inputs, this delegates the following: 134 // 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs 135 // by the delegate (in the TFLite graph, these nodes take in dequantized FP32 136 // outputs). 137 // 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first* 138 // delegated partition. This is because TFLite's partitioning algorithm 139 // greedily puts all such nodes in the first partition. 140 class FP16GraphPartitionHelper : public GraphPartitionHelper { 141 public: FP16GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)142 FP16GraphPartitionHelper(TfLiteContext* context, 143 IsNodeSupportedFn is_node_supported_fn) 144 : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} 145 146 protected: 147 // Specialized function to handle fp16 nodes. 148 bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, 149 TfLiteRegistration* registration, int node_id, 150 std::string* unsupported_details) override; 151 152 // This will remap input tensors by removing FP16 to FP32 dequantized tensors. 153 std::vector<int> GetNodesOfFirstNLargestPartitionsImpl( 154 int n, int min_nodes_per_partition) override; 155 156 private: 157 // This remaps fp32 inputs of the given node to their corresponding fp16 158 // version, if applicable. Can be summarized as: 159 // fp16 -> DEQUANTIZE -> fp32 -> OP -> output 160 // becomes 161 // fp16 -> OP -> output 162 void RemapFp16InputTensors(TfLiteNode* node, 163 std::vector<int>* orig_inputs) const; 164 165 // Performs the above remapping for all nodes in the given list, without 166 // tracking the original inputs. 167 void RemapFp16InputTensors(const std::vector<int>& nodes) const; 168 169 // ('dequantize' here refers to fp16 DEQUANTIZE) 170 // Mapping of dequantize nodes' output tensor-id to its node id. 171 // TODO(b/156707497): Use absl hash_maps here. 172 std::unordered_map<int, int> constant_dequant_nodes_; 173 // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16). 174 std::unordered_map<int, int> constant_dequant_map_; 175 // mapping of DEQUANTIZE output tensor-id to its number of consumers. 176 std::unordered_map<int, int> constant_dequant_consumers_; 177 }; 178 179 } // namespace delegates 180 } // namespace tflite 181 182 #endif // TENSORFLOW_LITE_DELEGATES_UTILS_H_ 183