• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/canonicalize_entry_point_io.h"
16 
17 #include <algorithm>
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21 #include <vector>
22 
23 #include "src/ast/disable_validation_decoration.h"
24 #include "src/program_builder.h"
25 #include "src/sem/function.h"
26 #include "src/transform/unshadow.h"
27 
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
30 
31 namespace tint {
32 namespace transform {
33 
34 CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
35 CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
36 
37 namespace {
38 
39 // Comparison function used to reorder struct members such that all members with
40 // location attributes appear first (ordered by location slot), followed by
41 // those with builtin attributes.
StructMemberComparator(const ast::StructMember * a,const ast::StructMember * b)42 bool StructMemberComparator(const ast::StructMember* a,
43                             const ast::StructMember* b) {
44   auto* a_loc = ast::GetDecoration<ast::LocationDecoration>(a->decorations);
45   auto* b_loc = ast::GetDecoration<ast::LocationDecoration>(b->decorations);
46   auto* a_blt = ast::GetDecoration<ast::BuiltinDecoration>(a->decorations);
47   auto* b_blt = ast::GetDecoration<ast::BuiltinDecoration>(b->decorations);
48   if (a_loc) {
49     if (!b_loc) {
50       // `a` has location attribute and `b` does not: `a` goes first.
51       return true;
52     }
53     // Both have location attributes: smallest goes first.
54     return a_loc->value < b_loc->value;
55   } else {
56     if (b_loc) {
57       // `b` has location attribute and `a` does not: `b` goes first.
58       return false;
59     }
60     // Both are builtins: order doesn't matter, just use enum value.
61     return a_blt->builtin < b_blt->builtin;
62   }
63 }
64 
65 // Returns true if `deco` is a shader IO decoration.
IsShaderIODecoration(const ast::Decoration * deco)66 bool IsShaderIODecoration(const ast::Decoration* deco) {
67   return deco->IsAnyOf<ast::BuiltinDecoration, ast::InterpolateDecoration,
68                        ast::InvariantDecoration, ast::LocationDecoration>();
69 }
70 
71 // Returns true if `decos` contains a `sample_mask` builtin.
HasSampleMask(const ast::DecorationList & decos)72 bool HasSampleMask(const ast::DecorationList& decos) {
73   auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos);
74   return builtin && builtin->builtin == ast::Builtin::kSampleMask;
75 }
76 
77 }  // namespace
78 
79 /// State holds the current transform state for a single entry point.
80 struct CanonicalizeEntryPointIO::State {
81   /// OutputValue represents a shader result that the wrapper function produces.
82   struct OutputValue {
83     /// The name of the output value.
84     std::string name;
85     /// The type of the output value.
86     const ast::Type* type;
87     /// The shader IO attributes.
88     ast::DecorationList attributes;
89     /// The value itself.
90     const ast::Expression* value;
91   };
92 
93   /// The clone context.
94   CloneContext& ctx;
95   /// The transform config.
96   CanonicalizeEntryPointIO::Config const cfg;
97   /// The entry point function (AST).
98   const ast::Function* func_ast;
99   /// The entry point function (SEM).
100   const sem::Function* func_sem;
101 
102   /// The new entry point wrapper function's parameters.
103   ast::VariableList wrapper_ep_parameters;
104   /// The members of the wrapper function's struct parameter.
105   ast::StructMemberList wrapper_struct_param_members;
106   /// The name of the wrapper function's struct parameter.
107   Symbol wrapper_struct_param_name;
108   /// The parameters that will be passed to the original function.
109   ast::ExpressionList inner_call_parameters;
110   /// The members of the wrapper function's struct return type.
111   ast::StructMemberList wrapper_struct_output_members;
112   /// The wrapper function output values.
113   std::vector<OutputValue> wrapper_output_values;
114   /// The body of the wrapper function.
115   ast::StatementList wrapper_body;
116   /// Input names used by the entrypoint
117   std::unordered_set<std::string> input_names;
118 
119   /// Constructor
120   /// @param context the clone context
121   /// @param config the transform config
122   /// @param function the entry point function
Statetint::transform::CanonicalizeEntryPointIO::State123   State(CloneContext& context,
124         const CanonicalizeEntryPointIO::Config& config,
125         const ast::Function* function)
126       : ctx(context),
127         cfg(config),
128         func_ast(function),
129         func_sem(ctx.src->Sem().Get(function)) {}
130 
131   /// Clones the shader IO decorations from `src`.
132   /// @param src the decorations to clone
133   /// @return the cloned decorations
CloneShaderIOAttributestint::transform::CanonicalizeEntryPointIO::State134   ast::DecorationList CloneShaderIOAttributes(const ast::DecorationList& src) {
135     ast::DecorationList new_decorations;
136     for (auto* deco : src) {
137       if (IsShaderIODecoration(deco)) {
138         new_decorations.push_back(ctx.Clone(deco));
139       }
140     }
141     return new_decorations;
142   }
143 
144   /// Create or return a symbol for the wrapper function's struct parameter.
145   /// @returns the symbol for the struct parameter
InputStructSymboltint::transform::CanonicalizeEntryPointIO::State146   Symbol InputStructSymbol() {
147     if (!wrapper_struct_param_name.IsValid()) {
148       wrapper_struct_param_name = ctx.dst->Sym();
149     }
150     return wrapper_struct_param_name;
151   }
152 
153   /// Add a shader input to the entry point.
154   /// @param name the name of the shader input
155   /// @param type the type of the shader input
156   /// @param attributes the attributes to apply to the shader input
157   /// @returns an expression which evaluates to the value of the shader input
AddInputtint::transform::CanonicalizeEntryPointIO::State158   const ast::Expression* AddInput(std::string name,
159                                   const sem::Type* type,
160                                   ast::DecorationList attributes) {
161     auto* ast_type = CreateASTTypeFor(ctx, type);
162     if (cfg.shader_style == ShaderStyle::kSpirv) {
163       // Vulkan requires that integer user-defined fragment inputs are
164       // always decorated with `Flat`.
165       // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
166       // attribute is required for integers.
167       if (type->is_integer_scalar_or_vector() &&
168           ast::HasDecoration<ast::LocationDecoration>(attributes) &&
169           !ast::HasDecoration<ast::InterpolateDecoration>(attributes) &&
170           func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
171         attributes.push_back(ctx.dst->Interpolate(
172             ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
173       }
174 
175       // Disable validation for use of the `input` storage class.
176       attributes.push_back(
177           ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
178 
179       // Create the global variable and use its value for the shader input.
180       auto symbol = ctx.dst->Symbols().New(name);
181       const ast::Expression* value = ctx.dst->Expr(symbol);
182       if (HasSampleMask(attributes)) {
183         // Vulkan requires the type of a SampleMask builtin to be an array.
184         // Declare it as array<u32, 1> and then load the first element.
185         ast_type = ctx.dst->ty.array(ast_type, 1);
186         value = ctx.dst->IndexAccessor(value, 0);
187       }
188       ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput,
189                       std::move(attributes));
190       return value;
191     } else if (cfg.shader_style == ShaderStyle::kMsl &&
192                ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
193       // If this input is a builtin and we are targeting MSL, then add it to the
194       // parameter list and pass it directly to the inner function.
195       Symbol symbol = input_names.emplace(name).second
196                           ? ctx.dst->Symbols().Register(name)
197                           : ctx.dst->Symbols().New(name);
198       wrapper_ep_parameters.push_back(
199           ctx.dst->Param(symbol, ast_type, std::move(attributes)));
200       return ctx.dst->Expr(symbol);
201     } else {
202       // Otherwise, move it to the new structure member list.
203       Symbol symbol = input_names.emplace(name).second
204                           ? ctx.dst->Symbols().Register(name)
205                           : ctx.dst->Symbols().New(name);
206       wrapper_struct_param_members.push_back(
207           ctx.dst->Member(symbol, ast_type, std::move(attributes)));
208       return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
209     }
210   }
211 
212   /// Add a shader output to the entry point.
213   /// @param name the name of the shader output
214   /// @param type the type of the shader output
215   /// @param attributes the attributes to apply to the shader output
216   /// @param value the value of the shader output
AddOutputtint::transform::CanonicalizeEntryPointIO::State217   void AddOutput(std::string name,
218                  const sem::Type* type,
219                  ast::DecorationList attributes,
220                  const ast::Expression* value) {
221     // Vulkan requires that integer user-defined vertex outputs are
222     // always decorated with `Flat`.
223     // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
224     // attribute is required for integers.
225     if (cfg.shader_style == ShaderStyle::kSpirv &&
226         type->is_integer_scalar_or_vector() &&
227         ast::HasDecoration<ast::LocationDecoration>(attributes) &&
228         !ast::HasDecoration<ast::InterpolateDecoration>(attributes) &&
229         func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
230       attributes.push_back(ctx.dst->Interpolate(
231           ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
232     }
233 
234     OutputValue output;
235     output.name = name;
236     output.type = CreateASTTypeFor(ctx, type);
237     output.attributes = std::move(attributes);
238     output.value = value;
239     wrapper_output_values.push_back(output);
240   }
241 
242   /// Process a non-struct parameter.
243   /// This creates a new object for the shader input, moving the shader IO
244   /// attributes to it. It also adds an expression to the list of parameters
245   /// that will be passed to the original function.
246   /// @param param the original function parameter
ProcessNonStructParametertint::transform::CanonicalizeEntryPointIO::State247   void ProcessNonStructParameter(const sem::Parameter* param) {
248     // Remove the shader IO attributes from the inner function parameter, and
249     // attach them to the new object instead.
250     ast::DecorationList attributes;
251     for (auto* deco : param->Declaration()->decorations) {
252       if (IsShaderIODecoration(deco)) {
253         ctx.Remove(param->Declaration()->decorations, deco);
254         attributes.push_back(ctx.Clone(deco));
255       }
256     }
257 
258     auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
259     auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
260     inner_call_parameters.push_back(input_expr);
261   }
262 
263   /// Process a struct parameter.
264   /// This creates new objects for each struct member, moving the shader IO
265   /// attributes to them. It also creates the structure that will be passed to
266   /// the original function.
267   /// @param param the original function parameter
ProcessStructParametertint::transform::CanonicalizeEntryPointIO::State268   void ProcessStructParameter(const sem::Parameter* param) {
269     auto* str = param->Type()->As<sem::Struct>();
270 
271     // Recreate struct members in the outer entry point and build an initializer
272     // list to pass them through to the inner function.
273     ast::ExpressionList inner_struct_values;
274     for (auto* member : str->Members()) {
275       if (member->Type()->Is<sem::Struct>()) {
276         TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
277         continue;
278       }
279 
280       auto* member_ast = member->Declaration();
281       auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
282       auto attributes = CloneShaderIOAttributes(member_ast->decorations);
283       auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
284       inner_struct_values.push_back(input_expr);
285     }
286 
287     // Construct the original structure using the new shader input objects.
288     inner_call_parameters.push_back(ctx.dst->Construct(
289         ctx.Clone(param->Declaration()->type), inner_struct_values));
290   }
291 
292   /// Process the entry point return type.
293   /// This generates a list of output values that are returned by the original
294   /// function.
295   /// @param inner_ret_type the original function return type
296   /// @param original_result the result object produced by the original function
ProcessReturnTypetint::transform::CanonicalizeEntryPointIO::State297   void ProcessReturnType(const sem::Type* inner_ret_type,
298                          Symbol original_result) {
299     if (auto* str = inner_ret_type->As<sem::Struct>()) {
300       for (auto* member : str->Members()) {
301         if (member->Type()->Is<sem::Struct>()) {
302           TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
303           continue;
304         }
305 
306         auto* member_ast = member->Declaration();
307         auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
308         auto attributes = CloneShaderIOAttributes(member_ast->decorations);
309 
310         // Extract the original structure member.
311         AddOutput(name, member->Type(), std::move(attributes),
312                   ctx.dst->MemberAccessor(original_result, name));
313       }
314     } else if (!inner_ret_type->Is<sem::Void>()) {
315       auto attributes =
316           CloneShaderIOAttributes(func_ast->return_type_decorations);
317 
318       // Propagate the non-struct return value as is.
319       AddOutput("value", func_sem->ReturnType(), std::move(attributes),
320                 ctx.dst->Expr(original_result));
321     }
322   }
323 
324   /// Add a fixed sample mask to the wrapper function output.
325   /// If there is already a sample mask, bitwise-and it with the fixed mask.
326   /// Otherwise, create a new output value from the fixed mask.
AddFixedSampleMasktint::transform::CanonicalizeEntryPointIO::State327   void AddFixedSampleMask() {
328     // Check the existing output values for a sample mask builtin.
329     for (auto& outval : wrapper_output_values) {
330       if (HasSampleMask(outval.attributes)) {
331         // Combine the authored sample mask with the fixed mask.
332         outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
333         return;
334       }
335     }
336 
337     // No existing sample mask builtin was found, so create a new output value
338     // using the fixed sample mask.
339     AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
340               {ctx.dst->Builtin(ast::Builtin::kSampleMask)},
341               ctx.dst->Expr(cfg.fixed_sample_mask));
342   }
343 
344   /// Add a point size builtin to the wrapper function output.
AddVertexPointSizetint::transform::CanonicalizeEntryPointIO::State345   void AddVertexPointSize() {
346     // Create a new output value and assign it a literal 1.0 value.
347     AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
348               {ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1.f));
349   }
350 
351   /// Create the wrapper function's struct parameter and type objects.
CreateInputStructtint::transform::CanonicalizeEntryPointIO::State352   void CreateInputStruct() {
353     // Sort the struct members to satisfy HLSL interfacing matching rules.
354     std::sort(wrapper_struct_param_members.begin(),
355               wrapper_struct_param_members.end(), StructMemberComparator);
356 
357     // Create the new struct type.
358     auto struct_name = ctx.dst->Sym();
359     auto* in_struct = ctx.dst->create<ast::Struct>(
360         struct_name, wrapper_struct_param_members, ast::DecorationList{});
361     ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
362 
363     // Create a new function parameter using this struct type.
364     auto* param =
365         ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
366     wrapper_ep_parameters.push_back(param);
367   }
368 
369   /// Create and return the wrapper function's struct result object.
370   /// @returns the struct type
CreateOutputStructtint::transform::CanonicalizeEntryPointIO::State371   ast::Struct* CreateOutputStruct() {
372     ast::StatementList assignments;
373 
374     auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
375 
376     // Create the struct members and their corresponding assignment statements.
377     std::unordered_set<std::string> member_names;
378     for (auto& outval : wrapper_output_values) {
379       // Use the original output name, unless that is already taken.
380       Symbol name;
381       if (member_names.count(outval.name)) {
382         name = ctx.dst->Symbols().New(outval.name);
383       } else {
384         name = ctx.dst->Symbols().Register(outval.name);
385       }
386       member_names.insert(ctx.dst->Symbols().NameFor(name));
387 
388       wrapper_struct_output_members.push_back(
389           ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
390       assignments.push_back(ctx.dst->Assign(
391           ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
392     }
393 
394     // Sort the struct members to satisfy HLSL interfacing matching rules.
395     std::sort(wrapper_struct_output_members.begin(),
396               wrapper_struct_output_members.end(), StructMemberComparator);
397 
398     // Create the new struct type.
399     auto* out_struct = ctx.dst->create<ast::Struct>(
400         ctx.dst->Sym(), wrapper_struct_output_members, ast::DecorationList{});
401     ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
402 
403     // Create the output struct object, assign its members, and return it.
404     auto* result_object =
405         ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
406     wrapper_body.push_back(ctx.dst->Decl(result_object));
407     wrapper_body.insert(wrapper_body.end(), assignments.begin(),
408                         assignments.end());
409     wrapper_body.push_back(ctx.dst->Return(wrapper_result));
410 
411     return out_struct;
412   }
413 
414   /// Create and assign the wrapper function's output variables.
CreateSpirvOutputVariablestint::transform::CanonicalizeEntryPointIO::State415   void CreateSpirvOutputVariables() {
416     for (auto& outval : wrapper_output_values) {
417       // Disable validation for use of the `output` storage class.
418       ast::DecorationList attributes = std::move(outval.attributes);
419       attributes.push_back(
420           ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
421 
422       // Create the global variable and assign it the output value.
423       auto name = ctx.dst->Symbols().New(outval.name);
424       auto* type = outval.type;
425       const ast::Expression* lhs = ctx.dst->Expr(name);
426       if (HasSampleMask(attributes)) {
427         // Vulkan requires the type of a SampleMask builtin to be an array.
428         // Declare it as array<u32, 1> and then store to the first element.
429         type = ctx.dst->ty.array(type, 1);
430         lhs = ctx.dst->IndexAccessor(lhs, 0);
431       }
432       ctx.dst->Global(name, type, ast::StorageClass::kOutput,
433                       std::move(attributes));
434       wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
435     }
436   }
437 
438   // Recreate the original function without entry point attributes and call it.
439   /// @returns the inner function call expression
CallInnerFunctiontint::transform::CanonicalizeEntryPointIO::State440   const ast::CallExpression* CallInnerFunction() {
441     // Add a suffix to the function name, as the wrapper function will take the
442     // original entry point name.
443     auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
444     auto inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
445 
446     // Clone everything, dropping the function and return type attributes.
447     // The parameter attributes will have already been stripped during
448     // processing.
449     auto* inner_function = ctx.dst->create<ast::Function>(
450         inner_name, ctx.Clone(func_ast->params),
451         ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
452         ast::DecorationList{}, ast::DecorationList{});
453     ctx.Replace(func_ast, inner_function);
454 
455     // Call the function.
456     return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
457   }
458 
459   /// Process the entry point function.
Processtint::transform::CanonicalizeEntryPointIO::State460   void Process() {
461     bool needs_fixed_sample_mask = false;
462     bool needs_vertex_point_size = false;
463     if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
464         cfg.fixed_sample_mask != 0xFFFFFFFF) {
465       needs_fixed_sample_mask = true;
466     }
467     if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
468         cfg.emit_vertex_point_size) {
469       needs_vertex_point_size = true;
470     }
471 
472     // Exit early if there is no shader IO to handle.
473     if (func_sem->Parameters().size() == 0 &&
474         func_sem->ReturnType()->Is<sem::Void>() && !needs_fixed_sample_mask &&
475         !needs_vertex_point_size) {
476       return;
477     }
478 
479     // Process the entry point parameters, collecting those that need to be
480     // aggregated into a single structure.
481     if (!func_sem->Parameters().empty()) {
482       for (auto* param : func_sem->Parameters()) {
483         if (param->Type()->Is<sem::Struct>()) {
484           ProcessStructParameter(param);
485         } else {
486           ProcessNonStructParameter(param);
487         }
488       }
489 
490       // Create a structure parameter for the outer entry point if necessary.
491       if (!wrapper_struct_param_members.empty()) {
492         CreateInputStruct();
493       }
494     }
495 
496     // Recreate the original function and call it.
497     auto* call_inner = CallInnerFunction();
498 
499     // Process the return type, and start building the wrapper function body.
500     std::function<const ast::Type*()> wrapper_ret_type = [&] {
501       return ctx.dst->ty.void_();
502     };
503     if (func_sem->ReturnType()->Is<sem::Void>()) {
504       // The function call is just a statement with no result.
505       wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
506     } else {
507       // Capture the result of calling the original function.
508       auto* inner_result = ctx.dst->Const(
509           ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
510       wrapper_body.push_back(ctx.dst->Decl(inner_result));
511 
512       // Process the original return type to determine the outputs that the
513       // outer function needs to produce.
514       ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
515     }
516 
517     // Add a fixed sample mask, if necessary.
518     if (needs_fixed_sample_mask) {
519       AddFixedSampleMask();
520     }
521 
522     // Add the pointsize builtin, if necessary.
523     if (needs_vertex_point_size) {
524       AddVertexPointSize();
525     }
526 
527     // Produce the entry point outputs, if necessary.
528     if (!wrapper_output_values.empty()) {
529       if (cfg.shader_style == ShaderStyle::kSpirv) {
530         CreateSpirvOutputVariables();
531       } else {
532         auto* output_struct = CreateOutputStruct();
533         wrapper_ret_type = [&, output_struct] {
534           return ctx.dst->ty.type_name(output_struct->name);
535         };
536       }
537     }
538 
539     // Create the wrapper entry point function.
540     // Take the name of the original entry point function.
541     auto name = ctx.Clone(func_ast->symbol);
542     auto* wrapper_func = ctx.dst->create<ast::Function>(
543         name, wrapper_ep_parameters, wrapper_ret_type(),
544         ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->decorations),
545         ast::DecorationList{});
546     ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast,
547                     wrapper_func);
548   }
549 };
550 
Run(CloneContext & ctx,const DataMap & inputs,DataMap &)551 void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
552                                    const DataMap& inputs,
553                                    DataMap&) {
554   if (!Requires<Unshadow>(ctx)) {
555     return;
556   }
557 
558   auto* cfg = inputs.Get<Config>();
559   if (cfg == nullptr) {
560     ctx.dst->Diagnostics().add_error(
561         diag::System::Transform,
562         "missing transform data for " + std::string(TypeInfo().name));
563     return;
564   }
565 
566   // Remove entry point IO attributes from struct declarations.
567   // New structures will be created for each entry point, as necessary.
568   for (auto* ty : ctx.src->AST().TypeDecls()) {
569     if (auto* struct_ty = ty->As<ast::Struct>()) {
570       for (auto* member : struct_ty->members) {
571         for (auto* deco : member->decorations) {
572           if (IsShaderIODecoration(deco)) {
573             ctx.Remove(member->decorations, deco);
574           }
575         }
576       }
577     }
578   }
579 
580   for (auto* func_ast : ctx.src->AST().Functions()) {
581     if (!func_ast->IsEntryPoint()) {
582       continue;
583     }
584 
585     State state(ctx, *cfg, func_ast);
586     state.Process();
587   }
588 
589   ctx.Clone();
590 }
591 
Config(ShaderStyle style,uint32_t sample_mask,bool emit_point_size)592 CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
593                                          uint32_t sample_mask,
594                                          bool emit_point_size)
595     : shader_style(style),
596       fixed_sample_mask(sample_mask),
597       emit_vertex_point_size(emit_point_size) {}
598 
599 CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
600 CanonicalizeEntryPointIO::Config::~Config() = default;
601 
602 }  // namespace transform
603 }  // namespace tint
604