• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/module_scope_var_to_entry_point_param.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <utility>
20 #include <vector>
21 
22 #include "src/ast/disable_validation_decoration.h"
23 #include "src/program_builder.h"
24 #include "src/sem/call.h"
25 #include "src/sem/function.h"
26 #include "src/sem/statement.h"
27 #include "src/sem/variable.h"
28 
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
30 
31 namespace tint {
32 namespace transform {
33 namespace {
34 // Returns `true` if `type` is or contains a matrix type.
ContainsMatrix(const sem::Type * type)35 bool ContainsMatrix(const sem::Type* type) {
36   type = type->UnwrapRef();
37   if (type->Is<sem::Matrix>()) {
38     return true;
39   } else if (auto* ary = type->As<sem::Array>()) {
40     return ContainsMatrix(ary->ElemType());
41   } else if (auto* str = type->As<sem::Struct>()) {
42     for (auto* member : str->Members()) {
43       if (ContainsMatrix(member->Type())) {
44         return true;
45       }
46     }
47   }
48   return false;
49 }
50 }  // namespace
51 
52 /// State holds the current transform state.
53 struct ModuleScopeVarToEntryPointParam::State {
54   /// The clone context.
55   CloneContext& ctx;
56 
57   /// Constructor
58   /// @param context the clone context
Statetint::transform::ModuleScopeVarToEntryPointParam::State59   explicit State(CloneContext& context) : ctx(context) {}
60 
61   /// Clone any struct types that are contained in `ty` (including `ty` itself),
62   /// and add it to the global declarations now, so that they precede new global
63   /// declarations that need to reference them.
64   /// @param ty the type to clone
CloneStructTypestint::transform::ModuleScopeVarToEntryPointParam::State65   void CloneStructTypes(const sem::Type* ty) {
66     if (auto* str = ty->As<sem::Struct>()) {
67       if (!cloned_structs_.emplace(str).second) {
68         // The struct has already been cloned.
69         return;
70       }
71 
72       // Recurse into members.
73       for (auto* member : str->Members()) {
74         CloneStructTypes(member->Type());
75       }
76 
77       // Clone the struct and add it to the global declaration list.
78       // Remove the old declaration.
79       auto* ast_str = str->Declaration();
80       ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
81       ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
82     } else if (auto* arr = ty->As<sem::Array>()) {
83       CloneStructTypes(arr->ElemType());
84     }
85   }
86 
87   /// Process the module.
Processtint::transform::ModuleScopeVarToEntryPointParam::State88   void Process() {
89     // Predetermine the list of function calls that need to be replaced.
90     using CallList = std::vector<const ast::CallExpression*>;
91     std::unordered_map<const ast::Function*, CallList> calls_to_replace;
92 
93     std::vector<const ast::Function*> functions_to_process;
94 
95     // Build a list of functions that transitively reference any module-scope
96     // variables.
97     for (auto* func_ast : ctx.src->AST().Functions()) {
98       auto* func_sem = ctx.src->Sem().Get(func_ast);
99 
100       bool needs_processing = false;
101       for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
102         if (var->StorageClass() != ast::StorageClass::kNone) {
103           needs_processing = true;
104           break;
105         }
106       }
107       if (needs_processing) {
108         functions_to_process.push_back(func_ast);
109 
110         // Find all of the calls to this function that will need to be replaced.
111         for (auto* call : func_sem->CallSites()) {
112           calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
113               call->Declaration());
114         }
115       }
116     }
117 
118     // Build a list of `&ident` expressions. We'll use this later to avoid
119     // generating expressions of the form `&*ident`, which break WGSL validation
120     // rules when this expression is passed to a function.
121     // TODO(jrprice): We should add support for bidirectional SEM tree traversal
122     // so that we can do this on the fly instead.
123     std::unordered_map<const ast::IdentifierExpression*,
124                        const ast::UnaryOpExpression*>
125         ident_to_address_of;
126     for (auto* node : ctx.src->ASTNodes().Objects()) {
127       auto* address_of = node->As<ast::UnaryOpExpression>();
128       if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
129         continue;
130       }
131       if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
132         ident_to_address_of[ident] = address_of;
133       }
134     }
135 
136     for (auto* func_ast : functions_to_process) {
137       auto* func_sem = ctx.src->Sem().Get(func_ast);
138       bool is_entry_point = func_ast->IsEntryPoint();
139 
140       // Map module-scope variables onto their replacement.
141       struct NewVar {
142         Symbol symbol;
143         bool is_pointer;
144       };
145       std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
146 
147       // We aggregate all workgroup variables into a struct to avoid hitting
148       // MSL's limit for threadgroup memory arguments.
149       Symbol workgroup_parameter_symbol;
150       ast::StructMemberList workgroup_parameter_members;
151       auto workgroup_param = [&]() {
152         if (!workgroup_parameter_symbol.IsValid()) {
153           workgroup_parameter_symbol = ctx.dst->Sym();
154         }
155         return workgroup_parameter_symbol;
156       };
157 
158       for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
159         auto sc = var->StorageClass();
160         if (sc == ast::StorageClass::kNone) {
161           continue;
162         }
163         if (sc != ast::StorageClass::kPrivate &&
164             sc != ast::StorageClass::kStorage &&
165             sc != ast::StorageClass::kUniform &&
166             sc != ast::StorageClass::kUniformConstant &&
167             sc != ast::StorageClass::kWorkgroup) {
168           TINT_ICE(Transform, ctx.dst->Diagnostics())
169               << "unhandled module-scope storage class (" << sc << ")";
170         }
171 
172         // This is the symbol for the variable that replaces the module-scope
173         // var.
174         auto new_var_symbol = ctx.dst->Sym();
175 
176         // Helper to create an AST node for the store type of the variable.
177         auto store_type = [&]() {
178           return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
179         };
180 
181         // Track whether the new variable is a pointer or not.
182         bool is_pointer = false;
183 
184         if (is_entry_point) {
185           if (var->Type()->UnwrapRef()->is_handle()) {
186             // For a texture or sampler variable, redeclare it as an entry point
187             // parameter. Disable entry point parameter validation.
188             auto* disable_validation =
189                 ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
190             auto decos = ctx.Clone(var->Declaration()->decorations);
191             decos.push_back(disable_validation);
192             auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
193             ctx.InsertFront(func_ast->params, param);
194           } else if (sc == ast::StorageClass::kStorage ||
195                      sc == ast::StorageClass::kUniform) {
196             // Variables into the Storage and Uniform storage classes are
197             // redeclared as entry point parameters with a pointer type.
198             auto attributes = ctx.Clone(var->Declaration()->decorations);
199             attributes.push_back(ctx.dst->Disable(
200                 ast::DisabledValidation::kEntryPointParameter));
201             attributes.push_back(
202                 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
203             auto* param_type = ctx.dst->ty.pointer(
204                 store_type(), sc, var->Declaration()->declared_access);
205             auto* param =
206                 ctx.dst->Param(new_var_symbol, param_type, attributes);
207             ctx.InsertFront(func_ast->params, param);
208             is_pointer = true;
209           } else if (sc == ast::StorageClass::kWorkgroup &&
210                      ContainsMatrix(var->Type())) {
211             // Due to a bug in the MSL compiler, we use a threadgroup memory
212             // argument for any workgroup allocation that contains a matrix.
213             // See crbug.com/tint/938.
214             // TODO(jrprice): Do this for all other workgroup variables too.
215 
216             // Create a member in the workgroup parameter struct.
217             auto member = ctx.Clone(var->Declaration()->symbol);
218             workgroup_parameter_members.push_back(
219                 ctx.dst->Member(member, store_type()));
220             CloneStructTypes(var->Type()->UnwrapRef());
221 
222             // Create a function-scope variable that is a pointer to the member.
223             auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
224                 ctx.dst->Deref(workgroup_param()), member));
225             auto* local_var =
226                 ctx.dst->Const(new_var_symbol,
227                                ctx.dst->ty.pointer(
228                                    store_type(), ast::StorageClass::kWorkgroup),
229                                member_ptr);
230             ctx.InsertFront(func_ast->body->statements,
231                             ctx.dst->Decl(local_var));
232             is_pointer = true;
233           } else {
234             // Variables in the Private and Workgroup storage classes are
235             // redeclared at function scope. Disable storage class validation on
236             // this variable.
237             auto* disable_validation =
238                 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
239             auto* constructor = ctx.Clone(var->Declaration()->constructor);
240             auto* local_var =
241                 ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
242                              ast::DecorationList{disable_validation});
243             ctx.InsertFront(func_ast->body->statements,
244                             ctx.dst->Decl(local_var));
245           }
246         } else {
247           // For a regular function, redeclare the variable as a parameter.
248           // Use a pointer for non-handle types.
249           auto* param_type = store_type();
250           ast::DecorationList attributes;
251           if (!var->Type()->UnwrapRef()->is_handle()) {
252             param_type = ctx.dst->ty.pointer(
253                 param_type, sc, var->Declaration()->declared_access);
254             is_pointer = true;
255 
256             // Disable validation of the parameter's storage class and of
257             // arguments passed it.
258             attributes.push_back(
259                 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
260             attributes.push_back(ctx.dst->Disable(
261                 ast::DisabledValidation::kIgnoreInvalidPointerArgument));
262           }
263           ctx.InsertBack(
264               func_ast->params,
265               ctx.dst->Param(new_var_symbol, param_type, attributes));
266         }
267 
268         // Replace all uses of the module-scope variable.
269         // For non-entry points, dereference non-handle pointer parameters.
270         for (auto* user : var->Users()) {
271           if (user->Stmt()->Function()->Declaration() == func_ast) {
272             const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
273             if (is_pointer) {
274               // If this identifier is used by an address-of operator, just
275               // remove the address-of instead of adding a deref, since we
276               // already have a pointer.
277               auto* ident =
278                   user->Declaration()->As<ast::IdentifierExpression>();
279               if (ident_to_address_of.count(ident)) {
280                 ctx.Replace(ident_to_address_of[ident], expr);
281                 continue;
282               }
283 
284               expr = ctx.dst->Deref(expr);
285             }
286             ctx.Replace(user->Declaration(), expr);
287           }
288         }
289 
290         var_to_newvar[var] = {new_var_symbol, is_pointer};
291       }
292 
293       if (!workgroup_parameter_members.empty()) {
294         // Create the workgroup memory parameter.
295         // The parameter is a struct that contains members for each workgroup
296         // variable.
297         auto* str = ctx.dst->Structure(ctx.dst->Sym(),
298                                        std::move(workgroup_parameter_members));
299         auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
300                                                ast::StorageClass::kWorkgroup);
301         auto* disable_validation =
302             ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
303         auto* param =
304             ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
305         ctx.InsertFront(func_ast->params, param);
306       }
307 
308       // Pass the variables as pointers to any functions that need them.
309       for (auto* call : calls_to_replace[func_ast]) {
310         auto* target =
311             ctx.src->AST().Functions().Find(call->target.name->symbol);
312         auto* target_sem = ctx.src->Sem().Get(target);
313 
314         // Add new arguments for any variables that are needed by the callee.
315         // For entry points, pass non-handle types as pointers.
316         for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
317           auto sc = target_var->StorageClass();
318           if (sc == ast::StorageClass::kNone) {
319             continue;
320           }
321 
322           auto new_var = var_to_newvar[target_var];
323           bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
324           const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
325           if (is_entry_point && !is_handle && !new_var.is_pointer) {
326             // We need to pass a pointer and we don't already have one, so take
327             // the address of the new variable.
328             arg = ctx.dst->AddressOf(arg);
329           }
330           ctx.InsertBack(call->args, arg);
331         }
332       }
333     }
334 
335     // Now remove all module-scope variables with these storage classes.
336     for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
337       auto* var_sem = ctx.src->Sem().Get(var_ast);
338       if (var_sem->StorageClass() != ast::StorageClass::kNone) {
339         ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
340       }
341     }
342   }
343 
344  private:
345   std::unordered_set<const sem::Struct*> cloned_structs_;
346 };
347 
348 ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
349 
350 ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
351 
Run(CloneContext & ctx,const DataMap &,DataMap &)352 void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
353                                           const DataMap&,
354                                           DataMap&) {
355   State state{ctx};
356   state.Process();
357   ctx.Clone();
358 }
359 
360 }  // namespace transform
361 }  // namespace tint
362