1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/num_workgroups_from_uniform.h"
16
17 #include <memory>
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21
22 #include "src/program_builder.h"
23 #include "src/sem/function.h"
24 #include "src/transform/canonicalize_entry_point_io.h"
25 #include "src/utils/hash.h"
26
27 TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform);
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
29
30 namespace tint {
31 namespace transform {
32 namespace {
33 /// Accessor describes the identifiers used in a member accessor that is being
34 /// used to retrieve the num_workgroups builtin from a parameter.
35 struct Accessor {
36 Symbol param;
37 Symbol member;
38
39 /// Equality operator
operator ==tint::transform::__anon3e2de3200111::Accessor40 bool operator==(const Accessor& other) const {
41 return param == other.param && member == other.member;
42 }
43 /// Hash function
44 struct Hasher {
operator ()tint::transform::__anon3e2de3200111::Accessor::Hasher45 size_t operator()(const Accessor& a) const {
46 return utils::Hash(a.param, a.member);
47 }
48 };
49 };
50 } // namespace
51
52 NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
53 NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
54
Run(CloneContext & ctx,const DataMap & inputs,DataMap &)55 void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
56 const DataMap& inputs,
57 DataMap&) {
58 if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
59 return;
60 }
61
62 auto* cfg = inputs.Get<Config>();
63 if (cfg == nullptr) {
64 ctx.dst->Diagnostics().add_error(
65 diag::System::Transform,
66 "missing transform data for " + std::string(TypeInfo().name));
67 return;
68 }
69
70 const char* kNumWorkgroupsMemberName = "num_workgroups";
71
72 // Find all entry point parameters that declare the num_workgroups builtin.
73 std::unordered_set<Accessor, Accessor::Hasher> to_replace;
74 for (auto* func : ctx.src->AST().Functions()) {
75 // num_workgroups is only valid for compute stages.
76 if (func->PipelineStage() != ast::PipelineStage::kCompute) {
77 continue;
78 }
79
80 for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
81 // Because the CanonicalizeEntryPointIO transform has been run, builtins
82 // will only appear as struct members.
83 auto* str = param->Type()->As<sem::Struct>();
84 if (!str) {
85 continue;
86 }
87
88 for (auto* member : str->Members()) {
89 auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
90 member->Declaration()->decorations);
91 if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
92 continue;
93 }
94
95 // Capture the symbols that would be used to access this member, which
96 // we will replace later. We currently have no way to get from the
97 // parameter directly to the member accessor expressions that use it.
98 to_replace.insert(
99 {param->Declaration()->symbol, member->Declaration()->symbol});
100
101 // Remove the struct member.
102 // The CanonicalizeEntryPointIO transform will have generated this
103 // struct uniquely for this particular entry point, so we know that
104 // there will be no other uses of this struct in the module and that we
105 // can safely modify it here.
106 ctx.Remove(str->Declaration()->members, member->Declaration());
107
108 // If this is the only member, remove the struct and parameter too.
109 if (str->Members().size() == 1) {
110 ctx.Remove(func->params, param->Declaration());
111 ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
112 }
113 }
114 }
115 }
116
117 // Get (or create, on first call) the uniform buffer that will receive the
118 // number of workgroups.
119 const ast::Variable* num_workgroups_ubo = nullptr;
120 auto get_ubo = [&]() {
121 if (!num_workgroups_ubo) {
122 auto* num_workgroups_struct = ctx.dst->Structure(
123 ctx.dst->Sym(),
124 {ctx.dst->Member(kNumWorkgroupsMemberName,
125 ctx.dst->ty.vec3(ctx.dst->ty.u32()))},
126 ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
127 num_workgroups_ubo = ctx.dst->Global(
128 ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
129 ast::StorageClass::kUniform,
130 ast::DecorationList{ctx.dst->GroupAndBinding(
131 cfg->ubo_binding.group, cfg->ubo_binding.binding)});
132 }
133 return num_workgroups_ubo;
134 };
135
136 // Now replace all the places where the builtins are accessed with the value
137 // loaded from the uniform buffer.
138 for (auto* node : ctx.src->ASTNodes().Objects()) {
139 auto* accessor = node->As<ast::MemberAccessorExpression>();
140 if (!accessor) {
141 continue;
142 }
143 auto* ident = accessor->structure->As<ast::IdentifierExpression>();
144 if (!ident) {
145 continue;
146 }
147
148 if (to_replace.count({ident->symbol, accessor->member->symbol})) {
149 ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
150 kNumWorkgroupsMemberName));
151 }
152 }
153
154 ctx.Clone();
155 }
156
Config(sem::BindingPoint ubo_bp)157 NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
158 : ubo_binding(ubo_bp) {}
159 NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
160 NumWorkgroupsFromUniform::Config::~Config() = default;
161
162 } // namespace transform
163 } // namespace tint
164