• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
17 
18 #include "absl/types/any.h"
19 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
20 #include "tensorflow/lite/delegates/gpu/common/operations.h"
21 #include "tensorflow/lite/delegates/gpu/common/shape.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h"
25 #include "tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h"
26 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
27 
28 namespace tflite {
29 namespace gpu {
30 namespace {
TryDepthwiseConvPlus1x1Conv(CalculationsPrecision precision,const GraphFloat32 & graph,NodeId first_node_id,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)31 absl::Status TryDepthwiseConvPlus1x1Conv(
32     CalculationsPrecision precision, const GraphFloat32& graph,
33     NodeId first_node_id,
34     const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
35     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
36   auto* dw_node = graph.GetNode(first_node_id);
37   if (dw_node == nullptr) {
38     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
39   }
40   if (OperationTypeFromString(dw_node->operation.type) !=
41       OperationType::DEPTHWISE_CONVOLUTION) {
42     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
43   }
44   auto dw_inputs = graph.FindInputs(dw_node->id);
45   if (dw_inputs.size() != 1) {
46     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
47   }
48   auto dw_outputs = graph.FindOutputs(dw_node->id);
49   auto consumers = graph.FindConsumers(dw_outputs[0]->id);
50   if (consumers.size() != 1) {
51     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
52   }
53   auto* conv_node = consumers[0];
54   if (conv_node == nullptr) {
55     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
56   }
57   if (consumed_nodes->find(conv_node->id) != consumed_nodes->end()) {
58     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
59   }
60   if (OperationTypeFromString(conv_node->operation.type) !=
61       OperationType::CONVOLUTION_2D) {
62     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
63   }
64   if (graph.FindInputs(conv_node->id).size() != 1) {
65     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
66   }
67   auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
68       dw_node->operation.attributes);
69   auto conv_attr =
70       absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
71   auto conv_outputs = graph.FindOutputs(conv_node->id);
72   OperationDef op_def;
73   op_def.precision = precision;
74   auto it = tensor_descriptors.find(dw_inputs[0]->id);
75   if (it != tensor_descriptors.end()) {
76     op_def.src_tensors.push_back(it->second);
77   }
78   it = tensor_descriptors.find(conv_outputs[0]->id);
79   if (it != tensor_descriptors.end()) {
80     op_def.dst_tensors.push_back(it->second);
81   }
82   if (!IsDepthwiseConvPlus1x1ConvSupported(op_def, dw_attr, conv_attr)) {
83     return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
84   }
85   std::unique_ptr<GPUOperation>* gpu_op =
86       InitSingleOpSubgraph(dw_inputs, conv_outputs, gpu_subgraph);
87   auto operation = CreateDepthwiseConvPlus1x1Conv(op_def, dw_attr, conv_attr);
88   *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
89   consumed_nodes->insert(dw_node->id);
90   consumed_nodes->insert(conv_node->id);
91   return absl::OkStatus();
92 }
93 
94 // fully connected + fully connected + add
TryFCFCAdd(const GpuInfo & gpu_info,CalculationsPrecision precision,const GraphFloat32 & graph,NodeId first_node_id,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)95 absl::Status TryFCFCAdd(
96     const GpuInfo& gpu_info, CalculationsPrecision precision,
97     const GraphFloat32& graph, NodeId first_node_id,
98     const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
99     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
100   auto* fc0_node = graph.GetNode(first_node_id);
101   if (fc0_node == nullptr) {
102     return absl::NotFoundError("FCFCAdd not suitable.");
103   }
104   if (OperationTypeFromString(fc0_node->operation.type) !=
105       OperationType::FULLY_CONNECTED) {
106     return absl::NotFoundError("FCFCAdd not suitable.");
107   }
108   auto fc0_inputs = graph.FindInputs(fc0_node->id);
109   if (fc0_inputs.size() != 1) {
110     return absl::NotFoundError("FCFCAdd not suitable.");
111   }
112   auto fc0_output_id = graph.FindOutputs(fc0_node->id)[0]->id;
113   auto consumers = graph.FindConsumers(fc0_output_id);
114   if (consumers.size() != 1) {
115     return absl::NotFoundError("FCFCAdd not suitable.");
116   }
117   auto* add_node = consumers[0];
118   if (add_node == nullptr) {
119     return absl::NotFoundError("FCFCAdd not suitable.");
120   }
121   if (consumed_nodes->find(add_node->id) != consumed_nodes->end()) {
122     return absl::NotFoundError("FCFCAdd not suitable.");
123   }
124   if (OperationTypeFromString(add_node->operation.type) != OperationType::ADD) {
125     return absl::NotFoundError("FCFCAdd not suitable.");
126   }
127   auto add_inputs = graph.FindInputs(add_node->id);
128   if (add_inputs.size() != 2) {
129     return absl::NotFoundError("FCFCAdd not suitable.");
130   }
131   auto fc1_output_id = add_inputs[0]->id + add_inputs[1]->id - fc0_output_id;
132   auto* fc1_node = graph.FindProducer(fc1_output_id);
133   if (fc1_node == nullptr) {
134     return absl::NotFoundError("FCFCAdd not suitable.");
135   }
136   if (OperationTypeFromString(fc1_node->operation.type) !=
137       OperationType::FULLY_CONNECTED) {
138     return absl::NotFoundError("FCFCAdd not suitable.");
139   }
140   if (consumed_nodes->find(fc1_node->id) != consumed_nodes->end()) {
141     return absl::NotFoundError("FCFCAdd not suitable.");
142   }
143   auto fc1_inputs = graph.FindInputs(fc1_node->id);
144   if (fc1_inputs.size() != 1) {
145     return absl::NotFoundError("FCFCAdd not suitable.");
146   }
147   auto fc0_attr =
148       absl::any_cast<FullyConnectedAttributes>(fc0_node->operation.attributes);
149   auto fc1_attr =
150       absl::any_cast<FullyConnectedAttributes>(fc1_node->operation.attributes);
151   if (fc0_attr.weights.shape.o != fc1_attr.weights.shape.o) {
152     return absl::NotFoundError("FCFCAdd not suitable.");
153   }
154   auto add_outputs = graph.FindOutputs(add_node->id);
155 
156   OperationDef op_def;
157   op_def.precision = precision;
158   auto it = tensor_descriptors.find(fc0_inputs[0]->id);
159   if (it != tensor_descriptors.end()) {
160     op_def.src_tensors.push_back(it->second);
161   }
162   it = tensor_descriptors.find(fc1_inputs[0]->id);
163   if (it != tensor_descriptors.end()) {
164     op_def.src_tensors.push_back(it->second);
165   }
166   it = tensor_descriptors.find(add_outputs[0]->id);
167   if (it != tensor_descriptors.end()) {
168     op_def.dst_tensors.push_back(it->second);
169   }
170 
171   for (int i = 0; i < fc1_inputs.size(); ++i) {
172     fc0_inputs.push_back(fc1_inputs[i]);
173   }
174   std::unique_ptr<GPUOperation>* gpu_op =
175       InitSingleOpSubgraph(fc0_inputs, add_outputs, gpu_subgraph);
176   FCFCAdd fc = CreateFCFCAdd(gpu_info, op_def, fc0_attr, fc1_attr);
177   *gpu_op = absl::make_unique<FCFCAdd>(std::move(fc));
178   consumed_nodes->insert(fc0_node->id);
179   consumed_nodes->insert(fc1_node->id);
180   consumed_nodes->insert(add_node->id);
181   return absl::OkStatus();
182 }
183 }  // namespace
184 
GPUSubgraphFromGraph(const GpuInfo & gpu_info,CalculationsPrecision precision,const GraphFloat32 & graph,NodeId first_node_id,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph,std::string * name)185 absl::Status GPUSubgraphFromGraph(
186     const GpuInfo& gpu_info, CalculationsPrecision precision,
187     const GraphFloat32& graph, NodeId first_node_id,
188     const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
189     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph,
190     std::string* name) {
191   if ((gpu_info.IsAdreno() || gpu_info.IsNvidia()) &&
192       TryDepthwiseConvPlus1x1Conv(precision, graph, first_node_id,
193                                   tensor_descriptors, consumed_nodes,
194                                   gpu_subgraph)
195           .ok()) {
196     *name = "depthwise_conv_plus_1x1_conv";
197     return absl::OkStatus();
198   }
199   if ((gpu_info.IsIntel() || gpu_info.IsNvidia()) &&
200       TryFCFCAdd(gpu_info, precision, graph, first_node_id, tensor_descriptors,
201                  consumed_nodes, gpu_subgraph)
202           .ok()) {
203     *name = "fully_connected_x2_and_add";
204     return absl::OkStatus();
205   }
206   return absl::NotFoundError("No special combination.");
207 }
208 
209 }  // namespace gpu
210 }  // namespace tflite
211