1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/calculate_array_length.h"
16
17 #include <unordered_map>
18 #include <utility>
19
20 #include "src/ast/call_statement.h"
21 #include "src/ast/disable_validation_decoration.h"
22 #include "src/program_builder.h"
23 #include "src/sem/block_statement.h"
24 #include "src/sem/call.h"
25 #include "src/sem/statement.h"
26 #include "src/sem/struct.h"
27 #include "src/sem/variable.h"
28 #include "src/transform/simplify_pointers.h"
29 #include "src/utils/hash.h"
30 #include "src/utils/map.h"
31
32 TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
33 TINT_INSTANTIATE_TYPEINFO(
34 tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
35
36 namespace tint {
37 namespace transform {
38
39 namespace {
40
41 /// ArrayUsage describes a runtime array usage.
42 /// It is used as a key by the array_length_by_usage map.
43 struct ArrayUsage {
44 ast::BlockStatement const* const block;
45 sem::Node const* const buffer;
operator ==tint::transform::__anond045fda90111::ArrayUsage46 bool operator==(const ArrayUsage& rhs) const {
47 return block == rhs.block && buffer == rhs.buffer;
48 }
49 struct Hasher {
operator ()tint::transform::__anond045fda90111::ArrayUsage::Hasher50 inline std::size_t operator()(const ArrayUsage& u) const {
51 return utils::Hash(u.block, u.buffer);
52 }
53 };
54 };
55
56 } // namespace
57
BufferSizeIntrinsic(ProgramID pid)58 CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid)
59 : Base(pid) {}
60 CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
InternalName() const61 std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
62 return "intrinsic_buffer_size";
63 }
64
65 const CalculateArrayLength::BufferSizeIntrinsic*
Clone(CloneContext * ctx) const66 CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
67 return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
68 ctx->dst->ID());
69 }
70
71 CalculateArrayLength::CalculateArrayLength() = default;
72 CalculateArrayLength::~CalculateArrayLength() = default;
73
Run(CloneContext & ctx,const DataMap &,DataMap &)74 void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) {
75 auto& sem = ctx.src->Sem();
76 if (!Requires<SimplifyPointers>(ctx)) {
77 return;
78 }
79
80 // get_buffer_size_intrinsic() emits the function decorated with
81 // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
82 // [RW]ByteAddressBuffer.GetDimensions().
83 std::unordered_map<const sem::Struct*, Symbol> buffer_size_intrinsics;
84 auto get_buffer_size_intrinsic = [&](const sem::Struct* buffer_type) {
85 return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
86 auto name = ctx.dst->Sym();
87 auto* buffer_typename =
88 ctx.dst->ty.type_name(ctx.Clone(buffer_type->Declaration()->name));
89 auto* disable_validation = ctx.dst->Disable(
90 ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
91 auto* func = ctx.dst->create<ast::Function>(
92 name,
93 ast::VariableList{
94 // Note: The buffer parameter requires the kStorage StorageClass
95 // in order for HLSL to emit this as a ByteAddressBuffer.
96 ctx.dst->create<ast::Variable>(
97 ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
98 ast::Access::kUndefined, buffer_typename, true, nullptr,
99 ast::DecorationList{disable_validation}),
100 ctx.dst->Param("result",
101 ctx.dst->ty.pointer(ctx.dst->ty.u32(),
102 ast::StorageClass::kFunction)),
103 },
104 ctx.dst->ty.void_(), nullptr,
105 ast::DecorationList{
106 ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
107 },
108 ast::DecorationList{});
109 ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(),
110 buffer_type->Declaration(), func);
111 return name;
112 });
113 };
114
115 std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
116 array_length_by_usage;
117
118 // Find all the arrayLength() calls...
119 for (auto* node : ctx.src->ASTNodes().Objects()) {
120 if (auto* call_expr = node->As<ast::CallExpression>()) {
121 auto* call = sem.Get(call_expr);
122 if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
123 if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
124 // We're dealing with an arrayLength() call
125
126 // https://gpuweb.github.io/gpuweb/wgsl/#array-types states:
127 //
128 // * The last member of the structure type defining the store type for
129 // a variable in the storage storage class may be a runtime-sized
130 // array.
131 // * A runtime-sized array must not be used as the store type or
132 // contained within a store type in any other cases.
133 // * An expression must not evaluate to a runtime-sized array type.
134 //
135 // We can assume that the arrayLength() call has a single argument of
136 // the form: arrayLength(&X.Y) where X is an expression that resolves
137 // to the storage buffer structure, and Y is the runtime sized array.
138 auto* arg = call_expr->args[0];
139 auto* address_of = arg->As<ast::UnaryOpExpression>();
140 if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
141 TINT_ICE(Transform, ctx.dst->Diagnostics())
142 << "arrayLength() expected pointer to member access, got "
143 << address_of->TypeInfo().name;
144 }
145 auto* array_expr = address_of->expr;
146
147 auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
148 if (!accessor) {
149 TINT_ICE(Transform, ctx.dst->Diagnostics())
150 << "arrayLength() expected pointer to member access, got "
151 "pointer to "
152 << array_expr->TypeInfo().name;
153 break;
154 }
155 auto* storage_buffer_expr = accessor->structure;
156 auto* storage_buffer_sem = sem.Get(storage_buffer_expr);
157 auto* storage_buffer_type =
158 storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>();
159
160 // Generate BufferSizeIntrinsic for this storage type if we haven't
161 // already
162 auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
163
164 if (!storage_buffer_type) {
165 TINT_ICE(Transform, ctx.dst->Diagnostics())
166 << "arrayLength(X.Y) expected X to be sem::Struct, got "
167 << storage_buffer_type->FriendlyName(ctx.src->Symbols());
168 break;
169 }
170
171 // Find the current statement block
172 auto* block = call->Stmt()->Block()->Declaration();
173
174 // If the storage_buffer_expr is resolves to a variable (typically
175 // true) then key the array_length from the variable. If not, key off
176 // the expression semantic node, which will be unique per call to
177 // arrayLength().
178 const sem::Node* storage_buffer_usage = storage_buffer_sem;
179 if (auto* user = storage_buffer_sem->As<sem::VariableUser>()) {
180 storage_buffer_usage = user->Variable();
181 }
182
183 auto array_length = utils::GetOrCreate(
184 array_length_by_usage, {block, storage_buffer_usage}, [&] {
185 // First time this array length is used for this block.
186 // Let's calculate it.
187
188 // Semantic info for the runtime array structure member
189 auto* array_member_sem = storage_buffer_type->Members().back();
190
191 // Construct the variable that'll hold the result of
192 // RWByteAddressBuffer.GetDimensions()
193 auto* buffer_size_result = ctx.dst->Decl(
194 ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
195 ast::StorageClass::kNone, ctx.dst->Expr(0u)));
196
197 // Call storage_buffer.GetDimensions(&buffer_size_result)
198 auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
199 // BufferSizeIntrinsic(X, ARGS...) is
200 // translated to:
201 // X.GetDimensions(ARGS..) by the writer
202 buffer_size, ctx.Clone(storage_buffer_expr),
203 ctx.dst->AddressOf(
204 ctx.dst->Expr(buffer_size_result->variable->symbol))));
205
206 // Calculate actual array length
207 // total_storage_buffer_size - array_offset
208 // array_length = ----------------------------------------
209 // array_stride
210 auto name = ctx.dst->Sym();
211 uint32_t array_offset = array_member_sem->Offset();
212 uint32_t array_stride = array_member_sem->Size();
213 auto* array_length_var = ctx.dst->Decl(ctx.dst->Const(
214 name, ctx.dst->ty.u32(),
215 ctx.dst->Div(
216 ctx.dst->Sub(buffer_size_result->variable->symbol,
217 array_offset),
218 array_stride)));
219
220 // Insert the array length calculations at the top of the block
221 ctx.InsertBefore(block->statements, block->statements[0],
222 buffer_size_result);
223 ctx.InsertBefore(block->statements, block->statements[0],
224 call_get_dims);
225 ctx.InsertBefore(block->statements, block->statements[0],
226 array_length_var);
227 return name;
228 });
229
230 // Replace the call to arrayLength() with the array length variable
231 ctx.Replace(call_expr, ctx.dst->Expr(array_length));
232 }
233 }
234 }
235 }
236
237 ctx.Clone();
238 }
239
240 } // namespace transform
241 } // namespace tint
242