1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/zero_init_workgroup_memory.h"
16
17 #include <algorithm>
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "src/ast/workgroup_decoration.h"
24 #include "src/program_builder.h"
25 #include "src/sem/atomic_type.h"
26 #include "src/sem/function.h"
27 #include "src/sem/variable.h"
28 #include "src/utils/map.h"
29 #include "src/utils/unique_vector.h"
30
31 TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
32
33 namespace tint {
34 namespace transform {
35
36 /// PIMPL state for the ZeroInitWorkgroupMemory transform
37 struct ZeroInitWorkgroupMemory::State {
38 /// The clone context
39 CloneContext& ctx;
40
41 /// An alias to *ctx.dst
42 ProgramBuilder& b = *ctx.dst;
43
44 /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
45 /// be used instead.
46 uint32_t workgroup_size_const = 0;
47 /// The size of the workgroup as an expression generator. Use if
48 /// #workgroup_size_const is 0.
49 std::function<const ast::Expression*()> workgroup_size_expr;
50
51 /// ArrayIndex represents a function on the local invocation index, of
52 /// the form: `array_index = (local_invocation_index % modulo) / division`
53 struct ArrayIndex {
54 /// The RHS of the modulus part of the expression
55 uint32_t modulo = 1;
56 /// The RHS of the division part of the expression
57 uint32_t division = 1;
58
59 /// Equality operator
60 /// @param i the ArrayIndex to compare to this ArrayIndex
61 /// @returns true if `i` and this ArrayIndex are equal
operator ==tint::transform::ZeroInitWorkgroupMemory::State::ArrayIndex62 bool operator==(const ArrayIndex& i) const {
63 return modulo == i.modulo && division == i.division;
64 }
65
66 /// Hash function for the ArrayIndex type
67 struct Hasher {
68 /// @param i the ArrayIndex to calculate a hash for
69 /// @returns the hash value for the ArrayIndex `i`
operator ()tint::transform::ZeroInitWorkgroupMemory::State::ArrayIndex::Hasher70 size_t operator()(const ArrayIndex& i) const {
71 return utils::Hash(i.modulo, i.division);
72 }
73 };
74 };
75
76 /// A list of unique ArrayIndex
77 using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
78
79 /// Expression holds information about an expression that is being built for a
80 /// statement will zero workgroup values.
81 struct Expression {
82 /// The AST expression node
83 const ast::Expression* expr = nullptr;
84 /// The number of iterations required to zero the value
85 uint32_t num_iterations = 0;
86 /// All array indices used by this expression
87 ArrayIndices array_indices;
88 };
89
90 /// Statement holds information about a statement that will zero workgroup
91 /// values.
92 struct Statement {
93 /// The AST statement node
94 const ast::Statement* stmt;
95 /// The number of iterations required to zero the value
96 uint32_t num_iterations;
97 /// All array indices used by this statement
98 ArrayIndices array_indices;
99 };
100
101 /// All statements that zero workgroup memory
102 std::vector<Statement> statements;
103
104 /// A map of ArrayIndex to the name reserved for the `let` declaration of that
105 /// index.
106 std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
107
108 /// Constructor
109 /// @param c the CloneContext used for the transform
Statetint::transform::ZeroInitWorkgroupMemory::State110 explicit State(CloneContext& c) : ctx(c) {}
111
112 /// Run inserts the workgroup memory zero-initialization logic at the top of
113 /// the given function
114 /// @param fn a compute shader entry point function
Runtint::transform::ZeroInitWorkgroupMemory::State115 void Run(const ast::Function* fn) {
116 auto& sem = ctx.src->Sem();
117
118 CalculateWorkgroupSize(
119 ast::GetDecoration<ast::WorkgroupDecoration>(fn->decorations));
120
121 // Generate a list of statements to zero initialize each of the
122 // workgroup storage variables used by `fn`. This will populate #statements.
123 auto* func = sem.Get(fn);
124 for (auto* var : func->TransitivelyReferencedGlobals()) {
125 if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
126 BuildZeroingStatements(
127 var->Type()->UnwrapRef(), [&](uint32_t num_values) {
128 auto var_name = ctx.Clone(var->Declaration()->symbol);
129 return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
130 });
131 }
132 }
133
134 if (statements.empty()) {
135 return; // No workgroup variables to initialize.
136 }
137
138 // Scan the entry point for an existing local_invocation_index builtin
139 // parameter
140 std::function<const ast::Expression*()> local_index;
141 for (auto* param : fn->params) {
142 if (auto* builtin =
143 ast::GetDecoration<ast::BuiltinDecoration>(param->decorations)) {
144 if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
145 local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
146 break;
147 }
148 }
149
150 if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
151 for (auto* member : str->Members()) {
152 if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
153 member->Declaration()->decorations)) {
154 if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
155 local_index = [=] {
156 auto* param_expr = b.Expr(ctx.Clone(param->symbol));
157 auto member_name = ctx.Clone(member->Declaration()->symbol);
158 return b.MemberAccessor(param_expr, member_name);
159 };
160 break;
161 }
162 }
163 }
164 }
165 }
166 if (!local_index) {
167 // No existing local index parameter. Append one to the entry point.
168 auto* param =
169 b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
170 {b.Builtin(ast::Builtin::kLocalInvocationIndex)});
171 ctx.InsertBack(fn->params, param);
172 local_index = [=] { return b.Expr(param->symbol); };
173 }
174
175 // Take the zeroing statements and bin them by the number of iterations
176 // required to zero the workgroup data. We then emit these in blocks,
177 // possibly wrapped in if-statements or for-loops.
178 std::unordered_map<uint32_t, std::vector<Statement>>
179 stmts_by_num_iterations;
180 std::vector<uint32_t> num_sorted_iterations;
181 for (auto& s : statements) {
182 auto& stmts = stmts_by_num_iterations[s.num_iterations];
183 if (stmts.empty()) {
184 num_sorted_iterations.emplace_back(s.num_iterations);
185 }
186 stmts.emplace_back(s);
187 }
188 std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
189
190 // Loop over the statements, grouped by num_iterations.
191 for (auto num_iterations : num_sorted_iterations) {
192 auto& stmts = stmts_by_num_iterations[num_iterations];
193
194 // Gather all the array indices used by all the statements in the block.
195 ArrayIndices array_indices;
196 for (auto& s : stmts) {
197 for (auto& idx : s.array_indices) {
198 array_indices.add(idx);
199 }
200 }
201
202 // Determine the block type used to emit these statements.
203
204 if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
205 // Either the workgroup size is dynamic, or smaller than num_iterations.
206 // In either case, we need to generate a for loop to ensure we
207 // initialize all the array elements.
208 //
209 // for (var idx : u32 = local_index;
210 // idx < num_iterations;
211 // idx += workgroup_size) {
212 // ...
213 // }
214 auto idx = b.Symbols().New("idx");
215 auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
216 auto* cond = b.create<ast::BinaryExpression>(
217 ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations));
218 auto* cont = b.Assign(
219 idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
220 : workgroup_size_expr()));
221
222 auto block = DeclareArrayIndices(num_iterations, array_indices,
223 [&] { return b.Expr(idx); });
224 for (auto& s : stmts) {
225 block.emplace_back(s.stmt);
226 }
227 auto* for_loop = b.For(init, cond, cont, b.Block(block));
228 ctx.InsertFront(fn->body->statements, for_loop);
229 } else if (num_iterations < workgroup_size_const) {
230 // Workgroup size is a known constant, but is greater than
231 // num_iterations. Emit an if statement:
232 //
233 // if (local_index < num_iterations) {
234 // ...
235 // }
236 auto* cond = b.create<ast::BinaryExpression>(
237 ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations));
238 auto block = DeclareArrayIndices(num_iterations, array_indices,
239 [&] { return b.Expr(local_index()); });
240 for (auto& s : stmts) {
241 block.emplace_back(s.stmt);
242 }
243 auto* if_stmt = b.If(cond, b.Block(block));
244 ctx.InsertFront(fn->body->statements, if_stmt);
245 } else {
246 // Workgroup size exactly equals num_iterations.
247 // No need for any conditionals. Just emit a basic block:
248 //
249 // {
250 // ...
251 // }
252 auto block = DeclareArrayIndices(num_iterations, array_indices,
253 [&] { return b.Expr(local_index()); });
254 for (auto& s : stmts) {
255 block.emplace_back(s.stmt);
256 }
257 ctx.InsertFront(fn->body->statements, b.Block(block));
258 }
259 }
260
261 // Append a single workgroup barrier after the zero initialization.
262 ctx.InsertFront(fn->body->statements,
263 b.CallStmt(b.Call("workgroupBarrier")));
264 }
265
266 /// BuildZeroingExpr is a function that builds a sub-expression used to zero
267 /// workgroup values. `num_values` is the number of elements that the
268 /// expression will be used to zero. Returns the expression.
269 using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
270
271 /// BuildZeroingStatements() generates the statements required to zero
272 /// initialize the workgroup storage expression of type `ty`.
273 /// @param ty the expression type
274 /// @param get_expr a function that builds the AST nodes for the expression.
BuildZeroingStatementstint::transform::ZeroInitWorkgroupMemory::State275 void BuildZeroingStatements(const sem::Type* ty,
276 const BuildZeroingExpr& get_expr) {
277 if (CanTriviallyZero(ty)) {
278 auto var = get_expr(1u);
279 auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
280 statements.emplace_back(Statement{b.Assign(var.expr, zero_init),
281 var.num_iterations, var.array_indices});
282 return;
283 }
284
285 if (auto* atomic = ty->As<sem::Atomic>()) {
286 auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
287 auto expr = get_expr(1u);
288 auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
289 statements.emplace_back(Statement{b.CallStmt(store), expr.num_iterations,
290 expr.array_indices});
291 return;
292 }
293
294 if (auto* str = ty->As<sem::Struct>()) {
295 for (auto* member : str->Members()) {
296 auto name = ctx.Clone(member->Declaration()->symbol);
297 BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
298 auto s = get_expr(num_values);
299 return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
300 s.array_indices};
301 });
302 }
303 return;
304 }
305
306 if (auto* arr = ty->As<sem::Array>()) {
307 BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
308 // num_values is the number of values to zero for the element type.
309 // The number of iterations required to zero the array and its elements
310 // is:
311 // `num_values * arr->Count()`
312 // The index for this array is:
313 // `(idx % modulo) / division`
314 auto modulo = num_values * arr->Count();
315 auto division = num_values;
316 auto a = get_expr(modulo);
317 auto array_indices = a.array_indices;
318 array_indices.add(ArrayIndex{modulo, division});
319 auto index =
320 utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
321 [&] { return b.Symbols().New("i"); });
322 return Expression{b.IndexAccessor(a.expr, index), a.num_iterations,
323 array_indices};
324 });
325 return;
326 }
327
328 TINT_UNREACHABLE(Transform, b.Diagnostics())
329 << "could not zero workgroup type: " << ty->type_name();
330 }
331
332 /// DeclareArrayIndices returns a list of statements that contain the `let`
333 /// declarations for all of the ArrayIndices.
334 /// @param num_iterations the number of iterations for the block
335 /// @param array_indices the list of array indices to generate `let`
336 /// declarations for
337 /// @param iteration a function that returns the index of the current
338 /// iteration.
339 /// @returns the list of `let` statements that declare the array indices
DeclareArrayIndicestint::transform::ZeroInitWorkgroupMemory::State340 ast::StatementList DeclareArrayIndices(
341 uint32_t num_iterations,
342 const ArrayIndices& array_indices,
343 const std::function<const ast::Expression*()>& iteration) {
344 ast::StatementList stmts;
345 std::map<Symbol, ArrayIndex> indices_by_name;
346 for (auto index : array_indices) {
347 auto name = array_index_names.at(index);
348 auto* mod =
349 (num_iterations > index.modulo)
350 ? b.create<ast::BinaryExpression>(
351 ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo))
352 : iteration();
353 auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
354 auto* decl = b.Decl(b.Const(name, b.ty.u32(), div));
355 stmts.emplace_back(decl);
356 }
357 return stmts;
358 }
359
360 /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
361 /// #workgroup_size_expr with the linear workgroup size.
362 /// @param deco the workgroup decoration applied to the entry point function
CalculateWorkgroupSizetint::transform::ZeroInitWorkgroupMemory::State363 void CalculateWorkgroupSize(const ast::WorkgroupDecoration* deco) {
364 bool is_signed = false;
365 workgroup_size_const = 1u;
366 workgroup_size_expr = nullptr;
367 for (auto* expr : deco->Values()) {
368 if (!expr) {
369 continue;
370 }
371 auto* sem = ctx.src->Sem().Get(expr);
372 if (auto c = sem->ConstantValue()) {
373 if (c.ElementType()->Is<sem::I32>()) {
374 workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
375 continue;
376 } else if (c.ElementType()->Is<sem::U32>()) {
377 workgroup_size_const *= c.Elements()[0].u32;
378 continue;
379 }
380 }
381 // Constant value could not be found. Build expression instead.
382 workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
383 auto* e = ctx.Clone(expr);
384 if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
385 e = b.Construct<ProgramBuilder::u32>(e);
386 }
387 return size ? b.Mul(size(), e) : e;
388 };
389 }
390 if (workgroup_size_expr) {
391 if (workgroup_size_const != 1) {
392 // Fold workgroup_size_const in to workgroup_size_expr
393 workgroup_size_expr = [this, is_signed,
394 const_size = workgroup_size_const,
395 expr_size = workgroup_size_expr] {
396 return is_signed
397 ? b.Mul(expr_size(), static_cast<int32_t>(const_size))
398 : b.Mul(expr_size(), const_size);
399 };
400 }
401 // Indicate that workgroup_size_expr should be used instead of the
402 // constant.
403 workgroup_size_const = 0;
404 }
405 }
406
407 /// @returns true if a variable with store type `ty` can be efficiently zeroed
408 /// by assignment of a type constructor without operands. If
409 /// CanTriviallyZero() returns false, then the type needs to be
410 /// initialized by decomposing the initialization into multiple
411 /// sub-initializations.
412 /// @param ty the type to inspect
CanTriviallyZerotint::transform::ZeroInitWorkgroupMemory::State413 bool CanTriviallyZero(const sem::Type* ty) {
414 if (ty->Is<sem::Atomic>()) {
415 return false;
416 }
417 if (auto* str = ty->As<sem::Struct>()) {
418 for (auto* member : str->Members()) {
419 if (!CanTriviallyZero(member->Type())) {
420 return false;
421 }
422 }
423 }
424 if (ty->Is<sem::Array>()) {
425 return false;
426 }
427 // True for all other storable types
428 return true;
429 }
430 };
431
432 ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
433
434 ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
435
Run(CloneContext & ctx,const DataMap &,DataMap &)436 void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
437 for (auto* fn : ctx.src->AST().Functions()) {
438 if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
439 State{ctx}.Run(fn);
440 }
441 }
442 ctx.Clone();
443 }
444
445 } // namespace transform
446 } // namespace tint
447