• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/array_length_from_uniform.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 
21 #include "src/ast/struct_block_decoration.h"
22 #include "src/program_builder.h"
23 #include "src/sem/call.h"
24 #include "src/sem/variable.h"
25 #include "src/transform/simplify_pointers.h"
26 
27 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform);
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config);
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
30 
31 namespace tint {
32 namespace transform {
33 
34 ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
35 ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
36 
37 /// Iterate over all arrayLength() intrinsics that operate on
38 /// storage buffer variables.
39 /// @param ctx the CloneContext.
40 /// @param functor of type void(const ast::CallExpression*, const
41 /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
42 /// ast::CallExpression of the arrayLength call expression node, a
43 /// sem::VariableUser of the used storage buffer variable, and the
44 /// sem::GlobalVariable for the storage buffer.
45 template <typename F>
IterateArrayLengthOnStorageVar(CloneContext & ctx,F && functor)46 static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
47   auto& sem = ctx.src->Sem();
48 
49   // Find all calls to the arrayLength() intrinsic.
50   for (auto* node : ctx.src->ASTNodes().Objects()) {
51     auto* call_expr = node->As<ast::CallExpression>();
52     if (!call_expr) {
53       continue;
54     }
55 
56     auto* call = sem.Get(call_expr);
57     auto* intrinsic = call->Target()->As<sem::Intrinsic>();
58     if (!intrinsic || intrinsic->Type() != sem::IntrinsicType::kArrayLength) {
59       continue;
60     }
61 
62     // Get the storage buffer that contains the runtime array.
63     // We assume that the argument to `arrayLength` has the form
64     // `&resource.array`, which requires that `SimplifyPointers` have been run
65     // before this transform.
66     auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
67     if (!param || param->op != ast::UnaryOp::kAddressOf) {
68       TINT_ICE(Transform, ctx.dst->Diagnostics())
69           << "expected form of arrayLength argument to be "
70              "&resource.array";
71       break;
72     }
73     auto* accessor = param->expr->As<ast::MemberAccessorExpression>();
74     if (!accessor) {
75       TINT_ICE(Transform, ctx.dst->Diagnostics())
76           << "expected form of arrayLength argument to be "
77              "&resource.array";
78       break;
79     }
80     auto* storage_buffer_expr = accessor->structure;
81     auto* storage_buffer_sem =
82         sem.Get(storage_buffer_expr)->As<sem::VariableUser>();
83     if (!storage_buffer_sem) {
84       TINT_ICE(Transform, ctx.dst->Diagnostics())
85           << "expected form of arrayLength argument to be "
86              "&resource.array";
87       break;
88     }
89 
90     // Get the index to use for the buffer size array.
91     auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
92     if (!var) {
93       TINT_ICE(Transform, ctx.dst->Diagnostics())
94           << "storage buffer is not a global variable";
95       break;
96     }
97     functor(call_expr, storage_buffer_sem, var);
98   }
99 }
100 
Run(CloneContext & ctx,const DataMap & inputs,DataMap & outputs)101 void ArrayLengthFromUniform::Run(CloneContext& ctx,
102                                  const DataMap& inputs,
103                                  DataMap& outputs) {
104   if (!Requires<SimplifyPointers>(ctx)) {
105     return;
106   }
107 
108   auto* cfg = inputs.Get<Config>();
109   if (cfg == nullptr) {
110     ctx.dst->Diagnostics().add_error(
111         diag::System::Transform,
112         "missing transform data for " + std::string(TypeInfo().name));
113     return;
114   }
115 
116   const char* kBufferSizeMemberName = "buffer_size";
117 
118   // Determine the size of the buffer size array.
119   uint32_t max_buffer_size_index = 0;
120 
121   IterateArrayLengthOnStorageVar(
122       ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
123                const sem::GlobalVariable* var) {
124         auto binding = var->BindingPoint();
125         auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
126         if (idx_itr == cfg->bindpoint_to_size_index.end()) {
127           return;
128         }
129         if (idx_itr->second > max_buffer_size_index) {
130           max_buffer_size_index = idx_itr->second;
131         }
132       });
133 
134   // Get (or create, on first call) the uniform buffer that will receive the
135   // size of each storage buffer in the module.
136   const ast::Variable* buffer_size_ubo = nullptr;
137   auto get_ubo = [&]() {
138     if (!buffer_size_ubo) {
139       // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
140       // We do this because UBOs require an element stride that is 16-byte
141       // aligned.
142       auto* buffer_size_struct = ctx.dst->Structure(
143           ctx.dst->Sym(),
144           {ctx.dst->Member(
145               kBufferSizeMemberName,
146               ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
147                                 (max_buffer_size_index / 4) + 1))},
148 
149           ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
150       buffer_size_ubo = ctx.dst->Global(
151           ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
152           ast::StorageClass::kUniform,
153           ast::DecorationList{
154               ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
155               ctx.dst->create<ast::BindingDecoration>(
156                   cfg->ubo_binding.binding)});
157     }
158     return buffer_size_ubo;
159   };
160 
161   std::unordered_set<uint32_t> used_size_indices;
162 
163   IterateArrayLengthOnStorageVar(
164       ctx, [&](const ast::CallExpression* call_expr,
165                const sem::VariableUser* storage_buffer_sem,
166                const sem::GlobalVariable* var) {
167         auto binding = var->BindingPoint();
168         auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
169         if (idx_itr == cfg->bindpoint_to_size_index.end()) {
170           return;
171         }
172 
173         uint32_t size_index = idx_itr->second;
174         used_size_indices.insert(size_index);
175 
176         // Load the total storage buffer size from the UBO.
177         uint32_t array_index = size_index / 4;
178         auto* vec_expr = ctx.dst->IndexAccessor(
179             ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
180             array_index);
181         uint32_t vec_index = size_index % 4;
182         auto* total_storage_buffer_size =
183             ctx.dst->IndexAccessor(vec_expr, vec_index);
184 
185         // Calculate actual array length
186         //                total_storage_buffer_size - array_offset
187         // array_length = ----------------------------------------
188         //                             array_stride
189         auto* storage_buffer_type =
190             storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>();
191         auto* array_member_sem = storage_buffer_type->Members().back();
192         uint32_t array_offset = array_member_sem->Offset();
193         uint32_t array_stride = array_member_sem->Size();
194         auto* array_length =
195             ctx.dst->Div(ctx.dst->Sub(total_storage_buffer_size, array_offset),
196                          array_stride);
197 
198         ctx.Replace(call_expr, array_length);
199       });
200 
201   ctx.Clone();
202 
203   outputs.Add<Result>(used_size_indices);
204 }
205 
Config(sem::BindingPoint ubo_bp)206 ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
207     : ubo_binding(ubo_bp) {}
208 ArrayLengthFromUniform::Config::Config(const Config&) = default;
209 ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
210     const Config&) = default;
211 ArrayLengthFromUniform::Config::~Config() = default;
212 
Result(std::unordered_set<uint32_t> used_size_indices_in)213 ArrayLengthFromUniform::Result::Result(
214     std::unordered_set<uint32_t> used_size_indices_in)
215     : used_size_indices(std::move(used_size_indices_in)) {}
216 ArrayLengthFromUniform::Result::Result(const Result&) = default;
217 ArrayLengthFromUniform::Result::~Result() = default;
218 
219 }  // namespace transform
220 }  // namespace tint
221