• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/memory/memory.h"
17 #include "tensorflow/lite/delegates/gpu/common/operations.h"
18 #include "tensorflow/lite/delegates/gpu/common/status.h"
19 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_buffer_1x1.h"
20 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
21 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h"
22 #include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h"
23 
24 namespace tflite {
25 namespace gpu {
26 
SelectFullyConnectedGeneric(const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const OperationDef & op_def,int batch_size)27 std::unique_ptr<GPUOperation> SelectFullyConnectedGeneric(
28     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
29     const OperationDef& op_def, int batch_size) {
30   if (op_def.IsBatchSupported()) {
31     BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o);
32     ConvPowerVR conv = CreateConvPowerVR(gpu_info, op_def, attr, &dst_shape);
33     return absl::make_unique<ConvPowerVR>(std::move(conv));
34   } else {
35     FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
36     return absl::make_unique<FullyConnected>(std::move(fc));
37   }
38 }
39 
SelectFullyConnectedAdreno(const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const OperationDef & op_def,int batch_size)40 std::unique_ptr<GPUOperation> SelectFullyConnectedAdreno(
41     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
42     const OperationDef& op_def, int batch_size) {
43   if (op_def.IsBatchSupported()) {
44     BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o);
45     ConvPowerVR conv = CreateConvPowerVR(gpu_info, op_def, attr, &dst_shape);
46     return absl::make_unique<ConvPowerVR>(std::move(conv));
47   } else {
48     FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
49     return absl::make_unique<FullyConnected>(std::move(fc));
50   }
51 }
52 
SelectFullyConnectedPowerVR(const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const OperationDef & op_def,int batch_size)53 std::unique_ptr<GPUOperation> SelectFullyConnectedPowerVR(
54     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
55     const OperationDef& op_def, int batch_size) {
56   if (op_def.IsBatchSupported()) {
57     ConvPowerVR conv = CreateConvPowerVR(gpu_info, op_def, attr);
58     return absl::make_unique<ConvPowerVR>(std::move(conv));
59   } else {
60     FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
61     return absl::make_unique<FullyConnected>(std::move(fc));
62   }
63 }
64 
SelectFullyConnectedMali(const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const OperationDef & op_def,int batch_size)65 std::unique_ptr<GPUOperation> SelectFullyConnectedMali(
66     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
67     const OperationDef& op_def, int batch_size) {
68   if (op_def.IsBatchSupported()) {
69     if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
70       ConvBuffer1x1 conv = CreateConvBuffer1x1(gpu_info, op_def, attr);
71       return absl::make_unique<ConvBuffer1x1>(std::move(conv));
72     } else {
73       BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o);
74       ConvPowerVR conv =
75           CreateConvPowerVR(gpu_info, op_def, attr, &dst_shape);
76       return absl::make_unique<ConvPowerVR>(std::move(conv));
77     }
78   } else {
79     FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
80     return absl::make_unique<FullyConnected>(std::move(fc));
81   }
82 }
83 
SelectFullyConnected(const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const OperationDef & op_def,int batch_size)84 std::unique_ptr<GPUOperation> SelectFullyConnected(
85     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
86     const OperationDef& op_def, int batch_size) {
87   if (gpu_info.IsApiMetal()) {
88     if (op_def.IsBatchSupported() && IsConvolutionMetalSupported(op_def)) {
89       BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o);
90       Convolution2DAttributes conv_attr;
91       conv_attr.padding.prepended = HW(0, 0);
92       conv_attr.padding.appended = HW(0, 0);
93       conv_attr.strides = HW(1, 1);
94       conv_attr.dilations = HW(1, 1);
95       conv_attr.weights = attr.weights;
96       conv_attr.bias = attr.bias;
97       ConvolutionMetal conv =
98           CreateConvolutionMetal(op_def, dst_shape, conv_attr, gpu_info);
99       return absl::make_unique<ConvolutionMetal>(std::move(conv));
100     } else {
101       FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
102       return absl::make_unique<FullyConnected>(std::move(fc));
103     }
104   } else if (gpu_info.IsAdreno()) {
105     return SelectFullyConnectedAdreno(attr, gpu_info, op_def, batch_size);
106   } else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsNvidia() ||
107              gpu_info.IsIntel() || gpu_info.IsApple()) {
108     return SelectFullyConnectedPowerVR(attr, gpu_info, op_def, batch_size);
109   } else if (gpu_info.IsMali()) {
110     return SelectFullyConnectedMali(attr, gpu_info, op_def, batch_size);
111   } else {
112     return SelectFullyConnectedGeneric(attr, gpu_info, op_def, batch_size);
113   }
114 }
115 
116 }  // namespace gpu
117 }  // namespace tflite
118