• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/zero_init_workgroup_memory.h"
16 
17 #include <algorithm>
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "src/ast/workgroup_decoration.h"
24 #include "src/program_builder.h"
25 #include "src/sem/atomic_type.h"
26 #include "src/sem/function.h"
27 #include "src/sem/variable.h"
28 #include "src/utils/map.h"
29 #include "src/utils/unique_vector.h"
30 
31 TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
32 
33 namespace tint {
34 namespace transform {
35 
36 /// PIMPL state for the ZeroInitWorkgroupMemory transform
37 struct ZeroInitWorkgroupMemory::State {
38   /// The clone context
39   CloneContext& ctx;
40 
41   /// An alias to *ctx.dst
42   ProgramBuilder& b = *ctx.dst;
43 
44   /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
45   /// be used instead.
46   uint32_t workgroup_size_const = 0;
47   /// The size of the workgroup as an expression generator. Use if
48   /// #workgroup_size_const is 0.
49   std::function<const ast::Expression*()> workgroup_size_expr;
50 
51   /// ArrayIndex represents a function on the local invocation index, of
52   /// the form: `array_index = (local_invocation_index % modulo) / division`
53   struct ArrayIndex {
54     /// The RHS of the modulus part of the expression
55     uint32_t modulo = 1;
56     /// The RHS of the division part of the expression
57     uint32_t division = 1;
58 
59     /// Equality operator
60     /// @param i the ArrayIndex to compare to this ArrayIndex
61     /// @returns true if `i` and this ArrayIndex are equal
operator ==tint::transform::ZeroInitWorkgroupMemory::State::ArrayIndex62     bool operator==(const ArrayIndex& i) const {
63       return modulo == i.modulo && division == i.division;
64     }
65 
66     /// Hash function for the ArrayIndex type
67     struct Hasher {
68       /// @param i the ArrayIndex to calculate a hash for
69       /// @returns the hash value for the ArrayIndex `i`
operator ()tint::transform::ZeroInitWorkgroupMemory::State::ArrayIndex::Hasher70       size_t operator()(const ArrayIndex& i) const {
71         return utils::Hash(i.modulo, i.division);
72       }
73     };
74   };
75 
76   /// A list of unique ArrayIndex
77   using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
78 
79   /// Expression holds information about an expression that is being built for a
80   /// statement will zero workgroup values.
81   struct Expression {
82     /// The AST expression node
83     const ast::Expression* expr = nullptr;
84     /// The number of iterations required to zero the value
85     uint32_t num_iterations = 0;
86     /// All array indices used by this expression
87     ArrayIndices array_indices;
88   };
89 
90   /// Statement holds information about a statement that will zero workgroup
91   /// values.
92   struct Statement {
93     /// The AST statement node
94     const ast::Statement* stmt;
95     /// The number of iterations required to zero the value
96     uint32_t num_iterations;
97     /// All array indices used by this statement
98     ArrayIndices array_indices;
99   };
100 
101   /// All statements that zero workgroup memory
102   std::vector<Statement> statements;
103 
104   /// A map of ArrayIndex to the name reserved for the `let` declaration of that
105   /// index.
106   std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
107 
108   /// Constructor
109   /// @param c the CloneContext used for the transform
Statetint::transform::ZeroInitWorkgroupMemory::State110   explicit State(CloneContext& c) : ctx(c) {}
111 
112   /// Run inserts the workgroup memory zero-initialization logic at the top of
113   /// the given function
114   /// @param fn a compute shader entry point function
Runtint::transform::ZeroInitWorkgroupMemory::State115   void Run(const ast::Function* fn) {
116     auto& sem = ctx.src->Sem();
117 
118     CalculateWorkgroupSize(
119         ast::GetDecoration<ast::WorkgroupDecoration>(fn->decorations));
120 
121     // Generate a list of statements to zero initialize each of the
122     // workgroup storage variables used by `fn`. This will populate #statements.
123     auto* func = sem.Get(fn);
124     for (auto* var : func->TransitivelyReferencedGlobals()) {
125       if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
126         BuildZeroingStatements(
127             var->Type()->UnwrapRef(), [&](uint32_t num_values) {
128               auto var_name = ctx.Clone(var->Declaration()->symbol);
129               return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
130             });
131       }
132     }
133 
134     if (statements.empty()) {
135       return;  // No workgroup variables to initialize.
136     }
137 
138     // Scan the entry point for an existing local_invocation_index builtin
139     // parameter
140     std::function<const ast::Expression*()> local_index;
141     for (auto* param : fn->params) {
142       if (auto* builtin =
143               ast::GetDecoration<ast::BuiltinDecoration>(param->decorations)) {
144         if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
145           local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
146           break;
147         }
148       }
149 
150       if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
151         for (auto* member : str->Members()) {
152           if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
153                   member->Declaration()->decorations)) {
154             if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
155               local_index = [=] {
156                 auto* param_expr = b.Expr(ctx.Clone(param->symbol));
157                 auto member_name = ctx.Clone(member->Declaration()->symbol);
158                 return b.MemberAccessor(param_expr, member_name);
159               };
160               break;
161             }
162           }
163         }
164       }
165     }
166     if (!local_index) {
167       // No existing local index parameter. Append one to the entry point.
168       auto* param =
169           b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
170                   {b.Builtin(ast::Builtin::kLocalInvocationIndex)});
171       ctx.InsertBack(fn->params, param);
172       local_index = [=] { return b.Expr(param->symbol); };
173     }
174 
175     // Take the zeroing statements and bin them by the number of iterations
176     // required to zero the workgroup data. We then emit these in blocks,
177     // possibly wrapped in if-statements or for-loops.
178     std::unordered_map<uint32_t, std::vector<Statement>>
179         stmts_by_num_iterations;
180     std::vector<uint32_t> num_sorted_iterations;
181     for (auto& s : statements) {
182       auto& stmts = stmts_by_num_iterations[s.num_iterations];
183       if (stmts.empty()) {
184         num_sorted_iterations.emplace_back(s.num_iterations);
185       }
186       stmts.emplace_back(s);
187     }
188     std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
189 
190     // Loop over the statements, grouped by num_iterations.
191     for (auto num_iterations : num_sorted_iterations) {
192       auto& stmts = stmts_by_num_iterations[num_iterations];
193 
194       // Gather all the array indices used by all the statements in the block.
195       ArrayIndices array_indices;
196       for (auto& s : stmts) {
197         for (auto& idx : s.array_indices) {
198           array_indices.add(idx);
199         }
200       }
201 
202       // Determine the block type used to emit these statements.
203 
204       if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
205         // Either the workgroup size is dynamic, or smaller than num_iterations.
206         // In either case, we need to generate a for loop to ensure we
207         // initialize all the array elements.
208         //
209         //  for (var idx : u32 = local_index;
210         //           idx < num_iterations;
211         //           idx += workgroup_size) {
212         //    ...
213         //  }
214         auto idx = b.Symbols().New("idx");
215         auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
216         auto* cond = b.create<ast::BinaryExpression>(
217             ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations));
218         auto* cont = b.Assign(
219             idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
220                                                  : workgroup_size_expr()));
221 
222         auto block = DeclareArrayIndices(num_iterations, array_indices,
223                                          [&] { return b.Expr(idx); });
224         for (auto& s : stmts) {
225           block.emplace_back(s.stmt);
226         }
227         auto* for_loop = b.For(init, cond, cont, b.Block(block));
228         ctx.InsertFront(fn->body->statements, for_loop);
229       } else if (num_iterations < workgroup_size_const) {
230         // Workgroup size is a known constant, but is greater than
231         // num_iterations. Emit an if statement:
232         //
233         //  if (local_index < num_iterations) {
234         //    ...
235         //  }
236         auto* cond = b.create<ast::BinaryExpression>(
237             ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations));
238         auto block = DeclareArrayIndices(num_iterations, array_indices,
239                                          [&] { return b.Expr(local_index()); });
240         for (auto& s : stmts) {
241           block.emplace_back(s.stmt);
242         }
243         auto* if_stmt = b.If(cond, b.Block(block));
244         ctx.InsertFront(fn->body->statements, if_stmt);
245       } else {
246         // Workgroup size exactly equals num_iterations.
247         // No need for any conditionals. Just emit a basic block:
248         //
249         // {
250         //    ...
251         // }
252         auto block = DeclareArrayIndices(num_iterations, array_indices,
253                                          [&] { return b.Expr(local_index()); });
254         for (auto& s : stmts) {
255           block.emplace_back(s.stmt);
256         }
257         ctx.InsertFront(fn->body->statements, b.Block(block));
258       }
259     }
260 
261     // Append a single workgroup barrier after the zero initialization.
262     ctx.InsertFront(fn->body->statements,
263                     b.CallStmt(b.Call("workgroupBarrier")));
264   }
265 
266   /// BuildZeroingExpr is a function that builds a sub-expression used to zero
267   /// workgroup values. `num_values` is the number of elements that the
268   /// expression will be used to zero. Returns the expression.
269   using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
270 
271   /// BuildZeroingStatements() generates the statements required to zero
272   /// initialize the workgroup storage expression of type `ty`.
273   /// @param ty the expression type
274   /// @param get_expr a function that builds the AST nodes for the expression.
BuildZeroingStatementstint::transform::ZeroInitWorkgroupMemory::State275   void BuildZeroingStatements(const sem::Type* ty,
276                               const BuildZeroingExpr& get_expr) {
277     if (CanTriviallyZero(ty)) {
278       auto var = get_expr(1u);
279       auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
280       statements.emplace_back(Statement{b.Assign(var.expr, zero_init),
281                                         var.num_iterations, var.array_indices});
282       return;
283     }
284 
285     if (auto* atomic = ty->As<sem::Atomic>()) {
286       auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
287       auto expr = get_expr(1u);
288       auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
289       statements.emplace_back(Statement{b.CallStmt(store), expr.num_iterations,
290                                         expr.array_indices});
291       return;
292     }
293 
294     if (auto* str = ty->As<sem::Struct>()) {
295       for (auto* member : str->Members()) {
296         auto name = ctx.Clone(member->Declaration()->symbol);
297         BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
298           auto s = get_expr(num_values);
299           return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
300                             s.array_indices};
301         });
302       }
303       return;
304     }
305 
306     if (auto* arr = ty->As<sem::Array>()) {
307       BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
308         // num_values is the number of values to zero for the element type.
309         // The number of iterations required to zero the array and its elements
310         // is:
311         //      `num_values * arr->Count()`
312         // The index for this array is:
313         //      `(idx % modulo) / division`
314         auto modulo = num_values * arr->Count();
315         auto division = num_values;
316         auto a = get_expr(modulo);
317         auto array_indices = a.array_indices;
318         array_indices.add(ArrayIndex{modulo, division});
319         auto index =
320             utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
321                                [&] { return b.Symbols().New("i"); });
322         return Expression{b.IndexAccessor(a.expr, index), a.num_iterations,
323                           array_indices};
324       });
325       return;
326     }
327 
328     TINT_UNREACHABLE(Transform, b.Diagnostics())
329         << "could not zero workgroup type: " << ty->type_name();
330   }
331 
332   /// DeclareArrayIndices returns a list of statements that contain the `let`
333   /// declarations for all of the ArrayIndices.
334   /// @param num_iterations the number of iterations for the block
335   /// @param array_indices the list of array indices to generate `let`
336   ///        declarations for
337   /// @param iteration a function that returns the index of the current
338   ///         iteration.
339   /// @returns the list of `let` statements that declare the array indices
DeclareArrayIndicestint::transform::ZeroInitWorkgroupMemory::State340   ast::StatementList DeclareArrayIndices(
341       uint32_t num_iterations,
342       const ArrayIndices& array_indices,
343       const std::function<const ast::Expression*()>& iteration) {
344     ast::StatementList stmts;
345     std::map<Symbol, ArrayIndex> indices_by_name;
346     for (auto index : array_indices) {
347       auto name = array_index_names.at(index);
348       auto* mod =
349           (num_iterations > index.modulo)
350               ? b.create<ast::BinaryExpression>(
351                     ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo))
352               : iteration();
353       auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
354       auto* decl = b.Decl(b.Const(name, b.ty.u32(), div));
355       stmts.emplace_back(decl);
356     }
357     return stmts;
358   }
359 
360   /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
361   /// #workgroup_size_expr with the linear workgroup size.
362   /// @param deco the workgroup decoration applied to the entry point function
CalculateWorkgroupSizetint::transform::ZeroInitWorkgroupMemory::State363   void CalculateWorkgroupSize(const ast::WorkgroupDecoration* deco) {
364     bool is_signed = false;
365     workgroup_size_const = 1u;
366     workgroup_size_expr = nullptr;
367     for (auto* expr : deco->Values()) {
368       if (!expr) {
369         continue;
370       }
371       auto* sem = ctx.src->Sem().Get(expr);
372       if (auto c = sem->ConstantValue()) {
373         if (c.ElementType()->Is<sem::I32>()) {
374           workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
375           continue;
376         } else if (c.ElementType()->Is<sem::U32>()) {
377           workgroup_size_const *= c.Elements()[0].u32;
378           continue;
379         }
380       }
381       // Constant value could not be found. Build expression instead.
382       workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
383         auto* e = ctx.Clone(expr);
384         if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
385           e = b.Construct<ProgramBuilder::u32>(e);
386         }
387         return size ? b.Mul(size(), e) : e;
388       };
389     }
390     if (workgroup_size_expr) {
391       if (workgroup_size_const != 1) {
392         // Fold workgroup_size_const in to workgroup_size_expr
393         workgroup_size_expr = [this, is_signed,
394                                const_size = workgroup_size_const,
395                                expr_size = workgroup_size_expr] {
396           return is_signed
397                      ? b.Mul(expr_size(), static_cast<int32_t>(const_size))
398                      : b.Mul(expr_size(), const_size);
399         };
400       }
401       // Indicate that workgroup_size_expr should be used instead of the
402       // constant.
403       workgroup_size_const = 0;
404     }
405   }
406 
407   /// @returns true if a variable with store type `ty` can be efficiently zeroed
408   /// by assignment of a type constructor without operands. If
409   /// CanTriviallyZero() returns false, then the type needs to be
410   /// initialized by decomposing the initialization into multiple
411   /// sub-initializations.
412   /// @param ty the type to inspect
CanTriviallyZerotint::transform::ZeroInitWorkgroupMemory::State413   bool CanTriviallyZero(const sem::Type* ty) {
414     if (ty->Is<sem::Atomic>()) {
415       return false;
416     }
417     if (auto* str = ty->As<sem::Struct>()) {
418       for (auto* member : str->Members()) {
419         if (!CanTriviallyZero(member->Type())) {
420           return false;
421         }
422       }
423     }
424     if (ty->Is<sem::Array>()) {
425       return false;
426     }
427     // True for all other storable types
428     return true;
429   }
430 };
431 
432 ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
433 
434 ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
435 
Run(CloneContext & ctx,const DataMap &,DataMap &)436 void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
437   for (auto* fn : ctx.src->AST().Functions()) {
438     if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
439       State{ctx}.Run(fn);
440     }
441   }
442   ctx.Clone();
443 }
444 
445 }  // namespace transform
446 }  // namespace tint
447