1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/module_scope_var_to_entry_point_param.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <utility>
20 #include <vector>
21
22 #include "src/ast/disable_validation_decoration.h"
23 #include "src/program_builder.h"
24 #include "src/sem/call.h"
25 #include "src/sem/function.h"
26 #include "src/sem/statement.h"
27 #include "src/sem/variable.h"
28
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
30
31 namespace tint {
32 namespace transform {
33 namespace {
34 // Returns `true` if `type` is or contains a matrix type.
ContainsMatrix(const sem::Type * type)35 bool ContainsMatrix(const sem::Type* type) {
36 type = type->UnwrapRef();
37 if (type->Is<sem::Matrix>()) {
38 return true;
39 } else if (auto* ary = type->As<sem::Array>()) {
40 return ContainsMatrix(ary->ElemType());
41 } else if (auto* str = type->As<sem::Struct>()) {
42 for (auto* member : str->Members()) {
43 if (ContainsMatrix(member->Type())) {
44 return true;
45 }
46 }
47 }
48 return false;
49 }
50 } // namespace
51
52 /// State holds the current transform state.
53 struct ModuleScopeVarToEntryPointParam::State {
54 /// The clone context.
55 CloneContext& ctx;
56
57 /// Constructor
58 /// @param context the clone context
Statetint::transform::ModuleScopeVarToEntryPointParam::State59 explicit State(CloneContext& context) : ctx(context) {}
60
61 /// Clone any struct types that are contained in `ty` (including `ty` itself),
62 /// and add it to the global declarations now, so that they precede new global
63 /// declarations that need to reference them.
64 /// @param ty the type to clone
CloneStructTypestint::transform::ModuleScopeVarToEntryPointParam::State65 void CloneStructTypes(const sem::Type* ty) {
66 if (auto* str = ty->As<sem::Struct>()) {
67 if (!cloned_structs_.emplace(str).second) {
68 // The struct has already been cloned.
69 return;
70 }
71
72 // Recurse into members.
73 for (auto* member : str->Members()) {
74 CloneStructTypes(member->Type());
75 }
76
77 // Clone the struct and add it to the global declaration list.
78 // Remove the old declaration.
79 auto* ast_str = str->Declaration();
80 ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
81 ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
82 } else if (auto* arr = ty->As<sem::Array>()) {
83 CloneStructTypes(arr->ElemType());
84 }
85 }
86
87 /// Process the module.
Processtint::transform::ModuleScopeVarToEntryPointParam::State88 void Process() {
89 // Predetermine the list of function calls that need to be replaced.
90 using CallList = std::vector<const ast::CallExpression*>;
91 std::unordered_map<const ast::Function*, CallList> calls_to_replace;
92
93 std::vector<const ast::Function*> functions_to_process;
94
95 // Build a list of functions that transitively reference any module-scope
96 // variables.
97 for (auto* func_ast : ctx.src->AST().Functions()) {
98 auto* func_sem = ctx.src->Sem().Get(func_ast);
99
100 bool needs_processing = false;
101 for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
102 if (var->StorageClass() != ast::StorageClass::kNone) {
103 needs_processing = true;
104 break;
105 }
106 }
107 if (needs_processing) {
108 functions_to_process.push_back(func_ast);
109
110 // Find all of the calls to this function that will need to be replaced.
111 for (auto* call : func_sem->CallSites()) {
112 calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
113 call->Declaration());
114 }
115 }
116 }
117
118 // Build a list of `&ident` expressions. We'll use this later to avoid
119 // generating expressions of the form `&*ident`, which break WGSL validation
120 // rules when this expression is passed to a function.
121 // TODO(jrprice): We should add support for bidirectional SEM tree traversal
122 // so that we can do this on the fly instead.
123 std::unordered_map<const ast::IdentifierExpression*,
124 const ast::UnaryOpExpression*>
125 ident_to_address_of;
126 for (auto* node : ctx.src->ASTNodes().Objects()) {
127 auto* address_of = node->As<ast::UnaryOpExpression>();
128 if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
129 continue;
130 }
131 if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
132 ident_to_address_of[ident] = address_of;
133 }
134 }
135
136 for (auto* func_ast : functions_to_process) {
137 auto* func_sem = ctx.src->Sem().Get(func_ast);
138 bool is_entry_point = func_ast->IsEntryPoint();
139
140 // Map module-scope variables onto their replacement.
141 struct NewVar {
142 Symbol symbol;
143 bool is_pointer;
144 };
145 std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
146
147 // We aggregate all workgroup variables into a struct to avoid hitting
148 // MSL's limit for threadgroup memory arguments.
149 Symbol workgroup_parameter_symbol;
150 ast::StructMemberList workgroup_parameter_members;
151 auto workgroup_param = [&]() {
152 if (!workgroup_parameter_symbol.IsValid()) {
153 workgroup_parameter_symbol = ctx.dst->Sym();
154 }
155 return workgroup_parameter_symbol;
156 };
157
158 for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
159 auto sc = var->StorageClass();
160 if (sc == ast::StorageClass::kNone) {
161 continue;
162 }
163 if (sc != ast::StorageClass::kPrivate &&
164 sc != ast::StorageClass::kStorage &&
165 sc != ast::StorageClass::kUniform &&
166 sc != ast::StorageClass::kUniformConstant &&
167 sc != ast::StorageClass::kWorkgroup) {
168 TINT_ICE(Transform, ctx.dst->Diagnostics())
169 << "unhandled module-scope storage class (" << sc << ")";
170 }
171
172 // This is the symbol for the variable that replaces the module-scope
173 // var.
174 auto new_var_symbol = ctx.dst->Sym();
175
176 // Helper to create an AST node for the store type of the variable.
177 auto store_type = [&]() {
178 return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
179 };
180
181 // Track whether the new variable is a pointer or not.
182 bool is_pointer = false;
183
184 if (is_entry_point) {
185 if (var->Type()->UnwrapRef()->is_handle()) {
186 // For a texture or sampler variable, redeclare it as an entry point
187 // parameter. Disable entry point parameter validation.
188 auto* disable_validation =
189 ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
190 auto decos = ctx.Clone(var->Declaration()->decorations);
191 decos.push_back(disable_validation);
192 auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
193 ctx.InsertFront(func_ast->params, param);
194 } else if (sc == ast::StorageClass::kStorage ||
195 sc == ast::StorageClass::kUniform) {
196 // Variables into the Storage and Uniform storage classes are
197 // redeclared as entry point parameters with a pointer type.
198 auto attributes = ctx.Clone(var->Declaration()->decorations);
199 attributes.push_back(ctx.dst->Disable(
200 ast::DisabledValidation::kEntryPointParameter));
201 attributes.push_back(
202 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
203 auto* param_type = ctx.dst->ty.pointer(
204 store_type(), sc, var->Declaration()->declared_access);
205 auto* param =
206 ctx.dst->Param(new_var_symbol, param_type, attributes);
207 ctx.InsertFront(func_ast->params, param);
208 is_pointer = true;
209 } else if (sc == ast::StorageClass::kWorkgroup &&
210 ContainsMatrix(var->Type())) {
211 // Due to a bug in the MSL compiler, we use a threadgroup memory
212 // argument for any workgroup allocation that contains a matrix.
213 // See crbug.com/tint/938.
214 // TODO(jrprice): Do this for all other workgroup variables too.
215
216 // Create a member in the workgroup parameter struct.
217 auto member = ctx.Clone(var->Declaration()->symbol);
218 workgroup_parameter_members.push_back(
219 ctx.dst->Member(member, store_type()));
220 CloneStructTypes(var->Type()->UnwrapRef());
221
222 // Create a function-scope variable that is a pointer to the member.
223 auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
224 ctx.dst->Deref(workgroup_param()), member));
225 auto* local_var =
226 ctx.dst->Const(new_var_symbol,
227 ctx.dst->ty.pointer(
228 store_type(), ast::StorageClass::kWorkgroup),
229 member_ptr);
230 ctx.InsertFront(func_ast->body->statements,
231 ctx.dst->Decl(local_var));
232 is_pointer = true;
233 } else {
234 // Variables in the Private and Workgroup storage classes are
235 // redeclared at function scope. Disable storage class validation on
236 // this variable.
237 auto* disable_validation =
238 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
239 auto* constructor = ctx.Clone(var->Declaration()->constructor);
240 auto* local_var =
241 ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
242 ast::DecorationList{disable_validation});
243 ctx.InsertFront(func_ast->body->statements,
244 ctx.dst->Decl(local_var));
245 }
246 } else {
247 // For a regular function, redeclare the variable as a parameter.
248 // Use a pointer for non-handle types.
249 auto* param_type = store_type();
250 ast::DecorationList attributes;
251 if (!var->Type()->UnwrapRef()->is_handle()) {
252 param_type = ctx.dst->ty.pointer(
253 param_type, sc, var->Declaration()->declared_access);
254 is_pointer = true;
255
256 // Disable validation of the parameter's storage class and of
257 // arguments passed it.
258 attributes.push_back(
259 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
260 attributes.push_back(ctx.dst->Disable(
261 ast::DisabledValidation::kIgnoreInvalidPointerArgument));
262 }
263 ctx.InsertBack(
264 func_ast->params,
265 ctx.dst->Param(new_var_symbol, param_type, attributes));
266 }
267
268 // Replace all uses of the module-scope variable.
269 // For non-entry points, dereference non-handle pointer parameters.
270 for (auto* user : var->Users()) {
271 if (user->Stmt()->Function()->Declaration() == func_ast) {
272 const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
273 if (is_pointer) {
274 // If this identifier is used by an address-of operator, just
275 // remove the address-of instead of adding a deref, since we
276 // already have a pointer.
277 auto* ident =
278 user->Declaration()->As<ast::IdentifierExpression>();
279 if (ident_to_address_of.count(ident)) {
280 ctx.Replace(ident_to_address_of[ident], expr);
281 continue;
282 }
283
284 expr = ctx.dst->Deref(expr);
285 }
286 ctx.Replace(user->Declaration(), expr);
287 }
288 }
289
290 var_to_newvar[var] = {new_var_symbol, is_pointer};
291 }
292
293 if (!workgroup_parameter_members.empty()) {
294 // Create the workgroup memory parameter.
295 // The parameter is a struct that contains members for each workgroup
296 // variable.
297 auto* str = ctx.dst->Structure(ctx.dst->Sym(),
298 std::move(workgroup_parameter_members));
299 auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
300 ast::StorageClass::kWorkgroup);
301 auto* disable_validation =
302 ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
303 auto* param =
304 ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
305 ctx.InsertFront(func_ast->params, param);
306 }
307
308 // Pass the variables as pointers to any functions that need them.
309 for (auto* call : calls_to_replace[func_ast]) {
310 auto* target =
311 ctx.src->AST().Functions().Find(call->target.name->symbol);
312 auto* target_sem = ctx.src->Sem().Get(target);
313
314 // Add new arguments for any variables that are needed by the callee.
315 // For entry points, pass non-handle types as pointers.
316 for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
317 auto sc = target_var->StorageClass();
318 if (sc == ast::StorageClass::kNone) {
319 continue;
320 }
321
322 auto new_var = var_to_newvar[target_var];
323 bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
324 const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
325 if (is_entry_point && !is_handle && !new_var.is_pointer) {
326 // We need to pass a pointer and we don't already have one, so take
327 // the address of the new variable.
328 arg = ctx.dst->AddressOf(arg);
329 }
330 ctx.InsertBack(call->args, arg);
331 }
332 }
333 }
334
335 // Now remove all module-scope variables with these storage classes.
336 for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
337 auto* var_sem = ctx.src->Sem().Get(var_ast);
338 if (var_sem->StorageClass() != ast::StorageClass::kNone) {
339 ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
340 }
341 }
342 }
343
344 private:
345 std::unordered_set<const sem::Struct*> cloned_structs_;
346 };
347
348 ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
349
350 ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
351
Run(CloneContext & ctx,const DataMap &,DataMap &)352 void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
353 const DataMap&,
354 DataMap&) {
355 State state{ctx};
356 state.Process();
357 ctx.Clone();
358 }
359
360 } // namespace transform
361 } // namespace tint
362