1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_CONSTANTS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_CONSTANTS_H_
18
19 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
20 #include "tensorflow/lite/delegates/gpu/common/operations.h"
21 #include "tensorflow/lite/delegates/gpu/common/shape.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h"
27 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29
30 namespace tflite {
31 namespace gpu {
32
33 template <DataType S, typename T>
RearrangeWeightsForConvConstants(const tflite::gpu::Tensor<OHWI,S> & weights,absl::Span<T> dst)34 void RearrangeWeightsForConvConstants(
35 const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
36 const int dst_depth = DivideRoundUp(weights.shape.o, 4);
37 const int src_depth = DivideRoundUp(weights.shape.i, 4);
38 const int kernel_x = weights.shape.w;
39 const int kernel_y = weights.shape.h;
40
41 int counter = 0;
42 for (int s = 0; s < src_depth; ++s) {
43 for (int y = 0; y < kernel_y; ++y) {
44 for (int x = 0; x < kernel_x; ++x) {
45 for (int d = 0; d < dst_depth; ++d) {
46 const int channels_count = std::min(4, weights.shape.i - s * 4);
47 T filters[4];
48 for (int i = 0; i < 4; ++i) {
49 for (int j = 0; j < channels_count; ++j) {
50 const int s_ch = s * 4 + j;
51 const int d_ch = d * 4 + i;
52 if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
53 const int f_index =
54 weights.shape.LinearIndex({d_ch, y, x, s_ch});
55 filters[j][i] = weights.data[f_index];
56 } else {
57 filters[j][i] = 0.0f;
58 }
59 }
60 }
61 for (int i = 0; i < channels_count; ++i) {
62 dst[counter++] = filters[i];
63 }
64 }
65 }
66 }
67 }
68 }
69
70 template <DataType S, typename T>
RearrangeWeightsForConvConstantsDot(const tflite::gpu::Tensor<OHWI,S> & weights,absl::Span<T> dst)71 void RearrangeWeightsForConvConstantsDot(
72 const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
73 const int dst_depth = DivideRoundUp(weights.shape.o, 4);
74 const int src_depth = DivideRoundUp(weights.shape.i, 4);
75 const int kernel_x = weights.shape.w;
76 const int kernel_y = weights.shape.h;
77
78 int counter = 0;
79 for (int s = 0; s < src_depth; ++s) {
80 for (int y = 0; y < kernel_y; ++y) {
81 for (int x = 0; x < kernel_x; ++x) {
82 for (int d = 0; d < dst_depth; ++d) {
83 const int channels_count = std::min(4, weights.shape.o - d * 4);
84 T filters[4];
85 for (int j = 0; j < channels_count; ++j) {
86 for (int i = 0; i < 4; ++i) {
87 const int s_ch = s * 4 + i;
88 const int d_ch = d * 4 + j;
89 if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
90 const int f_index =
91 weights.shape.LinearIndex({d_ch, y, x, s_ch});
92 filters[j][i] = weights.data[f_index];
93 } else {
94 filters[j][i] = 0.0f;
95 }
96 }
97 }
98 for (int i = 0; i < channels_count; ++i) {
99 dst[counter++] = filters[i];
100 }
101 }
102 }
103 }
104 }
105 }
106
107 template <DataType T>
UploadWeightsForConvConstants(const tflite::gpu::Tensor<OHWI,T> & weights,CalculationsPrecision precision,bool use_dot_conv,GPUOperation * op)108 void UploadWeightsForConvConstants(const tflite::gpu::Tensor<OHWI, T>& weights,
109 CalculationsPrecision precision,
110 bool use_dot_conv, GPUOperation* op) {
111 const int src_depth = DivideRoundUp(weights.shape.i, 4);
112 const int dst_depth = DivideRoundUp(weights.shape.o, 4);
113 const int kernel_x = weights.shape.w;
114 const int kernel_y = weights.shape.h;
115
116 const bool f32_weights = precision == CalculationsPrecision::F32;
117 const int float_size = f32_weights ? 4 : 2;
118 const int aligned_ch_count = use_dot_conv ? weights.shape.o * src_depth * 4
119 : weights.shape.i * dst_depth * 4;
120 const int float_count = aligned_ch_count * kernel_x * kernel_y;
121
122 BufferDescriptor desc;
123 desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
124 desc.element_size = 4;
125 desc.memory_type = MemoryType::CONSTANT;
126 desc.size = float_size * float_count;
127 desc.data.resize(desc.size);
128
129 if (f32_weights) {
130 float4* ptr = reinterpret_cast<float4*>(desc.data.data());
131 if (use_dot_conv) {
132 RearrangeWeightsForConvConstantsDot(weights,
133 absl::MakeSpan(ptr, float_count / 4));
134 } else {
135 RearrangeWeightsForConvConstants(weights,
136 absl::MakeSpan(ptr, float_count / 4));
137 }
138 } else {
139 half4* ptr = reinterpret_cast<half4*>(desc.data.data());
140 if (use_dot_conv) {
141 RearrangeWeightsForConvConstantsDot(weights,
142 absl::MakeSpan(ptr, float_count / 4));
143 } else {
144 RearrangeWeightsForConvConstants(weights,
145 absl::MakeSpan(ptr, float_count / 4));
146 }
147 }
148
149 op->args_.AddObject("weights",
150 absl::make_unique<BufferDescriptor>(std::move(desc)));
151 }
152
153 bool IsConvConstantsSupported(const GpuInfo& gpu_info,
154 const OperationDef& definition,
155 const Convolution2DAttributes& attr);
156
157 GPUOperation CreateConvConstants(const GpuInfo& gpu_info,
158 const OperationDef& definition,
159 const Convolution2DAttributes& attr);
160
161 } // namespace gpu
162 } // namespace tflite
163
164 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_CONSTANTS_H_
165