1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h"
17
18 #include <math.h>
19
20 #include <set>
21 #include <vector>
22
23 #include "tensorflow/lite/delegates/gpu/common/util.h"
24
25 namespace tflite {
26 namespace gpu {
27
28 namespace {
29
30 template <typename T>
AddCornerCases(const T & grid,int max_work_group_total_size,const T & max_work_group_sizes,WorkGroupSizeAlignment x_alignment,WorkGroupSizeAlignment y_alignment,WorkGroupSizeAlignment z_alignment,std::vector<T> * work_groups)31 void AddCornerCases(const T& grid, int max_work_group_total_size,
32 const T& max_work_group_sizes,
33 WorkGroupSizeAlignment x_alignment,
34 WorkGroupSizeAlignment y_alignment,
35 WorkGroupSizeAlignment z_alignment,
36 std::vector<T>* work_groups) {
37 for (int x = 1; x <= 4; ++x) {
38 for (int y = 1; y <= 4; ++y) {
39 for (int z = 1; z <= 4; ++z) {
40 int wg_x = DivideRoundUp(grid.x, x);
41 int wg_y = DivideRoundUp(grid.y, y);
42 int wg_z = DivideRoundUp(grid.z, z);
43 if (wg_x > max_work_group_sizes.x || wg_y > max_work_group_sizes.y ||
44 wg_z > max_work_group_sizes.z ||
45 wg_x * wg_y * wg_z > max_work_group_total_size) {
46 continue;
47 }
48 if (x_alignment == WorkGroupSizeAlignment::PRECISE &&
49 grid.x % wg_x != 0) {
50 continue;
51 }
52 if (y_alignment == WorkGroupSizeAlignment::PRECISE &&
53 grid.y % wg_y != 0) {
54 continue;
55 }
56 if (z_alignment == WorkGroupSizeAlignment::PRECISE &&
57 grid.z % wg_z != 0) {
58 continue;
59 }
60 work_groups->push_back({wg_x, wg_y, wg_z});
61 }
62 }
63 }
64
65 // this will add at least {1, 1, 1} always.
66 for (int x = 1; x <= 4; ++x) {
67 for (int y = 1; y <= 4; ++y) {
68 for (int z = 1; z <= 4; ++z) {
69 if (x > max_work_group_sizes.x || y > max_work_group_sizes.y ||
70 z > max_work_group_sizes.z ||
71 x * y * z > max_work_group_total_size) {
72 continue;
73 }
74 if (x_alignment == WorkGroupSizeAlignment::PRECISE && grid.x % x != 0) {
75 continue;
76 }
77 if (y_alignment == WorkGroupSizeAlignment::PRECISE && grid.y % y != 0) {
78 continue;
79 }
80 if (z_alignment == WorkGroupSizeAlignment::PRECISE && grid.z % z != 0) {
81 continue;
82 }
83 work_groups->push_back({x, y, z});
84 }
85 }
86 }
87 }
88
GetDivisors(int number)89 std::vector<int> GetDivisors(int number) {
90 const int max_divisor = static_cast<int>(std::sqrt(number));
91 std::vector<int> divisors;
92 // we don't know the number of dividers, so it is just heuristic.
93 divisors.reserve(max_divisor / 3 + 1);
94 for (int i = 1; i <= max_divisor; ++i) {
95 const int d = number / i;
96 if (i * d == number) {
97 divisors.push_back(i);
98 if (d != i) {
99 divisors.push_back(d);
100 }
101 }
102 }
103 return divisors;
104 }
105
GetDivisorsForRange(int number,int range)106 std::vector<int> GetDivisorsForRange(int number, int range) {
107 const int last_number = number + range;
108 const int max_divisor = static_cast<int>(std::sqrt(last_number));
109 std::set<int> divisors;
110 for (int i = 1; i <= max_divisor; ++i) {
111 const int reminder = number % i;
112 // iterate through numbers that divisible by i in our range;
113 const int first_number = number + (i - reminder) % i;
114 if (first_number <= last_number) {
115 divisors.insert(i);
116 }
117 for (int j = first_number; j <= last_number; j += i) {
118 const int d = j / i;
119 if (d != i) {
120 divisors.insert(d);
121 }
122 }
123 }
124 return std::vector<int>(divisors.begin(), divisors.end());
125 }
126
127 } // namespace
128
GetPossibleSizes(int number,WorkGroupSizeAlignment z_alignment)129 std::vector<int> GetPossibleSizes(int number,
130 WorkGroupSizeAlignment z_alignment) {
131 if (z_alignment == WorkGroupSizeAlignment::PRECISE) {
132 // we will use for potential sizes, sizes that cover grid precisely
133 // work group size * k (k is integer) == grid_size
134 return GetDivisors(number);
135 } else {
136 // when we chose work group size we can use work group size that
137 // work group size * k (k is integer) != grid_size (slightly bigger)
138 // so in this heuristic we trying to find potential size, that satisfies
139 // to this : work group size * k (k is integer) <= grid_size + 5
140 // and this : work group size * k (k is integer) >= grid_size
141 return GetDivisorsForRange(number, 5);
142 }
143 }
144
145 template <typename T>
GenerateWorkGroupSizes(const T & grid,int min_work_group_total_size,int max_work_group_total_size,const T & max_work_group_sizes,WorkGroupSizeAlignment x_alignment,WorkGroupSizeAlignment y_alignment,WorkGroupSizeAlignment z_alignment)146 std::vector<T> GenerateWorkGroupSizes(
147 const T& grid, int min_work_group_total_size, int max_work_group_total_size,
148 const T& max_work_group_sizes, WorkGroupSizeAlignment x_alignment,
149 WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment) {
150 std::vector<T> work_groups;
151 work_groups.reserve(64);
152
153 std::vector<int> sizes_x = GetPossibleSizes(grid.x, x_alignment);
154 std::vector<int> sizes_y = GetPossibleSizes(grid.y, y_alignment);
155 std::vector<int> sizes_z = GetPossibleSizes(grid.z, z_alignment);
156
157 for (auto x : sizes_x) {
158 if (x > max_work_group_sizes.x) continue;
159 for (auto y : sizes_y) {
160 if (y > max_work_group_sizes.y) continue;
161 for (auto z : sizes_z) {
162 if (z > max_work_group_sizes.z) continue;
163 const int work_group_size = x * y * z;
164 if (work_group_size < min_work_group_total_size ||
165 work_group_size > max_work_group_total_size)
166 continue;
167 work_groups.push_back({x, y, z});
168 }
169 }
170 }
171
172 return work_groups;
173 }
174
175 // Specializations of GenerateWorkGroupSizes for int3 and uint3
176
177 template std::vector<int3> GenerateWorkGroupSizes(
178 const int3& grid, int min_work_group_total_size,
179 int max_work_group_total_size, const int3& max_work_group_sizes,
180 WorkGroupSizeAlignment x_alignment, WorkGroupSizeAlignment y_alignment,
181 WorkGroupSizeAlignment z_alignment);
182
183 template std::vector<uint3> GenerateWorkGroupSizes(
184 const uint3& grid, int min_work_group_total_size,
185 int max_work_group_total_size, const uint3& max_work_group_sizes,
186 WorkGroupSizeAlignment x_alignment, WorkGroupSizeAlignment y_alignment,
187 WorkGroupSizeAlignment z_alignment);
188
189 template <typename T>
GenerateWorkGroupSizesAlignedToGrid(const T & grid,const T & max_work_group_size,const int max_work_group_total_size,std::vector<T> * work_groups)190 void GenerateWorkGroupSizesAlignedToGrid(const T& grid,
191 const T& max_work_group_size,
192 const int max_work_group_total_size,
193 std::vector<T>* work_groups) {
194 auto alignment = WorkGroupSizeAlignment::PRECISE;
195 *work_groups = GenerateWorkGroupSizes<T>(
196 grid, /*min_work_group_total_size = */ 32, max_work_group_total_size,
197 max_work_group_size, alignment, alignment, alignment);
198 // If the grid parameter too small, method below cannot generate workgroups.
199 if (work_groups->empty()) {
200 AddCornerCases(grid, max_work_group_total_size, max_work_group_size,
201 alignment, alignment, alignment, work_groups);
202 }
203 }
204
205 // Specializations of GenerateWorkGroupSizesAlignedToGrid for int3 and uint3
206
207 template void GenerateWorkGroupSizesAlignedToGrid(
208 const int3& grid, const int3& max_work_group_size,
209 const int max_work_group_total_size, std::vector<int3>* work_groups);
210
211 template void GenerateWorkGroupSizesAlignedToGrid(
212 const uint3& grid, const uint3& max_work_group_size,
213 const int max_work_group_total_size, std::vector<uint3>* work_groups);
214
215 } // namespace gpu
216 } // namespace tflite
217