1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/decompose_memory_access.h"
16
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "src/ast/assignment_statement.h"
24 #include "src/ast/call_statement.h"
25 #include "src/ast/disable_validation_decoration.h"
26 #include "src/ast/type_name.h"
27 #include "src/ast/unary_op.h"
28 #include "src/block_allocator.h"
29 #include "src/program_builder.h"
30 #include "src/sem/array.h"
31 #include "src/sem/atomic_type.h"
32 #include "src/sem/call.h"
33 #include "src/sem/member_accessor_expression.h"
34 #include "src/sem/reference_type.h"
35 #include "src/sem/statement.h"
36 #include "src/sem/struct.h"
37 #include "src/sem/variable.h"
38 #include "src/utils/map.h"
39 #include "src/utils/hash.h"
40
41 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess);
42 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
43
44 namespace tint {
45 namespace transform {
46
47 namespace {
48
49 /// Offset is a simple ast::Expression builder interface, used to build byte
50 /// offsets for storage and uniform buffer accesses.
51 struct Offset : Castable<Offset> {
52 /// @returns builds and returns the ast::Expression in `ctx.dst`
53 virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
54 };
55
56 /// OffsetExpr is an implementation of Offset that clones and casts the given
57 /// expression to `u32`.
58 struct OffsetExpr : Offset {
59 const ast::Expression* const expr = nullptr;
60
OffsetExprtint::transform::__anon253b86840111::OffsetExpr61 explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
62
Buildtint::transform::__anon253b86840111::OffsetExpr63 const ast::Expression* Build(CloneContext& ctx) const override {
64 auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
65 auto* res = ctx.Clone(expr);
66 if (!type->Is<sem::U32>()) {
67 res = ctx.dst->Construct<ProgramBuilder::u32>(res);
68 }
69 return res;
70 }
71 };
72
73 /// OffsetLiteral is an implementation of Offset that constructs a u32 literal
74 /// value.
75 struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
76 uint32_t const literal = 0;
77
OffsetLiteraltint::transform::__anon253b86840111::OffsetLiteral78 explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
79
Buildtint::transform::__anon253b86840111::OffsetLiteral80 const ast::Expression* Build(CloneContext& ctx) const override {
81 return ctx.dst->Expr(literal);
82 }
83 };
84
85 /// OffsetBinOp is an implementation of Offset that constructs a binary-op of
86 /// two Offsets.
87 struct OffsetBinOp : Offset {
88 ast::BinaryOp op;
89 Offset const* lhs = nullptr;
90 Offset const* rhs = nullptr;
91
Buildtint::transform::__anon253b86840111::OffsetBinOp92 const ast::Expression* Build(CloneContext& ctx) const override {
93 return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
94 rhs->Build(ctx));
95 }
96 };
97
98 /// LoadStoreKey is the unordered map key to a load or store intrinsic.
99 struct LoadStoreKey {
100 ast::StorageClass const storage_class; // buffer storage class
101 sem::Type const* buf_ty = nullptr; // buffer type
102 sem::Type const* el_ty = nullptr; // element type
operator ==tint::transform::__anon253b86840111::LoadStoreKey103 bool operator==(const LoadStoreKey& rhs) const {
104 return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
105 el_ty == rhs.el_ty;
106 }
107 struct Hasher {
operator ()tint::transform::__anon253b86840111::LoadStoreKey::Hasher108 inline std::size_t operator()(const LoadStoreKey& u) const {
109 return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
110 }
111 };
112 };
113
114 /// AtomicKey is the unordered map key to an atomic intrinsic.
115 struct AtomicKey {
116 sem::Type const* buf_ty = nullptr; // buffer type
117 sem::Type const* el_ty = nullptr; // element type
118 sem::IntrinsicType const op; // atomic op
operator ==tint::transform::__anon253b86840111::AtomicKey119 bool operator==(const AtomicKey& rhs) const {
120 return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
121 }
122 struct Hasher {
operator ()tint::transform::__anon253b86840111::AtomicKey::Hasher123 inline std::size_t operator()(const AtomicKey& u) const {
124 return utils::Hash(u.buf_ty, u.el_ty, u.op);
125 }
126 };
127 };
128
IntrinsicDataTypeFor(const sem::Type * ty,DecomposeMemoryAccess::Intrinsic::DataType & out)129 bool IntrinsicDataTypeFor(const sem::Type* ty,
130 DecomposeMemoryAccess::Intrinsic::DataType& out) {
131 if (ty->Is<sem::I32>()) {
132 out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
133 return true;
134 }
135 if (ty->Is<sem::U32>()) {
136 out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
137 return true;
138 }
139 if (ty->Is<sem::F32>()) {
140 out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
141 return true;
142 }
143 if (auto* vec = ty->As<sem::Vector>()) {
144 switch (vec->Width()) {
145 case 2:
146 if (vec->type()->Is<sem::I32>()) {
147 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
148 return true;
149 }
150 if (vec->type()->Is<sem::U32>()) {
151 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
152 return true;
153 }
154 if (vec->type()->Is<sem::F32>()) {
155 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
156 return true;
157 }
158 break;
159 case 3:
160 if (vec->type()->Is<sem::I32>()) {
161 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
162 return true;
163 }
164 if (vec->type()->Is<sem::U32>()) {
165 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
166 return true;
167 }
168 if (vec->type()->Is<sem::F32>()) {
169 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
170 return true;
171 }
172 break;
173 case 4:
174 if (vec->type()->Is<sem::I32>()) {
175 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
176 return true;
177 }
178 if (vec->type()->Is<sem::U32>()) {
179 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
180 return true;
181 }
182 if (vec->type()->Is<sem::F32>()) {
183 out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
184 return true;
185 }
186 break;
187 }
188 return false;
189 }
190
191 return false;
192 }
193
194 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
195 /// to a stub function to load the type `ty`.
IntrinsicLoadFor(ProgramBuilder * builder,ast::StorageClass storage_class,const sem::Type * ty)196 DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
197 ProgramBuilder* builder,
198 ast::StorageClass storage_class,
199 const sem::Type* ty) {
200 DecomposeMemoryAccess::Intrinsic::DataType type;
201 if (!IntrinsicDataTypeFor(ty, type)) {
202 return nullptr;
203 }
204 return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
205 builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
206 type);
207 }
208
209 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
210 /// to a stub function to store the type `ty`.
IntrinsicStoreFor(ProgramBuilder * builder,ast::StorageClass storage_class,const sem::Type * ty)211 DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
212 ProgramBuilder* builder,
213 ast::StorageClass storage_class,
214 const sem::Type* ty) {
215 DecomposeMemoryAccess::Intrinsic::DataType type;
216 if (!IntrinsicDataTypeFor(ty, type)) {
217 return nullptr;
218 }
219 return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
220 builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
221 storage_class, type);
222 }
223
224 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
225 /// to a stub function for the atomic op and the type `ty`.
IntrinsicAtomicFor(ProgramBuilder * builder,sem::IntrinsicType ity,const sem::Type * ty)226 DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
227 sem::IntrinsicType ity,
228 const sem::Type* ty) {
229 auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
230 switch (ity) {
231 case sem::IntrinsicType::kAtomicLoad:
232 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
233 break;
234 case sem::IntrinsicType::kAtomicStore:
235 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
236 break;
237 case sem::IntrinsicType::kAtomicAdd:
238 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
239 break;
240 case sem::IntrinsicType::kAtomicSub:
241 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
242 break;
243 case sem::IntrinsicType::kAtomicMax:
244 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
245 break;
246 case sem::IntrinsicType::kAtomicMin:
247 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
248 break;
249 case sem::IntrinsicType::kAtomicAnd:
250 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
251 break;
252 case sem::IntrinsicType::kAtomicOr:
253 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
254 break;
255 case sem::IntrinsicType::kAtomicXor:
256 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
257 break;
258 case sem::IntrinsicType::kAtomicExchange:
259 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
260 break;
261 case sem::IntrinsicType::kAtomicCompareExchangeWeak:
262 op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
263 break;
264 default:
265 TINT_ICE(Transform, builder->Diagnostics())
266 << "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
267 << ty->type_name();
268 break;
269 }
270
271 DecomposeMemoryAccess::Intrinsic::DataType type;
272 if (!IntrinsicDataTypeFor(ty, type)) {
273 return nullptr;
274 }
275 return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
276 builder->ID(), op, ast::StorageClass::kStorage, type);
277 }
278
279 /// BufferAccess describes a single storage or uniform buffer access
280 struct BufferAccess {
281 sem::Expression const* var = nullptr; // Storage buffer variable
282 Offset const* offset = nullptr; // The byte offset on var
283 sem::Type const* type = nullptr; // The type of the access
operator booltint::transform::__anon253b86840111::BufferAccess284 operator bool() const { return var; } // Returns true if valid
285 };
286
287 /// Store describes a single storage or uniform buffer write
288 struct Store {
289 ast::AssignmentStatement* assignment; // The AST assignment statement
290 BufferAccess target; // The target for the write
291 };
292
293 } // namespace
294
295 /// State holds the current transform state
296 struct DecomposeMemoryAccess::State {
297 /// The clone context
298 CloneContext& ctx;
299 /// Alias to `*ctx.dst`
300 ProgramBuilder& b;
301 /// Map of AST expression to storage or uniform buffer access
302 /// This map has entries added when encountered, and removed when outer
303 /// expressions chain the access.
304 /// Subset of #expression_order, as expressions are not removed from
305 /// #expression_order.
306 std::unordered_map<const ast::Expression*, BufferAccess> accesses;
307 /// The visited order of AST expressions (superset of #accesses)
308 std::vector<const ast::Expression*> expression_order;
309 /// [buffer-type, element-type] -> load function name
310 std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
311 /// [buffer-type, element-type] -> store function name
312 std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
313 /// [buffer-type, element-type, atomic-op] -> load function name
314 std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
315 /// List of storage or uniform buffer writes
316 std::vector<Store> stores;
317 /// Allocations for offsets
318 BlockAllocator<Offset> offsets_;
319
320 /// Constructor
321 /// @param context the CloneContext
Statetint::transform::DecomposeMemoryAccess::State322 explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
323
324 /// @param offset the offset value to wrap in an Offset
325 /// @returns an Offset for the given literal value
ToOffsettint::transform::DecomposeMemoryAccess::State326 const Offset* ToOffset(uint32_t offset) {
327 return offsets_.Create<OffsetLiteral>(offset);
328 }
329
330 /// @param expr the expression to convert to an Offset
331 /// @returns an Offset for the given ast::Expression
ToOffsettint::transform::DecomposeMemoryAccess::State332 const Offset* ToOffset(const ast::Expression* expr) {
333 if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
334 return offsets_.Create<OffsetLiteral>(u32->value);
335 } else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
336 if (i32->value > 0) {
337 return offsets_.Create<OffsetLiteral>(i32->value);
338 }
339 }
340 return offsets_.Create<OffsetExpr>(expr);
341 }
342
343 /// @param offset the Offset that is returned
344 /// @returns the given offset (pass-through)
ToOffsettint::transform::DecomposeMemoryAccess::State345 const Offset* ToOffset(const Offset* offset) { return offset; }
346
347 /// @param lhs_ the left-hand side of the add expression
348 /// @param rhs_ the right-hand side of the add expression
349 /// @return an Offset that is a sum of lhs and rhs, performing basic constant
350 /// folding if possible
351 template <typename LHS, typename RHS>
Addtint::transform::DecomposeMemoryAccess::State352 const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
353 auto* lhs = ToOffset(std::forward<LHS>(lhs_));
354 auto* rhs = ToOffset(std::forward<RHS>(rhs_));
355 auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
356 auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
357 if (lhs_lit && lhs_lit->literal == 0) {
358 return rhs;
359 }
360 if (rhs_lit && rhs_lit->literal == 0) {
361 return lhs;
362 }
363 if (lhs_lit && rhs_lit) {
364 if (static_cast<uint64_t>(lhs_lit->literal) +
365 static_cast<uint64_t>(rhs_lit->literal) <=
366 0xffffffff) {
367 return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
368 rhs_lit->literal);
369 }
370 }
371 auto* out = offsets_.Create<OffsetBinOp>();
372 out->op = ast::BinaryOp::kAdd;
373 out->lhs = lhs;
374 out->rhs = rhs;
375 return out;
376 }
377
378 /// @param lhs_ the left-hand side of the multiply expression
379 /// @param rhs_ the right-hand side of the multiply expression
380 /// @return an Offset that is the multiplication of lhs and rhs, performing
381 /// basic constant folding if possible
382 template <typename LHS, typename RHS>
Multint::transform::DecomposeMemoryAccess::State383 const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
384 auto* lhs = ToOffset(std::forward<LHS>(lhs_));
385 auto* rhs = ToOffset(std::forward<RHS>(rhs_));
386 auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
387 auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
388 if (lhs_lit && lhs_lit->literal == 0) {
389 return offsets_.Create<OffsetLiteral>(0);
390 }
391 if (rhs_lit && rhs_lit->literal == 0) {
392 return offsets_.Create<OffsetLiteral>(0);
393 }
394 if (lhs_lit && lhs_lit->literal == 1) {
395 return rhs;
396 }
397 if (rhs_lit && rhs_lit->literal == 1) {
398 return lhs;
399 }
400 if (lhs_lit && rhs_lit) {
401 return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
402 rhs_lit->literal);
403 }
404 auto* out = offsets_.Create<OffsetBinOp>();
405 out->op = ast::BinaryOp::kMultiply;
406 out->lhs = lhs;
407 out->rhs = rhs;
408 return out;
409 }
410
411 /// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
412 /// to #expression_order.
413 /// @param expr the expression that performs the access
414 /// @param access the access
AddAccesstint::transform::DecomposeMemoryAccess::State415 void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
416 TINT_ASSERT(Transform, access.type);
417 accesses.emplace(expr, access);
418 expression_order.emplace_back(expr);
419 }
420
421 /// TakeAccess() removes the `node` item from #accesses (if it exists),
422 /// returning the BufferAccess. If #accesses does not hold an item for
423 /// `node`, an invalid BufferAccess is returned.
424 /// @param node the expression that performed an access
425 /// @return the BufferAccess for the given expression
TakeAccesstint::transform::DecomposeMemoryAccess::State426 BufferAccess TakeAccess(const ast::Expression* node) {
427 auto lhs_it = accesses.find(node);
428 if (lhs_it == accesses.end()) {
429 return {};
430 }
431 auto access = lhs_it->second;
432 accesses.erase(node);
433 return access;
434 }
435
436 /// LoadFunc() returns a symbol to an intrinsic function that loads an element
437 /// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
438 /// The emitted function has the signature:
439 /// `fn load(buf : buf_ty, offset : u32) -> el_ty`
440 /// @param buf_ty the storage or uniform buffer type
441 /// @param el_ty the storage or uniform buffer element type
442 /// @param var_user the variable user
443 /// @return the name of the function that performs the load
LoadFunctint::transform::DecomposeMemoryAccess::State444 Symbol LoadFunc(const sem::Type* buf_ty,
445 const sem::Type* el_ty,
446 const sem::VariableUser* var_user) {
447 auto storage_class = var_user->Variable()->StorageClass();
448 return utils::GetOrCreate(
449 load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
450 auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
451 auto* disable_validation = b.Disable(
452 ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
453
454 ast::VariableList params = {
455 // Note: The buffer parameter requires the StorageClass in
456 // order for HLSL to emit this as a ByteAddressBuffer or cbuffer
457 // array.
458 b.create<ast::Variable>(b.Sym("buffer"), storage_class,
459 var_user->Variable()->Access(),
460 buf_ast_ty, true, nullptr,
461 ast::DecorationList{disable_validation}),
462 b.Param("offset", b.ty.u32()),
463 };
464
465 auto name = b.Sym();
466
467 if (auto* intrinsic =
468 IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
469 auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
470 auto* func = b.create<ast::Function>(
471 name, params, el_ast_ty, nullptr,
472 ast::DecorationList{
473 intrinsic,
474 b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
475 },
476 ast::DecorationList{});
477 b.AST().AddFunction(func);
478 } else if (auto* arr_ty = el_ty->As<sem::Array>()) {
479 // fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
480 // var arr : array<T, N>;
481 // for (var i = 0u; i < array_count; i = i + 1) {
482 // arr[i] = el_load_func(buf, offset + i * array_stride)
483 // }
484 // return arr;
485 // }
486 auto load =
487 LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
488 auto* arr =
489 b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
490 auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
491 auto* for_init = b.Decl(i);
492 auto* for_cond = b.create<ast::BinaryExpression>(
493 ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
494 auto* for_cont = b.Assign(i, b.Add(i, 1u));
495 auto* arr_el = b.IndexAccessor(arr, i);
496 auto* el_offset =
497 b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
498 auto* el_val = b.Call(load, "buffer", el_offset);
499 auto* for_loop = b.For(for_init, for_cond, for_cont,
500 b.Block(b.Assign(arr_el, el_val)));
501
502 b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
503 {
504 b.Decl(arr),
505 for_loop,
506 b.Return(arr),
507 });
508 } else {
509 ast::ExpressionList values;
510 if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
511 auto* vec_ty = mat_ty->ColumnType();
512 Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
513 for (uint32_t i = 0; i < mat_ty->columns(); i++) {
514 auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
515 values.emplace_back(b.Call(load, "buffer", offset));
516 }
517 } else if (auto* str = el_ty->As<sem::Struct>()) {
518 for (auto* member : str->Members()) {
519 auto* offset = b.Add("offset", member->Offset());
520 Symbol load =
521 LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
522 values.emplace_back(b.Call(load, "buffer", offset));
523 }
524 }
525 b.Func(
526 name, params, CreateASTTypeFor(ctx, el_ty),
527 {
528 b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
529 });
530 }
531 return name;
532 });
533 }
534
535 /// StoreFunc() returns a symbol to an intrinsic function that stores an
536 /// element of type `el_ty` to a storage buffer of type `buf_ty`.
537 /// The function has the signature:
538 /// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
539 /// @param buf_ty the storage buffer type
540 /// @param el_ty the storage buffer element type
541 /// @param var_user the variable user
542 /// @return the name of the function that performs the store
StoreFunctint::transform::DecomposeMemoryAccess::State543 Symbol StoreFunc(const sem::Type* buf_ty,
544 const sem::Type* el_ty,
545 const sem::VariableUser* var_user) {
546 auto storage_class = var_user->Variable()->StorageClass();
547 return utils::GetOrCreate(
548 store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
549 auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
550 auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
551 auto* disable_validation = b.Disable(
552 ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
553 ast::VariableList params{
554 // Note: The buffer parameter requires the StorageClass in
555 // order for HLSL to emit this as a ByteAddressBuffer.
556
557 b.create<ast::Variable>(b.Sym("buffer"), storage_class,
558 var_user->Variable()->Access(),
559 buf_ast_ty, true, nullptr,
560 ast::DecorationList{disable_validation}),
561 b.Param("offset", b.ty.u32()),
562 b.Param("value", el_ast_ty),
563 };
564
565 auto name = b.Sym();
566
567 if (auto* intrinsic =
568 IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
569 auto* func = b.create<ast::Function>(
570 name, params, b.ty.void_(), nullptr,
571 ast::DecorationList{
572 intrinsic,
573 b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
574 },
575 ast::DecorationList{});
576 b.AST().AddFunction(func);
577 } else {
578 ast::StatementList body;
579 if (auto* arr_ty = el_ty->As<sem::Array>()) {
580 // fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
581 // var array = value; // No dynamic indexing on constant arrays
582 // for (var i = 0u; i < array_count; i = i + 1) {
583 // arr[i] = el_store_func(buf, offset + i * array_stride,
584 // value[i])
585 // }
586 // return arr;
587 // }
588 auto* array =
589 b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
590 auto store =
591 StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
592 auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
593 auto* for_init = b.Decl(i);
594 auto* for_cond = b.create<ast::BinaryExpression>(
595 ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
596 auto* for_cont = b.Assign(i, b.Add(i, 1u));
597 auto* arr_el = b.IndexAccessor(array, i);
598 auto* el_offset =
599 b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
600 auto* store_stmt =
601 b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
602 auto* for_loop =
603 b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
604
605 body = {b.Decl(array), for_loop};
606 } else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
607 auto* vec_ty = mat_ty->ColumnType();
608 Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
609 for (uint32_t i = 0; i < mat_ty->columns(); i++) {
610 auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
611 auto* access = b.IndexAccessor("value", i);
612 auto* call = b.Call(store, "buffer", offset, access);
613 body.emplace_back(b.CallStmt(call));
614 }
615 } else if (auto* str = el_ty->As<sem::Struct>()) {
616 for (auto* member : str->Members()) {
617 auto* offset = b.Add("offset", member->Offset());
618 auto* access = b.MemberAccessor(
619 "value", ctx.Clone(member->Declaration()->symbol));
620 Symbol store =
621 StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
622 auto* call = b.Call(store, "buffer", offset, access);
623 body.emplace_back(b.CallStmt(call));
624 }
625 }
626 b.Func(name, params, b.ty.void_(), body);
627 }
628
629 return name;
630 });
631 }
632
633 /// AtomicFunc() returns a symbol to an intrinsic function that performs an
634 /// atomic operation from a storage buffer of type `buf_ty`. The function has
635 /// the signature:
636 // `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
637 /// @param buf_ty the storage buffer type
638 /// @param el_ty the storage buffer element type
639 /// @param intrinsic the atomic intrinsic
640 /// @param var_user the variable user
641 /// @return the name of the function that performs the load
AtomicFunctint::transform::DecomposeMemoryAccess::State642 Symbol AtomicFunc(const sem::Type* buf_ty,
643 const sem::Type* el_ty,
644 const sem::Intrinsic* intrinsic,
645 const sem::VariableUser* var_user) {
646 auto op = intrinsic->Type();
647 return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
648 auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
649 auto* disable_validation = b.Disable(
650 ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
651 // The first parameter to all WGSL atomics is the expression to the
652 // atomic. This is replaced with two parameters: the buffer and offset.
653
654 ast::VariableList params = {
655 // Note: The buffer parameter requires the kStorage StorageClass in
656 // order for HLSL to emit this as a ByteAddressBuffer.
657 b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
658 var_user->Variable()->Access(), buf_ast_ty,
659 true, nullptr,
660 ast::DecorationList{disable_validation}),
661 b.Param("offset", b.ty.u32()),
662 };
663
664 // Other parameters are copied as-is:
665 for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
666 auto* param = intrinsic->Parameters()[i];
667 auto* ty = CreateASTTypeFor(ctx, param->Type());
668 params.emplace_back(b.Param("param_" + std::to_string(i), ty));
669 }
670
671 auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
672 if (atomic == nullptr) {
673 TINT_ICE(Transform, b.Diagnostics())
674 << "IntrinsicAtomicFor() returned nullptr for op " << op
675 << " and type " << el_ty->type_name();
676 }
677
678 auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
679 auto* func = b.create<ast::Function>(
680 b.Sym(), params, ret_ty, nullptr,
681 ast::DecorationList{
682 atomic,
683 b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
684 },
685 ast::DecorationList{});
686
687 b.AST().AddFunction(func);
688 return func->symbol;
689 });
690 }
691 };
692
Intrinsic(ProgramID pid,Op o,ast::StorageClass sc,DataType ty)693 DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
694 Op o,
695 ast::StorageClass sc,
696 DataType ty)
697 : Base(pid), op(o), storage_class(sc), type(ty) {}
698 DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
InternalName() const699 std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
700 std::stringstream ss;
701 switch (op) {
702 case Op::kLoad:
703 ss << "intrinsic_load_";
704 break;
705 case Op::kStore:
706 ss << "intrinsic_store_";
707 break;
708 case Op::kAtomicLoad:
709 ss << "intrinsic_atomic_load_";
710 break;
711 case Op::kAtomicStore:
712 ss << "intrinsic_atomic_store_";
713 break;
714 case Op::kAtomicAdd:
715 ss << "intrinsic_atomic_add_";
716 break;
717 case Op::kAtomicSub:
718 ss << "intrinsic_atomic_sub_";
719 break;
720 case Op::kAtomicMax:
721 ss << "intrinsic_atomic_max_";
722 break;
723 case Op::kAtomicMin:
724 ss << "intrinsic_atomic_min_";
725 break;
726 case Op::kAtomicAnd:
727 ss << "intrinsic_atomic_and_";
728 break;
729 case Op::kAtomicOr:
730 ss << "intrinsic_atomic_or_";
731 break;
732 case Op::kAtomicXor:
733 ss << "intrinsic_atomic_xor_";
734 break;
735 case Op::kAtomicExchange:
736 ss << "intrinsic_atomic_exchange_";
737 break;
738 case Op::kAtomicCompareExchangeWeak:
739 ss << "intrinsic_atomic_compare_exchange_weak_";
740 break;
741 }
742 ss << storage_class << "_";
743 switch (type) {
744 case DataType::kU32:
745 ss << "u32";
746 break;
747 case DataType::kF32:
748 ss << "f32";
749 break;
750 case DataType::kI32:
751 ss << "i32";
752 break;
753 case DataType::kVec2U32:
754 ss << "vec2_u32";
755 break;
756 case DataType::kVec2F32:
757 ss << "vec2_f32";
758 break;
759 case DataType::kVec2I32:
760 ss << "vec2_i32";
761 break;
762 case DataType::kVec3U32:
763 ss << "vec3_u32";
764 break;
765 case DataType::kVec3F32:
766 ss << "vec3_f32";
767 break;
768 case DataType::kVec3I32:
769 ss << "vec3_i32";
770 break;
771 case DataType::kVec4U32:
772 ss << "vec4_u32";
773 break;
774 case DataType::kVec4F32:
775 ss << "vec4_f32";
776 break;
777 case DataType::kVec4I32:
778 ss << "vec4_i32";
779 break;
780 }
781 return ss.str();
782 }
783
Clone(CloneContext * ctx) const784 const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
785 CloneContext* ctx) const {
786 return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
787 ctx->dst->ID(), op, storage_class, type);
788 }
789
790 DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
791 DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
792
Run(CloneContext & ctx,const DataMap &,DataMap &)793 void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
794 auto& sem = ctx.src->Sem();
795
796 State state(ctx);
797
798 // Scan the AST nodes for storage and uniform buffer accesses. Complex
799 // expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
800 // maintaining an offset chain via the `state.TakeAccess()`,
801 // `state.AddAccess()` methods.
802 //
803 // Inner-most expression nodes are guaranteed to be visited first because AST
804 // nodes are fully immutable and require their children to be constructed
805 // first so their pointer can be passed to the parent's constructor.
806 for (auto* node : ctx.src->ASTNodes().Objects()) {
807 if (auto* ident = node->As<ast::IdentifierExpression>()) {
808 // X
809 if (auto* var = sem.Get<sem::VariableUser>(ident)) {
810 if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
811 var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
812 // Variable to a storage or uniform buffer
813 state.AddAccess(ident, {
814 var,
815 state.ToOffset(0u),
816 var->Type()->UnwrapRef(),
817 });
818 }
819 }
820 continue;
821 }
822
823 if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
824 // X.Y
825 auto* accessor_sem = sem.Get(accessor);
826 if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
827 if (swizzle->Indices().size() == 1) {
828 if (auto access = state.TakeAccess(accessor->structure)) {
829 auto* vec_ty = access.type->As<sem::Vector>();
830 auto* offset =
831 state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
832 state.AddAccess(accessor, {
833 access.var,
834 state.Add(access.offset, offset),
835 vec_ty->type()->UnwrapRef(),
836 });
837 }
838 }
839 } else {
840 if (auto access = state.TakeAccess(accessor->structure)) {
841 auto* str_ty = access.type->As<sem::Struct>();
842 auto* member = str_ty->FindMember(accessor->member->symbol);
843 auto offset = member->Offset();
844 state.AddAccess(accessor, {
845 access.var,
846 state.Add(access.offset, offset),
847 member->Type()->UnwrapRef(),
848 });
849 }
850 }
851 continue;
852 }
853
854 if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
855 if (auto access = state.TakeAccess(accessor->object)) {
856 // X[Y]
857 if (auto* arr = access.type->As<sem::Array>()) {
858 auto* offset = state.Mul(arr->Stride(), accessor->index);
859 state.AddAccess(accessor, {
860 access.var,
861 state.Add(access.offset, offset),
862 arr->ElemType()->UnwrapRef(),
863 });
864 continue;
865 }
866 if (auto* vec_ty = access.type->As<sem::Vector>()) {
867 auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
868 state.AddAccess(accessor, {
869 access.var,
870 state.Add(access.offset, offset),
871 vec_ty->type()->UnwrapRef(),
872 });
873 continue;
874 }
875 if (auto* mat_ty = access.type->As<sem::Matrix>()) {
876 auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
877 state.AddAccess(accessor, {
878 access.var,
879 state.Add(access.offset, offset),
880 mat_ty->ColumnType(),
881 });
882 continue;
883 }
884 }
885 }
886
887 if (auto* op = node->As<ast::UnaryOpExpression>()) {
888 if (op->op == ast::UnaryOp::kAddressOf) {
889 // &X
890 if (auto access = state.TakeAccess(op->expr)) {
891 // HLSL does not support pointers, so just take the access from the
892 // reference and place it on the pointer.
893 state.AddAccess(op, access);
894 continue;
895 }
896 }
897 }
898
899 if (auto* assign = node->As<ast::AssignmentStatement>()) {
900 // X = Y
901 // Move the LHS access to a store.
902 if (auto lhs = state.TakeAccess(assign->lhs)) {
903 state.stores.emplace_back(Store{assign, lhs});
904 }
905 }
906
907 if (auto* call_expr = node->As<ast::CallExpression>()) {
908 auto* call = sem.Get(call_expr);
909 if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
910 if (intrinsic->Type() == sem::IntrinsicType::kIgnore) { // [DEPRECATED]
911 // ignore(X)
912 // If X is an memory access, don't transform it into a load, as it
913 // may refer to a structure holding a runtime array, which cannot be
914 // loaded. Instead replace X with the underlying storage / uniform
915 // buffer variable.
916 if (auto access = state.TakeAccess(call_expr->args[0])) {
917 ctx.Replace(call_expr->args[0], [=, &ctx] {
918 return ctx.CloneWithoutTransform(access.var->Declaration());
919 });
920 }
921 continue;
922 }
923 if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
924 // arrayLength(X)
925 // Don't convert X into a load, this intrinsic actually requires the
926 // real pointer.
927 state.TakeAccess(call_expr->args[0]);
928 continue;
929 }
930 if (intrinsic->IsAtomic()) {
931 if (auto access = state.TakeAccess(call_expr->args[0])) {
932 // atomic___(X)
933 ctx.Replace(call_expr, [=, &ctx, &state] {
934 auto* buf = access.var->Declaration();
935 auto* offset = access.offset->Build(ctx);
936 auto* buf_ty = access.var->Type()->UnwrapRef();
937 auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
938 Symbol func =
939 state.AtomicFunc(buf_ty, el_ty, intrinsic,
940 access.var->As<sem::VariableUser>());
941
942 ast::ExpressionList args{ctx.Clone(buf), offset};
943 for (size_t i = 1; i < call_expr->args.size(); i++) {
944 auto* arg = call_expr->args[i];
945 args.emplace_back(ctx.Clone(arg));
946 }
947 return ctx.dst->Call(func, args);
948 });
949 }
950 }
951 }
952 }
953 }
954
955 // All remaining accesses are loads, transform these into calls to the
956 // corresponding load function
957 for (auto* expr : state.expression_order) {
958 auto access_it = state.accesses.find(expr);
959 if (access_it == state.accesses.end()) {
960 continue;
961 }
962 BufferAccess access = access_it->second;
963 ctx.Replace(expr, [=, &ctx, &state] {
964 auto* buf = access.var->Declaration();
965 auto* offset = access.offset->Build(ctx);
966 auto* buf_ty = access.var->Type()->UnwrapRef();
967 auto* el_ty = access.type->UnwrapRef();
968 Symbol func =
969 state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
970 return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
971 });
972 }
973
974 // And replace all storage and uniform buffer assignments with stores
975 for (auto store : state.stores) {
976 ctx.Replace(store.assignment, [=, &ctx, &state] {
977 auto* buf = store.target.var->Declaration();
978 auto* offset = store.target.offset->Build(ctx);
979 auto* buf_ty = store.target.var->Type()->UnwrapRef();
980 auto* el_ty = store.target.type->UnwrapRef();
981 auto* value = store.assignment->rhs;
982 Symbol func = state.StoreFunc(buf_ty, el_ty,
983 store.target.var->As<sem::VariableUser>());
984 auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
985 ctx.Clone(value));
986 return ctx.dst->CallStmt(call);
987 });
988 }
989
990 ctx.Clone();
991 }
992
993 } // namespace transform
994 } // namespace tint
995
996 TINT_INSTANTIATE_TYPEINFO(tint::transform::Offset);
997 TINT_INSTANTIATE_TYPEINFO(tint::transform::OffsetLiteral);
998