1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/array_length_from_uniform.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "src/ast/struct_block_decoration.h"
22 #include "src/program_builder.h"
23 #include "src/sem/call.h"
24 #include "src/sem/variable.h"
25 #include "src/transform/simplify_pointers.h"
26
27 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform);
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config);
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
30
31 namespace tint {
32 namespace transform {
33
34 ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
35 ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
36
37 /// Iterate over all arrayLength() intrinsics that operate on
38 /// storage buffer variables.
39 /// @param ctx the CloneContext.
40 /// @param functor of type void(const ast::CallExpression*, const
41 /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
42 /// ast::CallExpression of the arrayLength call expression node, a
43 /// sem::VariableUser of the used storage buffer variable, and the
44 /// sem::GlobalVariable for the storage buffer.
45 template <typename F>
IterateArrayLengthOnStorageVar(CloneContext & ctx,F && functor)46 static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
47 auto& sem = ctx.src->Sem();
48
49 // Find all calls to the arrayLength() intrinsic.
50 for (auto* node : ctx.src->ASTNodes().Objects()) {
51 auto* call_expr = node->As<ast::CallExpression>();
52 if (!call_expr) {
53 continue;
54 }
55
56 auto* call = sem.Get(call_expr);
57 auto* intrinsic = call->Target()->As<sem::Intrinsic>();
58 if (!intrinsic || intrinsic->Type() != sem::IntrinsicType::kArrayLength) {
59 continue;
60 }
61
62 // Get the storage buffer that contains the runtime array.
63 // We assume that the argument to `arrayLength` has the form
64 // `&resource.array`, which requires that `SimplifyPointers` have been run
65 // before this transform.
66 auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
67 if (!param || param->op != ast::UnaryOp::kAddressOf) {
68 TINT_ICE(Transform, ctx.dst->Diagnostics())
69 << "expected form of arrayLength argument to be "
70 "&resource.array";
71 break;
72 }
73 auto* accessor = param->expr->As<ast::MemberAccessorExpression>();
74 if (!accessor) {
75 TINT_ICE(Transform, ctx.dst->Diagnostics())
76 << "expected form of arrayLength argument to be "
77 "&resource.array";
78 break;
79 }
80 auto* storage_buffer_expr = accessor->structure;
81 auto* storage_buffer_sem =
82 sem.Get(storage_buffer_expr)->As<sem::VariableUser>();
83 if (!storage_buffer_sem) {
84 TINT_ICE(Transform, ctx.dst->Diagnostics())
85 << "expected form of arrayLength argument to be "
86 "&resource.array";
87 break;
88 }
89
90 // Get the index to use for the buffer size array.
91 auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
92 if (!var) {
93 TINT_ICE(Transform, ctx.dst->Diagnostics())
94 << "storage buffer is not a global variable";
95 break;
96 }
97 functor(call_expr, storage_buffer_sem, var);
98 }
99 }
100
Run(CloneContext & ctx,const DataMap & inputs,DataMap & outputs)101 void ArrayLengthFromUniform::Run(CloneContext& ctx,
102 const DataMap& inputs,
103 DataMap& outputs) {
104 if (!Requires<SimplifyPointers>(ctx)) {
105 return;
106 }
107
108 auto* cfg = inputs.Get<Config>();
109 if (cfg == nullptr) {
110 ctx.dst->Diagnostics().add_error(
111 diag::System::Transform,
112 "missing transform data for " + std::string(TypeInfo().name));
113 return;
114 }
115
116 const char* kBufferSizeMemberName = "buffer_size";
117
118 // Determine the size of the buffer size array.
119 uint32_t max_buffer_size_index = 0;
120
121 IterateArrayLengthOnStorageVar(
122 ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
123 const sem::GlobalVariable* var) {
124 auto binding = var->BindingPoint();
125 auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
126 if (idx_itr == cfg->bindpoint_to_size_index.end()) {
127 return;
128 }
129 if (idx_itr->second > max_buffer_size_index) {
130 max_buffer_size_index = idx_itr->second;
131 }
132 });
133
134 // Get (or create, on first call) the uniform buffer that will receive the
135 // size of each storage buffer in the module.
136 const ast::Variable* buffer_size_ubo = nullptr;
137 auto get_ubo = [&]() {
138 if (!buffer_size_ubo) {
139 // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
140 // We do this because UBOs require an element stride that is 16-byte
141 // aligned.
142 auto* buffer_size_struct = ctx.dst->Structure(
143 ctx.dst->Sym(),
144 {ctx.dst->Member(
145 kBufferSizeMemberName,
146 ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
147 (max_buffer_size_index / 4) + 1))},
148
149 ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
150 buffer_size_ubo = ctx.dst->Global(
151 ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
152 ast::StorageClass::kUniform,
153 ast::DecorationList{
154 ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
155 ctx.dst->create<ast::BindingDecoration>(
156 cfg->ubo_binding.binding)});
157 }
158 return buffer_size_ubo;
159 };
160
161 std::unordered_set<uint32_t> used_size_indices;
162
163 IterateArrayLengthOnStorageVar(
164 ctx, [&](const ast::CallExpression* call_expr,
165 const sem::VariableUser* storage_buffer_sem,
166 const sem::GlobalVariable* var) {
167 auto binding = var->BindingPoint();
168 auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
169 if (idx_itr == cfg->bindpoint_to_size_index.end()) {
170 return;
171 }
172
173 uint32_t size_index = idx_itr->second;
174 used_size_indices.insert(size_index);
175
176 // Load the total storage buffer size from the UBO.
177 uint32_t array_index = size_index / 4;
178 auto* vec_expr = ctx.dst->IndexAccessor(
179 ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
180 array_index);
181 uint32_t vec_index = size_index % 4;
182 auto* total_storage_buffer_size =
183 ctx.dst->IndexAccessor(vec_expr, vec_index);
184
185 // Calculate actual array length
186 // total_storage_buffer_size - array_offset
187 // array_length = ----------------------------------------
188 // array_stride
189 auto* storage_buffer_type =
190 storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>();
191 auto* array_member_sem = storage_buffer_type->Members().back();
192 uint32_t array_offset = array_member_sem->Offset();
193 uint32_t array_stride = array_member_sem->Size();
194 auto* array_length =
195 ctx.dst->Div(ctx.dst->Sub(total_storage_buffer_size, array_offset),
196 array_stride);
197
198 ctx.Replace(call_expr, array_length);
199 });
200
201 ctx.Clone();
202
203 outputs.Add<Result>(used_size_indices);
204 }
205
Config(sem::BindingPoint ubo_bp)206 ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
207 : ubo_binding(ubo_bp) {}
208 ArrayLengthFromUniform::Config::Config(const Config&) = default;
209 ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
210 const Config&) = default;
211 ArrayLengthFromUniform::Config::~Config() = default;
212
Result(std::unordered_set<uint32_t> used_size_indices_in)213 ArrayLengthFromUniform::Result::Result(
214 std::unordered_set<uint32_t> used_size_indices_in)
215 : used_size_indices(std::move(used_size_indices_in)) {}
216 ArrayLengthFromUniform::Result::Result(const Result&) = default;
217 ArrayLengthFromUniform::Result::~Result() = default;
218
219 } // namespace transform
220 } // namespace tint
221