• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/calculate_array_length.h"
16 
17 #include <unordered_map>
18 #include <utility>
19 
20 #include "src/ast/call_statement.h"
21 #include "src/ast/disable_validation_decoration.h"
22 #include "src/program_builder.h"
23 #include "src/sem/block_statement.h"
24 #include "src/sem/call.h"
25 #include "src/sem/statement.h"
26 #include "src/sem/struct.h"
27 #include "src/sem/variable.h"
28 #include "src/transform/simplify_pointers.h"
29 #include "src/utils/hash.h"
30 #include "src/utils/map.h"
31 
32 TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
33 TINT_INSTANTIATE_TYPEINFO(
34     tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
35 
36 namespace tint {
37 namespace transform {
38 
39 namespace {
40 
41 /// ArrayUsage describes a runtime array usage.
42 /// It is used as a key by the array_length_by_usage map.
43 struct ArrayUsage {
44   ast::BlockStatement const* const block;
45   sem::Node const* const buffer;
operator ==tint::transform::__anond045fda90111::ArrayUsage46   bool operator==(const ArrayUsage& rhs) const {
47     return block == rhs.block && buffer == rhs.buffer;
48   }
49   struct Hasher {
operator ()tint::transform::__anond045fda90111::ArrayUsage::Hasher50     inline std::size_t operator()(const ArrayUsage& u) const {
51       return utils::Hash(u.block, u.buffer);
52     }
53   };
54 };
55 
56 }  // namespace
57 
BufferSizeIntrinsic(ProgramID pid)58 CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid)
59     : Base(pid) {}
60 CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
InternalName() const61 std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
62   return "intrinsic_buffer_size";
63 }
64 
65 const CalculateArrayLength::BufferSizeIntrinsic*
Clone(CloneContext * ctx) const66 CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
67   return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
68       ctx->dst->ID());
69 }
70 
71 CalculateArrayLength::CalculateArrayLength() = default;
72 CalculateArrayLength::~CalculateArrayLength() = default;
73 
Run(CloneContext & ctx,const DataMap &,DataMap &)74 void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) {
75   auto& sem = ctx.src->Sem();
76   if (!Requires<SimplifyPointers>(ctx)) {
77     return;
78   }
79 
80   // get_buffer_size_intrinsic() emits the function decorated with
81   // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
82   // [RW]ByteAddressBuffer.GetDimensions().
83   std::unordered_map<const sem::Struct*, Symbol> buffer_size_intrinsics;
84   auto get_buffer_size_intrinsic = [&](const sem::Struct* buffer_type) {
85     return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
86       auto name = ctx.dst->Sym();
87       auto* buffer_typename =
88           ctx.dst->ty.type_name(ctx.Clone(buffer_type->Declaration()->name));
89       auto* disable_validation = ctx.dst->Disable(
90           ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
91       auto* func = ctx.dst->create<ast::Function>(
92           name,
93           ast::VariableList{
94               // Note: The buffer parameter requires the kStorage StorageClass
95               // in order for HLSL to emit this as a ByteAddressBuffer.
96               ctx.dst->create<ast::Variable>(
97                   ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
98                   ast::Access::kUndefined, buffer_typename, true, nullptr,
99                   ast::DecorationList{disable_validation}),
100               ctx.dst->Param("result",
101                              ctx.dst->ty.pointer(ctx.dst->ty.u32(),
102                                                  ast::StorageClass::kFunction)),
103           },
104           ctx.dst->ty.void_(), nullptr,
105           ast::DecorationList{
106               ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
107           },
108           ast::DecorationList{});
109       ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(),
110                       buffer_type->Declaration(), func);
111       return name;
112     });
113   };
114 
115   std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
116       array_length_by_usage;
117 
118   // Find all the arrayLength() calls...
119   for (auto* node : ctx.src->ASTNodes().Objects()) {
120     if (auto* call_expr = node->As<ast::CallExpression>()) {
121       auto* call = sem.Get(call_expr);
122       if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
123         if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
124           // We're dealing with an arrayLength() call
125 
126           // https://gpuweb.github.io/gpuweb/wgsl/#array-types states:
127           //
128           // * The last member of the structure type defining the store type for
129           //   a variable in the storage storage class may be a runtime-sized
130           //   array.
131           // * A runtime-sized array must not be used as the store type or
132           //   contained within a store type in any other cases.
133           // * An expression must not evaluate to a runtime-sized array type.
134           //
135           // We can assume that the arrayLength() call has a single argument of
136           // the form: arrayLength(&X.Y) where X is an expression that resolves
137           // to the storage buffer structure, and Y is the runtime sized array.
138           auto* arg = call_expr->args[0];
139           auto* address_of = arg->As<ast::UnaryOpExpression>();
140           if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
141             TINT_ICE(Transform, ctx.dst->Diagnostics())
142                 << "arrayLength() expected pointer to member access, got "
143                 << address_of->TypeInfo().name;
144           }
145           auto* array_expr = address_of->expr;
146 
147           auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
148           if (!accessor) {
149             TINT_ICE(Transform, ctx.dst->Diagnostics())
150                 << "arrayLength() expected pointer to member access, got "
151                    "pointer to "
152                 << array_expr->TypeInfo().name;
153             break;
154           }
155           auto* storage_buffer_expr = accessor->structure;
156           auto* storage_buffer_sem = sem.Get(storage_buffer_expr);
157           auto* storage_buffer_type =
158               storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>();
159 
160           // Generate BufferSizeIntrinsic for this storage type if we haven't
161           // already
162           auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
163 
164           if (!storage_buffer_type) {
165             TINT_ICE(Transform, ctx.dst->Diagnostics())
166                 << "arrayLength(X.Y) expected X to be sem::Struct, got "
167                 << storage_buffer_type->FriendlyName(ctx.src->Symbols());
168             break;
169           }
170 
171           // Find the current statement block
172           auto* block = call->Stmt()->Block()->Declaration();
173 
174           // If the storage_buffer_expr is resolves to a variable (typically
175           // true) then key the array_length from the variable. If not, key off
176           // the expression semantic node, which will be unique per call to
177           // arrayLength().
178           const sem::Node* storage_buffer_usage = storage_buffer_sem;
179           if (auto* user = storage_buffer_sem->As<sem::VariableUser>()) {
180             storage_buffer_usage = user->Variable();
181           }
182 
183           auto array_length = utils::GetOrCreate(
184               array_length_by_usage, {block, storage_buffer_usage}, [&] {
185                 // First time this array length is used for this block.
186                 // Let's calculate it.
187 
188                 // Semantic info for the runtime array structure member
189                 auto* array_member_sem = storage_buffer_type->Members().back();
190 
191                 // Construct the variable that'll hold the result of
192                 // RWByteAddressBuffer.GetDimensions()
193                 auto* buffer_size_result = ctx.dst->Decl(
194                     ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
195                                  ast::StorageClass::kNone, ctx.dst->Expr(0u)));
196 
197                 // Call storage_buffer.GetDimensions(&buffer_size_result)
198                 auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
199                     // BufferSizeIntrinsic(X, ARGS...) is
200                     // translated to:
201                     //  X.GetDimensions(ARGS..) by the writer
202                     buffer_size, ctx.Clone(storage_buffer_expr),
203                     ctx.dst->AddressOf(
204                         ctx.dst->Expr(buffer_size_result->variable->symbol))));
205 
206                 // Calculate actual array length
207                 //                total_storage_buffer_size - array_offset
208                 // array_length = ----------------------------------------
209                 //                             array_stride
210                 auto name = ctx.dst->Sym();
211                 uint32_t array_offset = array_member_sem->Offset();
212                 uint32_t array_stride = array_member_sem->Size();
213                 auto* array_length_var = ctx.dst->Decl(ctx.dst->Const(
214                     name, ctx.dst->ty.u32(),
215                     ctx.dst->Div(
216                         ctx.dst->Sub(buffer_size_result->variable->symbol,
217                                      array_offset),
218                         array_stride)));
219 
220                 // Insert the array length calculations at the top of the block
221                 ctx.InsertBefore(block->statements, block->statements[0],
222                                  buffer_size_result);
223                 ctx.InsertBefore(block->statements, block->statements[0],
224                                  call_get_dims);
225                 ctx.InsertBefore(block->statements, block->statements[0],
226                                  array_length_var);
227                 return name;
228               });
229 
230           // Replace the call to arrayLength() with the array length variable
231           ctx.Replace(call_expr, ctx.dst->Expr(array_length));
232         }
233       }
234     }
235   }
236 
237   ctx.Clone();
238 }
239 
240 }  // namespace transform
241 }  // namespace tint
242