• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/num_workgroups_from_uniform.h"
16 
17 #include <memory>
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21 
22 #include "src/program_builder.h"
23 #include "src/sem/function.h"
24 #include "src/transform/canonicalize_entry_point_io.h"
25 #include "src/utils/hash.h"
26 
27 TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform);
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
29 
30 namespace tint {
31 namespace transform {
32 namespace {
33 /// Accessor describes the identifiers used in a member accessor that is being
34 /// used to retrieve the num_workgroups builtin from a parameter.
35 struct Accessor {
36   Symbol param;
37   Symbol member;
38 
39   /// Equality operator
operator ==tint::transform::__anon3e2de3200111::Accessor40   bool operator==(const Accessor& other) const {
41     return param == other.param && member == other.member;
42   }
43   /// Hash function
44   struct Hasher {
operator ()tint::transform::__anon3e2de3200111::Accessor::Hasher45     size_t operator()(const Accessor& a) const {
46       return utils::Hash(a.param, a.member);
47     }
48   };
49 };
50 }  // namespace
51 
52 NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
53 NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
54 
Run(CloneContext & ctx,const DataMap & inputs,DataMap &)55 void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
56                                    const DataMap& inputs,
57                                    DataMap&) {
58   if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
59     return;
60   }
61 
62   auto* cfg = inputs.Get<Config>();
63   if (cfg == nullptr) {
64     ctx.dst->Diagnostics().add_error(
65         diag::System::Transform,
66         "missing transform data for " + std::string(TypeInfo().name));
67     return;
68   }
69 
70   const char* kNumWorkgroupsMemberName = "num_workgroups";
71 
72   // Find all entry point parameters that declare the num_workgroups builtin.
73   std::unordered_set<Accessor, Accessor::Hasher> to_replace;
74   for (auto* func : ctx.src->AST().Functions()) {
75     // num_workgroups is only valid for compute stages.
76     if (func->PipelineStage() != ast::PipelineStage::kCompute) {
77       continue;
78     }
79 
80     for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
81       // Because the CanonicalizeEntryPointIO transform has been run, builtins
82       // will only appear as struct members.
83       auto* str = param->Type()->As<sem::Struct>();
84       if (!str) {
85         continue;
86       }
87 
88       for (auto* member : str->Members()) {
89         auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
90             member->Declaration()->decorations);
91         if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
92           continue;
93         }
94 
95         // Capture the symbols that would be used to access this member, which
96         // we will replace later. We currently have no way to get from the
97         // parameter directly to the member accessor expressions that use it.
98         to_replace.insert(
99             {param->Declaration()->symbol, member->Declaration()->symbol});
100 
101         // Remove the struct member.
102         // The CanonicalizeEntryPointIO transform will have generated this
103         // struct uniquely for this particular entry point, so we know that
104         // there will be no other uses of this struct in the module and that we
105         // can safely modify it here.
106         ctx.Remove(str->Declaration()->members, member->Declaration());
107 
108         // If this is the only member, remove the struct and parameter too.
109         if (str->Members().size() == 1) {
110           ctx.Remove(func->params, param->Declaration());
111           ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
112         }
113       }
114     }
115   }
116 
117   // Get (or create, on first call) the uniform buffer that will receive the
118   // number of workgroups.
119   const ast::Variable* num_workgroups_ubo = nullptr;
120   auto get_ubo = [&]() {
121     if (!num_workgroups_ubo) {
122       auto* num_workgroups_struct = ctx.dst->Structure(
123           ctx.dst->Sym(),
124           {ctx.dst->Member(kNumWorkgroupsMemberName,
125                            ctx.dst->ty.vec3(ctx.dst->ty.u32()))},
126           ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
127       num_workgroups_ubo = ctx.dst->Global(
128           ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
129           ast::StorageClass::kUniform,
130           ast::DecorationList{ctx.dst->GroupAndBinding(
131               cfg->ubo_binding.group, cfg->ubo_binding.binding)});
132     }
133     return num_workgroups_ubo;
134   };
135 
136   // Now replace all the places where the builtins are accessed with the value
137   // loaded from the uniform buffer.
138   for (auto* node : ctx.src->ASTNodes().Objects()) {
139     auto* accessor = node->As<ast::MemberAccessorExpression>();
140     if (!accessor) {
141       continue;
142     }
143     auto* ident = accessor->structure->As<ast::IdentifierExpression>();
144     if (!ident) {
145       continue;
146     }
147 
148     if (to_replace.count({ident->symbol, accessor->member->symbol})) {
149       ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
150                                                     kNumWorkgroupsMemberName));
151     }
152   }
153 
154   ctx.Clone();
155 }
156 
Config(sem::BindingPoint ubo_bp)157 NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
158     : ubo_binding(ubo_bp) {}
159 NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
160 NumWorkgroupsFromUniform::Config::~Config() = default;
161 
162 }  // namespace transform
163 }  // namespace tint
164