• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/decompose_memory_access.h"
16 
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "src/ast/assignment_statement.h"
24 #include "src/ast/call_statement.h"
25 #include "src/ast/disable_validation_decoration.h"
26 #include "src/ast/type_name.h"
27 #include "src/ast/unary_op.h"
28 #include "src/block_allocator.h"
29 #include "src/program_builder.h"
30 #include "src/sem/array.h"
31 #include "src/sem/atomic_type.h"
32 #include "src/sem/call.h"
33 #include "src/sem/member_accessor_expression.h"
34 #include "src/sem/reference_type.h"
35 #include "src/sem/statement.h"
36 #include "src/sem/struct.h"
37 #include "src/sem/variable.h"
38 #include "src/utils/map.h"
39 #include "src/utils/hash.h"
40 
41 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess);
42 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
43 
44 namespace tint {
45 namespace transform {
46 
47 namespace {
48 
49 /// Offset is a simple ast::Expression builder interface, used to build byte
50 /// offsets for storage and uniform buffer accesses.
51 struct Offset : Castable<Offset> {
52   /// @returns builds and returns the ast::Expression in `ctx.dst`
53   virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
54 };
55 
56 /// OffsetExpr is an implementation of Offset that clones and casts the given
57 /// expression to `u32`.
58 struct OffsetExpr : Offset {
59   const ast::Expression* const expr = nullptr;
60 
OffsetExprtint::transform::__anon253b86840111::OffsetExpr61   explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
62 
Buildtint::transform::__anon253b86840111::OffsetExpr63   const ast::Expression* Build(CloneContext& ctx) const override {
64     auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
65     auto* res = ctx.Clone(expr);
66     if (!type->Is<sem::U32>()) {
67       res = ctx.dst->Construct<ProgramBuilder::u32>(res);
68     }
69     return res;
70   }
71 };
72 
73 /// OffsetLiteral is an implementation of Offset that constructs a u32 literal
74 /// value.
75 struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
76   uint32_t const literal = 0;
77 
OffsetLiteraltint::transform::__anon253b86840111::OffsetLiteral78   explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
79 
Buildtint::transform::__anon253b86840111::OffsetLiteral80   const ast::Expression* Build(CloneContext& ctx) const override {
81     return ctx.dst->Expr(literal);
82   }
83 };
84 
85 /// OffsetBinOp is an implementation of Offset that constructs a binary-op of
86 /// two Offsets.
87 struct OffsetBinOp : Offset {
88   ast::BinaryOp op;
89   Offset const* lhs = nullptr;
90   Offset const* rhs = nullptr;
91 
Buildtint::transform::__anon253b86840111::OffsetBinOp92   const ast::Expression* Build(CloneContext& ctx) const override {
93     return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
94                                                   rhs->Build(ctx));
95   }
96 };
97 
98 /// LoadStoreKey is the unordered map key to a load or store intrinsic.
99 struct LoadStoreKey {
100   ast::StorageClass const storage_class;  // buffer storage class
101   sem::Type const* buf_ty = nullptr;      // buffer type
102   sem::Type const* el_ty = nullptr;       // element type
operator ==tint::transform::__anon253b86840111::LoadStoreKey103   bool operator==(const LoadStoreKey& rhs) const {
104     return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
105            el_ty == rhs.el_ty;
106   }
107   struct Hasher {
operator ()tint::transform::__anon253b86840111::LoadStoreKey::Hasher108     inline std::size_t operator()(const LoadStoreKey& u) const {
109       return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
110     }
111   };
112 };
113 
114 /// AtomicKey is the unordered map key to an atomic intrinsic.
115 struct AtomicKey {
116   sem::Type const* buf_ty = nullptr;  // buffer type
117   sem::Type const* el_ty = nullptr;   // element type
118   sem::IntrinsicType const op;        // atomic op
operator ==tint::transform::__anon253b86840111::AtomicKey119   bool operator==(const AtomicKey& rhs) const {
120     return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
121   }
122   struct Hasher {
operator ()tint::transform::__anon253b86840111::AtomicKey::Hasher123     inline std::size_t operator()(const AtomicKey& u) const {
124       return utils::Hash(u.buf_ty, u.el_ty, u.op);
125     }
126   };
127 };
128 
IntrinsicDataTypeFor(const sem::Type * ty,DecomposeMemoryAccess::Intrinsic::DataType & out)129 bool IntrinsicDataTypeFor(const sem::Type* ty,
130                           DecomposeMemoryAccess::Intrinsic::DataType& out) {
131   if (ty->Is<sem::I32>()) {
132     out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
133     return true;
134   }
135   if (ty->Is<sem::U32>()) {
136     out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
137     return true;
138   }
139   if (ty->Is<sem::F32>()) {
140     out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
141     return true;
142   }
143   if (auto* vec = ty->As<sem::Vector>()) {
144     switch (vec->Width()) {
145       case 2:
146         if (vec->type()->Is<sem::I32>()) {
147           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
148           return true;
149         }
150         if (vec->type()->Is<sem::U32>()) {
151           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
152           return true;
153         }
154         if (vec->type()->Is<sem::F32>()) {
155           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
156           return true;
157         }
158         break;
159       case 3:
160         if (vec->type()->Is<sem::I32>()) {
161           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
162           return true;
163         }
164         if (vec->type()->Is<sem::U32>()) {
165           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
166           return true;
167         }
168         if (vec->type()->Is<sem::F32>()) {
169           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
170           return true;
171         }
172         break;
173       case 4:
174         if (vec->type()->Is<sem::I32>()) {
175           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
176           return true;
177         }
178         if (vec->type()->Is<sem::U32>()) {
179           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
180           return true;
181         }
182         if (vec->type()->Is<sem::F32>()) {
183           out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
184           return true;
185         }
186         break;
187     }
188     return false;
189   }
190 
191   return false;
192 }
193 
194 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
195 /// to a stub function to load the type `ty`.
IntrinsicLoadFor(ProgramBuilder * builder,ast::StorageClass storage_class,const sem::Type * ty)196 DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
197     ProgramBuilder* builder,
198     ast::StorageClass storage_class,
199     const sem::Type* ty) {
200   DecomposeMemoryAccess::Intrinsic::DataType type;
201   if (!IntrinsicDataTypeFor(ty, type)) {
202     return nullptr;
203   }
204   return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
205       builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
206       type);
207 }
208 
209 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
210 /// to a stub function to store the type `ty`.
IntrinsicStoreFor(ProgramBuilder * builder,ast::StorageClass storage_class,const sem::Type * ty)211 DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
212     ProgramBuilder* builder,
213     ast::StorageClass storage_class,
214     const sem::Type* ty) {
215   DecomposeMemoryAccess::Intrinsic::DataType type;
216   if (!IntrinsicDataTypeFor(ty, type)) {
217     return nullptr;
218   }
219   return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
220       builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
221       storage_class, type);
222 }
223 
224 /// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
225 /// to a stub function for the atomic op and the type `ty`.
IntrinsicAtomicFor(ProgramBuilder * builder,sem::IntrinsicType ity,const sem::Type * ty)226 DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
227                                                      sem::IntrinsicType ity,
228                                                      const sem::Type* ty) {
229   auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
230   switch (ity) {
231     case sem::IntrinsicType::kAtomicLoad:
232       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
233       break;
234     case sem::IntrinsicType::kAtomicStore:
235       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
236       break;
237     case sem::IntrinsicType::kAtomicAdd:
238       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
239       break;
240     case sem::IntrinsicType::kAtomicSub:
241       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
242       break;
243     case sem::IntrinsicType::kAtomicMax:
244       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
245       break;
246     case sem::IntrinsicType::kAtomicMin:
247       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
248       break;
249     case sem::IntrinsicType::kAtomicAnd:
250       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
251       break;
252     case sem::IntrinsicType::kAtomicOr:
253       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
254       break;
255     case sem::IntrinsicType::kAtomicXor:
256       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
257       break;
258     case sem::IntrinsicType::kAtomicExchange:
259       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
260       break;
261     case sem::IntrinsicType::kAtomicCompareExchangeWeak:
262       op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
263       break;
264     default:
265       TINT_ICE(Transform, builder->Diagnostics())
266           << "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
267           << ty->type_name();
268       break;
269   }
270 
271   DecomposeMemoryAccess::Intrinsic::DataType type;
272   if (!IntrinsicDataTypeFor(ty, type)) {
273     return nullptr;
274   }
275   return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
276       builder->ID(), op, ast::StorageClass::kStorage, type);
277 }
278 
279 /// BufferAccess describes a single storage or uniform buffer access
280 struct BufferAccess {
281   sem::Expression const* var = nullptr;  // Storage buffer variable
282   Offset const* offset = nullptr;        // The byte offset on var
283   sem::Type const* type = nullptr;       // The type of the access
operator booltint::transform::__anon253b86840111::BufferAccess284   operator bool() const { return var; }  // Returns true if valid
285 };
286 
287 /// Store describes a single storage or uniform buffer write
288 struct Store {
289   ast::AssignmentStatement* assignment;  // The AST assignment statement
290   BufferAccess target;                   // The target for the write
291 };
292 
293 }  // namespace
294 
295 /// State holds the current transform state
296 struct DecomposeMemoryAccess::State {
297   /// The clone context
298   CloneContext& ctx;
299   /// Alias to `*ctx.dst`
300   ProgramBuilder& b;
301   /// Map of AST expression to storage or uniform buffer access
302   /// This map has entries added when encountered, and removed when outer
303   /// expressions chain the access.
304   /// Subset of #expression_order, as expressions are not removed from
305   /// #expression_order.
306   std::unordered_map<const ast::Expression*, BufferAccess> accesses;
307   /// The visited order of AST expressions (superset of #accesses)
308   std::vector<const ast::Expression*> expression_order;
309   /// [buffer-type, element-type] -> load function name
310   std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
311   /// [buffer-type, element-type] -> store function name
312   std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
313   /// [buffer-type, element-type, atomic-op] -> load function name
314   std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
315   /// List of storage or uniform buffer writes
316   std::vector<Store> stores;
317   /// Allocations for offsets
318   BlockAllocator<Offset> offsets_;
319 
320   /// Constructor
321   /// @param context the CloneContext
Statetint::transform::DecomposeMemoryAccess::State322   explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
323 
324   /// @param offset the offset value to wrap in an Offset
325   /// @returns an Offset for the given literal value
ToOffsettint::transform::DecomposeMemoryAccess::State326   const Offset* ToOffset(uint32_t offset) {
327     return offsets_.Create<OffsetLiteral>(offset);
328   }
329 
330   /// @param expr the expression to convert to an Offset
331   /// @returns an Offset for the given ast::Expression
ToOffsettint::transform::DecomposeMemoryAccess::State332   const Offset* ToOffset(const ast::Expression* expr) {
333     if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
334       return offsets_.Create<OffsetLiteral>(u32->value);
335     } else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
336       if (i32->value > 0) {
337         return offsets_.Create<OffsetLiteral>(i32->value);
338       }
339     }
340     return offsets_.Create<OffsetExpr>(expr);
341   }
342 
343   /// @param offset the Offset that is returned
344   /// @returns the given offset (pass-through)
ToOffsettint::transform::DecomposeMemoryAccess::State345   const Offset* ToOffset(const Offset* offset) { return offset; }
346 
347   /// @param lhs_ the left-hand side of the add expression
348   /// @param rhs_ the right-hand side of the add expression
349   /// @return an Offset that is a sum of lhs and rhs, performing basic constant
350   /// folding if possible
351   template <typename LHS, typename RHS>
Addtint::transform::DecomposeMemoryAccess::State352   const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
353     auto* lhs = ToOffset(std::forward<LHS>(lhs_));
354     auto* rhs = ToOffset(std::forward<RHS>(rhs_));
355     auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
356     auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
357     if (lhs_lit && lhs_lit->literal == 0) {
358       return rhs;
359     }
360     if (rhs_lit && rhs_lit->literal == 0) {
361       return lhs;
362     }
363     if (lhs_lit && rhs_lit) {
364       if (static_cast<uint64_t>(lhs_lit->literal) +
365               static_cast<uint64_t>(rhs_lit->literal) <=
366           0xffffffff) {
367         return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
368                                               rhs_lit->literal);
369       }
370     }
371     auto* out = offsets_.Create<OffsetBinOp>();
372     out->op = ast::BinaryOp::kAdd;
373     out->lhs = lhs;
374     out->rhs = rhs;
375     return out;
376   }
377 
378   /// @param lhs_ the left-hand side of the multiply expression
379   /// @param rhs_ the right-hand side of the multiply expression
380   /// @return an Offset that is the multiplication of lhs and rhs, performing
381   /// basic constant folding if possible
382   template <typename LHS, typename RHS>
Multint::transform::DecomposeMemoryAccess::State383   const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
384     auto* lhs = ToOffset(std::forward<LHS>(lhs_));
385     auto* rhs = ToOffset(std::forward<RHS>(rhs_));
386     auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
387     auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
388     if (lhs_lit && lhs_lit->literal == 0) {
389       return offsets_.Create<OffsetLiteral>(0);
390     }
391     if (rhs_lit && rhs_lit->literal == 0) {
392       return offsets_.Create<OffsetLiteral>(0);
393     }
394     if (lhs_lit && lhs_lit->literal == 1) {
395       return rhs;
396     }
397     if (rhs_lit && rhs_lit->literal == 1) {
398       return lhs;
399     }
400     if (lhs_lit && rhs_lit) {
401       return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
402                                             rhs_lit->literal);
403     }
404     auto* out = offsets_.Create<OffsetBinOp>();
405     out->op = ast::BinaryOp::kMultiply;
406     out->lhs = lhs;
407     out->rhs = rhs;
408     return out;
409   }
410 
411   /// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
412   /// to #expression_order.
413   /// @param expr the expression that performs the access
414   /// @param access the access
AddAccesstint::transform::DecomposeMemoryAccess::State415   void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
416     TINT_ASSERT(Transform, access.type);
417     accesses.emplace(expr, access);
418     expression_order.emplace_back(expr);
419   }
420 
421   /// TakeAccess() removes the `node` item from #accesses (if it exists),
422   /// returning the BufferAccess. If #accesses does not hold an item for
423   /// `node`, an invalid BufferAccess is returned.
424   /// @param node the expression that performed an access
425   /// @return the BufferAccess for the given expression
TakeAccesstint::transform::DecomposeMemoryAccess::State426   BufferAccess TakeAccess(const ast::Expression* node) {
427     auto lhs_it = accesses.find(node);
428     if (lhs_it == accesses.end()) {
429       return {};
430     }
431     auto access = lhs_it->second;
432     accesses.erase(node);
433     return access;
434   }
435 
436   /// LoadFunc() returns a symbol to an intrinsic function that loads an element
437   /// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
438   /// The emitted function has the signature:
439   ///   `fn load(buf : buf_ty, offset : u32) -> el_ty`
440   /// @param buf_ty the storage or uniform buffer type
441   /// @param el_ty the storage or uniform buffer element type
442   /// @param var_user the variable user
443   /// @return the name of the function that performs the load
LoadFunctint::transform::DecomposeMemoryAccess::State444   Symbol LoadFunc(const sem::Type* buf_ty,
445                   const sem::Type* el_ty,
446                   const sem::VariableUser* var_user) {
447     auto storage_class = var_user->Variable()->StorageClass();
448     return utils::GetOrCreate(
449         load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
450           auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
451           auto* disable_validation = b.Disable(
452               ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
453 
454           ast::VariableList params = {
455               // Note: The buffer parameter requires the StorageClass in
456               // order for HLSL to emit this as a ByteAddressBuffer or cbuffer
457               // array.
458               b.create<ast::Variable>(b.Sym("buffer"), storage_class,
459                                       var_user->Variable()->Access(),
460                                       buf_ast_ty, true, nullptr,
461                                       ast::DecorationList{disable_validation}),
462               b.Param("offset", b.ty.u32()),
463           };
464 
465           auto name = b.Sym();
466 
467           if (auto* intrinsic =
468                   IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
469             auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
470             auto* func = b.create<ast::Function>(
471                 name, params, el_ast_ty, nullptr,
472                 ast::DecorationList{
473                     intrinsic,
474                     b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
475                 },
476                 ast::DecorationList{});
477             b.AST().AddFunction(func);
478           } else if (auto* arr_ty = el_ty->As<sem::Array>()) {
479             // fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
480             //   var arr : array<T, N>;
481             //   for (var i = 0u; i < array_count; i = i + 1) {
482             //     arr[i] = el_load_func(buf, offset + i * array_stride)
483             //   }
484             //   return arr;
485             // }
486             auto load =
487                 LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
488             auto* arr =
489                 b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
490             auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
491             auto* for_init = b.Decl(i);
492             auto* for_cond = b.create<ast::BinaryExpression>(
493                 ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
494             auto* for_cont = b.Assign(i, b.Add(i, 1u));
495             auto* arr_el = b.IndexAccessor(arr, i);
496             auto* el_offset =
497                 b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
498             auto* el_val = b.Call(load, "buffer", el_offset);
499             auto* for_loop = b.For(for_init, for_cond, for_cont,
500                                    b.Block(b.Assign(arr_el, el_val)));
501 
502             b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
503                    {
504                        b.Decl(arr),
505                        for_loop,
506                        b.Return(arr),
507                    });
508           } else {
509             ast::ExpressionList values;
510             if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
511               auto* vec_ty = mat_ty->ColumnType();
512               Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
513               for (uint32_t i = 0; i < mat_ty->columns(); i++) {
514                 auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
515                 values.emplace_back(b.Call(load, "buffer", offset));
516               }
517             } else if (auto* str = el_ty->As<sem::Struct>()) {
518               for (auto* member : str->Members()) {
519                 auto* offset = b.Add("offset", member->Offset());
520                 Symbol load =
521                     LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
522                 values.emplace_back(b.Call(load, "buffer", offset));
523               }
524             }
525             b.Func(
526                 name, params, CreateASTTypeFor(ctx, el_ty),
527                 {
528                     b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
529                 });
530           }
531           return name;
532         });
533   }
534 
535   /// StoreFunc() returns a symbol to an intrinsic function that stores an
536   /// element of type `el_ty` to a storage buffer of type `buf_ty`.
537   /// The function has the signature:
538   ///   `fn store(buf : buf_ty, offset : u32, value : el_ty)`
539   /// @param buf_ty the storage buffer type
540   /// @param el_ty the storage buffer element type
541   /// @param var_user the variable user
542   /// @return the name of the function that performs the store
StoreFunctint::transform::DecomposeMemoryAccess::State543   Symbol StoreFunc(const sem::Type* buf_ty,
544                    const sem::Type* el_ty,
545                    const sem::VariableUser* var_user) {
546     auto storage_class = var_user->Variable()->StorageClass();
547     return utils::GetOrCreate(
548         store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
549           auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
550           auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
551           auto* disable_validation = b.Disable(
552               ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
553           ast::VariableList params{
554               // Note: The buffer parameter requires the StorageClass in
555               // order for HLSL to emit this as a ByteAddressBuffer.
556 
557               b.create<ast::Variable>(b.Sym("buffer"), storage_class,
558                                       var_user->Variable()->Access(),
559                                       buf_ast_ty, true, nullptr,
560                                       ast::DecorationList{disable_validation}),
561               b.Param("offset", b.ty.u32()),
562               b.Param("value", el_ast_ty),
563           };
564 
565           auto name = b.Sym();
566 
567           if (auto* intrinsic =
568                   IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
569             auto* func = b.create<ast::Function>(
570                 name, params, b.ty.void_(), nullptr,
571                 ast::DecorationList{
572                     intrinsic,
573                     b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
574                 },
575                 ast::DecorationList{});
576             b.AST().AddFunction(func);
577           } else {
578             ast::StatementList body;
579             if (auto* arr_ty = el_ty->As<sem::Array>()) {
580               // fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
581               //   var array = value; // No dynamic indexing on constant arrays
582               //   for (var i = 0u; i < array_count; i = i + 1) {
583               //     arr[i] = el_store_func(buf, offset + i * array_stride,
584               //                            value[i])
585               //   }
586               //   return arr;
587               // }
588               auto* array =
589                   b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
590               auto store =
591                   StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
592               auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
593               auto* for_init = b.Decl(i);
594               auto* for_cond = b.create<ast::BinaryExpression>(
595                   ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
596               auto* for_cont = b.Assign(i, b.Add(i, 1u));
597               auto* arr_el = b.IndexAccessor(array, i);
598               auto* el_offset =
599                   b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
600               auto* store_stmt =
601                   b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
602               auto* for_loop =
603                   b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
604 
605               body = {b.Decl(array), for_loop};
606             } else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
607               auto* vec_ty = mat_ty->ColumnType();
608               Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
609               for (uint32_t i = 0; i < mat_ty->columns(); i++) {
610                 auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
611                 auto* access = b.IndexAccessor("value", i);
612                 auto* call = b.Call(store, "buffer", offset, access);
613                 body.emplace_back(b.CallStmt(call));
614               }
615             } else if (auto* str = el_ty->As<sem::Struct>()) {
616               for (auto* member : str->Members()) {
617                 auto* offset = b.Add("offset", member->Offset());
618                 auto* access = b.MemberAccessor(
619                     "value", ctx.Clone(member->Declaration()->symbol));
620                 Symbol store =
621                     StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
622                 auto* call = b.Call(store, "buffer", offset, access);
623                 body.emplace_back(b.CallStmt(call));
624               }
625             }
626             b.Func(name, params, b.ty.void_(), body);
627           }
628 
629           return name;
630         });
631   }
632 
633   /// AtomicFunc() returns a symbol to an intrinsic function that performs an
634   /// atomic operation from a storage buffer of type `buf_ty`. The function has
635   /// the signature:
636   // `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
637   /// @param buf_ty the storage buffer type
638   /// @param el_ty the storage buffer element type
639   /// @param intrinsic the atomic intrinsic
640   /// @param var_user the variable user
641   /// @return the name of the function that performs the load
AtomicFunctint::transform::DecomposeMemoryAccess::State642   Symbol AtomicFunc(const sem::Type* buf_ty,
643                     const sem::Type* el_ty,
644                     const sem::Intrinsic* intrinsic,
645                     const sem::VariableUser* var_user) {
646     auto op = intrinsic->Type();
647     return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
648       auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
649       auto* disable_validation = b.Disable(
650           ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
651       // The first parameter to all WGSL atomics is the expression to the
652       // atomic. This is replaced with two parameters: the buffer and offset.
653 
654       ast::VariableList params = {
655           // Note: The buffer parameter requires the kStorage StorageClass in
656           // order for HLSL to emit this as a ByteAddressBuffer.
657           b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
658                                   var_user->Variable()->Access(), buf_ast_ty,
659                                   true, nullptr,
660                                   ast::DecorationList{disable_validation}),
661           b.Param("offset", b.ty.u32()),
662       };
663 
664       // Other parameters are copied as-is:
665       for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
666         auto* param = intrinsic->Parameters()[i];
667         auto* ty = CreateASTTypeFor(ctx, param->Type());
668         params.emplace_back(b.Param("param_" + std::to_string(i), ty));
669       }
670 
671       auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
672       if (atomic == nullptr) {
673         TINT_ICE(Transform, b.Diagnostics())
674             << "IntrinsicAtomicFor() returned nullptr for op " << op
675             << " and type " << el_ty->type_name();
676       }
677 
678       auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
679       auto* func = b.create<ast::Function>(
680           b.Sym(), params, ret_ty, nullptr,
681           ast::DecorationList{
682               atomic,
683               b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
684           },
685           ast::DecorationList{});
686 
687       b.AST().AddFunction(func);
688       return func->symbol;
689     });
690   }
691 };
692 
Intrinsic(ProgramID pid,Op o,ast::StorageClass sc,DataType ty)693 DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
694                                             Op o,
695                                             ast::StorageClass sc,
696                                             DataType ty)
697     : Base(pid), op(o), storage_class(sc), type(ty) {}
698 DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
InternalName() const699 std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
700   std::stringstream ss;
701   switch (op) {
702     case Op::kLoad:
703       ss << "intrinsic_load_";
704       break;
705     case Op::kStore:
706       ss << "intrinsic_store_";
707       break;
708     case Op::kAtomicLoad:
709       ss << "intrinsic_atomic_load_";
710       break;
711     case Op::kAtomicStore:
712       ss << "intrinsic_atomic_store_";
713       break;
714     case Op::kAtomicAdd:
715       ss << "intrinsic_atomic_add_";
716       break;
717     case Op::kAtomicSub:
718       ss << "intrinsic_atomic_sub_";
719       break;
720     case Op::kAtomicMax:
721       ss << "intrinsic_atomic_max_";
722       break;
723     case Op::kAtomicMin:
724       ss << "intrinsic_atomic_min_";
725       break;
726     case Op::kAtomicAnd:
727       ss << "intrinsic_atomic_and_";
728       break;
729     case Op::kAtomicOr:
730       ss << "intrinsic_atomic_or_";
731       break;
732     case Op::kAtomicXor:
733       ss << "intrinsic_atomic_xor_";
734       break;
735     case Op::kAtomicExchange:
736       ss << "intrinsic_atomic_exchange_";
737       break;
738     case Op::kAtomicCompareExchangeWeak:
739       ss << "intrinsic_atomic_compare_exchange_weak_";
740       break;
741   }
742   ss << storage_class << "_";
743   switch (type) {
744     case DataType::kU32:
745       ss << "u32";
746       break;
747     case DataType::kF32:
748       ss << "f32";
749       break;
750     case DataType::kI32:
751       ss << "i32";
752       break;
753     case DataType::kVec2U32:
754       ss << "vec2_u32";
755       break;
756     case DataType::kVec2F32:
757       ss << "vec2_f32";
758       break;
759     case DataType::kVec2I32:
760       ss << "vec2_i32";
761       break;
762     case DataType::kVec3U32:
763       ss << "vec3_u32";
764       break;
765     case DataType::kVec3F32:
766       ss << "vec3_f32";
767       break;
768     case DataType::kVec3I32:
769       ss << "vec3_i32";
770       break;
771     case DataType::kVec4U32:
772       ss << "vec4_u32";
773       break;
774     case DataType::kVec4F32:
775       ss << "vec4_f32";
776       break;
777     case DataType::kVec4I32:
778       ss << "vec4_i32";
779       break;
780   }
781   return ss.str();
782 }
783 
Clone(CloneContext * ctx) const784 const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
785     CloneContext* ctx) const {
786   return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
787       ctx->dst->ID(), op, storage_class, type);
788 }
789 
790 DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
791 DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
792 
Run(CloneContext & ctx,const DataMap &,DataMap &)793 void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
794   auto& sem = ctx.src->Sem();
795 
796   State state(ctx);
797 
798   // Scan the AST nodes for storage and uniform buffer accesses. Complex
799   // expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
800   // maintaining an offset chain via the `state.TakeAccess()`,
801   // `state.AddAccess()` methods.
802   //
803   // Inner-most expression nodes are guaranteed to be visited first because AST
804   // nodes are fully immutable and require their children to be constructed
805   // first so their pointer can be passed to the parent's constructor.
806   for (auto* node : ctx.src->ASTNodes().Objects()) {
807     if (auto* ident = node->As<ast::IdentifierExpression>()) {
808       // X
809       if (auto* var = sem.Get<sem::VariableUser>(ident)) {
810         if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
811             var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
812           // Variable to a storage or uniform buffer
813           state.AddAccess(ident, {
814                                      var,
815                                      state.ToOffset(0u),
816                                      var->Type()->UnwrapRef(),
817                                  });
818         }
819       }
820       continue;
821     }
822 
823     if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
824       // X.Y
825       auto* accessor_sem = sem.Get(accessor);
826       if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
827         if (swizzle->Indices().size() == 1) {
828           if (auto access = state.TakeAccess(accessor->structure)) {
829             auto* vec_ty = access.type->As<sem::Vector>();
830             auto* offset =
831                 state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
832             state.AddAccess(accessor, {
833                                           access.var,
834                                           state.Add(access.offset, offset),
835                                           vec_ty->type()->UnwrapRef(),
836                                       });
837           }
838         }
839       } else {
840         if (auto access = state.TakeAccess(accessor->structure)) {
841           auto* str_ty = access.type->As<sem::Struct>();
842           auto* member = str_ty->FindMember(accessor->member->symbol);
843           auto offset = member->Offset();
844           state.AddAccess(accessor, {
845                                         access.var,
846                                         state.Add(access.offset, offset),
847                                         member->Type()->UnwrapRef(),
848                                     });
849         }
850       }
851       continue;
852     }
853 
854     if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
855       if (auto access = state.TakeAccess(accessor->object)) {
856         // X[Y]
857         if (auto* arr = access.type->As<sem::Array>()) {
858           auto* offset = state.Mul(arr->Stride(), accessor->index);
859           state.AddAccess(accessor, {
860                                         access.var,
861                                         state.Add(access.offset, offset),
862                                         arr->ElemType()->UnwrapRef(),
863                                     });
864           continue;
865         }
866         if (auto* vec_ty = access.type->As<sem::Vector>()) {
867           auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
868           state.AddAccess(accessor, {
869                                         access.var,
870                                         state.Add(access.offset, offset),
871                                         vec_ty->type()->UnwrapRef(),
872                                     });
873           continue;
874         }
875         if (auto* mat_ty = access.type->As<sem::Matrix>()) {
876           auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
877           state.AddAccess(accessor, {
878                                         access.var,
879                                         state.Add(access.offset, offset),
880                                         mat_ty->ColumnType(),
881                                     });
882           continue;
883         }
884       }
885     }
886 
887     if (auto* op = node->As<ast::UnaryOpExpression>()) {
888       if (op->op == ast::UnaryOp::kAddressOf) {
889         // &X
890         if (auto access = state.TakeAccess(op->expr)) {
891           // HLSL does not support pointers, so just take the access from the
892           // reference and place it on the pointer.
893           state.AddAccess(op, access);
894           continue;
895         }
896       }
897     }
898 
899     if (auto* assign = node->As<ast::AssignmentStatement>()) {
900       // X = Y
901       // Move the LHS access to a store.
902       if (auto lhs = state.TakeAccess(assign->lhs)) {
903         state.stores.emplace_back(Store{assign, lhs});
904       }
905     }
906 
907     if (auto* call_expr = node->As<ast::CallExpression>()) {
908       auto* call = sem.Get(call_expr);
909       if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
910         if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {  // [DEPRECATED]
911           // ignore(X)
912           // If X is an memory access, don't transform it into a load, as it
913           // may refer to a structure holding a runtime array, which cannot be
914           // loaded. Instead replace X with the underlying storage / uniform
915           // buffer variable.
916           if (auto access = state.TakeAccess(call_expr->args[0])) {
917             ctx.Replace(call_expr->args[0], [=, &ctx] {
918               return ctx.CloneWithoutTransform(access.var->Declaration());
919             });
920           }
921           continue;
922         }
923         if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
924           // arrayLength(X)
925           // Don't convert X into a load, this intrinsic actually requires the
926           // real pointer.
927           state.TakeAccess(call_expr->args[0]);
928           continue;
929         }
930         if (intrinsic->IsAtomic()) {
931           if (auto access = state.TakeAccess(call_expr->args[0])) {
932             // atomic___(X)
933             ctx.Replace(call_expr, [=, &ctx, &state] {
934               auto* buf = access.var->Declaration();
935               auto* offset = access.offset->Build(ctx);
936               auto* buf_ty = access.var->Type()->UnwrapRef();
937               auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
938               Symbol func =
939                   state.AtomicFunc(buf_ty, el_ty, intrinsic,
940                                    access.var->As<sem::VariableUser>());
941 
942               ast::ExpressionList args{ctx.Clone(buf), offset};
943               for (size_t i = 1; i < call_expr->args.size(); i++) {
944                 auto* arg = call_expr->args[i];
945                 args.emplace_back(ctx.Clone(arg));
946               }
947               return ctx.dst->Call(func, args);
948             });
949           }
950         }
951       }
952     }
953   }
954 
955   // All remaining accesses are loads, transform these into calls to the
956   // corresponding load function
957   for (auto* expr : state.expression_order) {
958     auto access_it = state.accesses.find(expr);
959     if (access_it == state.accesses.end()) {
960       continue;
961     }
962     BufferAccess access = access_it->second;
963     ctx.Replace(expr, [=, &ctx, &state] {
964       auto* buf = access.var->Declaration();
965       auto* offset = access.offset->Build(ctx);
966       auto* buf_ty = access.var->Type()->UnwrapRef();
967       auto* el_ty = access.type->UnwrapRef();
968       Symbol func =
969           state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
970       return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
971     });
972   }
973 
974   // And replace all storage and uniform buffer assignments with stores
975   for (auto store : state.stores) {
976     ctx.Replace(store.assignment, [=, &ctx, &state] {
977       auto* buf = store.target.var->Declaration();
978       auto* offset = store.target.offset->Build(ctx);
979       auto* buf_ty = store.target.var->Type()->UnwrapRef();
980       auto* el_ty = store.target.type->UnwrapRef();
981       auto* value = store.assignment->rhs;
982       Symbol func = state.StoreFunc(buf_ty, el_ty,
983                                     store.target.var->As<sem::VariableUser>());
984       auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
985                                  ctx.Clone(value));
986       return ctx.dst->CallStmt(call);
987     });
988   }
989 
990   ctx.Clone();
991 }
992 
993 }  // namespace transform
994 }  // namespace tint
995 
996 TINT_INSTANTIATE_TYPEINFO(tint::transform::Offset);
997 TINT_INSTANTIATE_TYPEINFO(tint::transform::OffsetLiteral);
998