• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/decompose_strided_matrix.h"
16 
17 #include <unordered_map>
18 #include <utility>
19 #include <vector>
20 
21 #include "src/program_builder.h"
22 #include "src/sem/expression.h"
23 #include "src/sem/member_accessor_expression.h"
24 #include "src/transform/simplify_pointers.h"
25 #include "src/utils/hash.h"
26 #include "src/utils/map.h"
27 
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
29 
30 namespace tint {
31 namespace transform {
32 namespace {
33 
34 /// MatrixInfo describes a matrix member with a custom stride
35 struct MatrixInfo {
36   /// The stride in bytes between columns of the matrix
37   uint32_t stride = 0;
38   /// The type of the matrix
39   const sem::Matrix* matrix = nullptr;
40 
41   /// @returns a new ast::Array that holds an vector column for each row of the
42   /// matrix.
arraytint::transform::__anon971f5c9d0111::MatrixInfo43   const ast::Array* array(ProgramBuilder* b) const {
44     return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
45                        matrix->columns(), stride);
46   }
47 
48   /// Equality operator
operator ==tint::transform::__anon971f5c9d0111::MatrixInfo49   bool operator==(const MatrixInfo& info) const {
50     return stride == info.stride && matrix == info.matrix;
51   }
52   /// Hash function
53   struct Hasher {
operator ()tint::transform::__anon971f5c9d0111::MatrixInfo::Hasher54     size_t operator()(const MatrixInfo& t) const {
55       return utils::Hash(t.stride, t.matrix);
56     }
57   };
58 };
59 
60 /// Return type of the callback function of GatherCustomStrideMatrixMembers
61 enum GatherResult { kContinue, kStop };
62 
63 /// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
64 /// storage and uniform structs, which are of a matrix type, and have a custom
65 /// matrix stride attribute. For each matrix member found, `callback` is called.
66 /// `callback` is a function with the signature:
67 ///      GatherResult(const sem::StructMember* member,
68 ///                   sem::Matrix* matrix,
69 ///                   uint32_t stride)
70 /// If `callback` return GatherResult::kStop, then the scanning will immediately
71 /// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
72 /// scanning will continue.
73 template <typename F>
GatherCustomStrideMatrixMembers(const Program * program,F && callback)74 void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
75   for (auto* node : program->ASTNodes().Objects()) {
76     if (auto* str = node->As<ast::Struct>()) {
77       auto* str_ty = program->Sem().Get(str);
78       if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
79           !str_ty->UsedAs(ast::StorageClass::kStorage)) {
80         continue;
81       }
82       for (auto* member : str_ty->Members()) {
83         auto* matrix = member->Type()->As<sem::Matrix>();
84         if (!matrix) {
85           continue;
86         }
87         auto* deco = ast::GetDecoration<ast::StrideDecoration>(
88             member->Declaration()->decorations);
89         if (!deco) {
90           continue;
91         }
92         uint32_t stride = deco->stride;
93         if (matrix->ColumnStride() == stride) {
94           continue;
95         }
96         if (callback(member, matrix, stride) == GatherResult::kStop) {
97           return;
98         }
99       }
100     }
101   }
102 }
103 
104 }  // namespace
105 
106 DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
107 
108 DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
109 
ShouldRun(const Program * program)110 bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
111   bool should_run = false;
112   GatherCustomStrideMatrixMembers(
113       program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
114         should_run = true;
115         return GatherResult::kStop;
116       });
117   return should_run;
118 }
119 
Run(CloneContext & ctx,const DataMap &,DataMap &)120 void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
121   if (!Requires<SimplifyPointers>(ctx)) {
122     return;
123   }
124 
125   // Scan the program for all storage and uniform structure matrix members with
126   // a custom stride attribute. Replace these matrices with an equivalent array,
127   // and populate the `decomposed` map with the members that have been replaced.
128   std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
129   GatherCustomStrideMatrixMembers(
130       ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
131                    uint32_t stride) {
132         // We've got ourselves a struct member of a matrix type with a custom
133         // stride. Replace this with an array of column vectors.
134         MatrixInfo info{stride, matrix};
135         auto* replacement = ctx.dst->Member(
136             member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
137         ctx.Replace(member->Declaration(), replacement);
138         decomposed.emplace(member->Declaration(), info);
139         return GatherResult::kContinue;
140       });
141 
142   // For all expressions where a single matrix column vector was indexed, we can
143   // preserve these without calling conversion functions.
144   // Example:
145   //   ssbo.mat[2] -> ssbo.mat[2]
146   ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr)
147                      -> const ast::IndexAccessorExpression* {
148     if (auto* access =
149             ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
150       auto it = decomposed.find(access->Member()->Declaration());
151       if (it != decomposed.end()) {
152         auto* obj = ctx.CloneWithoutTransform(expr->object);
153         auto* idx = ctx.Clone(expr->index);
154         return ctx.dst->IndexAccessor(obj, idx);
155       }
156     }
157     return nullptr;
158   });
159 
160   // For all struct member accesses to the matrix on the LHS of an assignment,
161   // we need to convert the matrix to the array before assigning to the
162   // structure.
163   // Example:
164   //   ssbo.mat = mat_to_arr(m)
165   std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
166   ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
167                      -> const ast::Statement* {
168     if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
169       auto it = decomposed.find(access->Member()->Declaration());
170       if (it == decomposed.end()) {
171         return nullptr;
172       }
173       MatrixInfo info = it->second;
174       auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
175         auto name = ctx.dst->Symbols().New(
176             "mat" + std::to_string(info.matrix->columns()) + "x" +
177             std::to_string(info.matrix->rows()) + "_stride_" +
178             std::to_string(info.stride) + "_to_arr");
179 
180         auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
181         auto array = [&] { return info.array(ctx.dst); };
182 
183         auto mat = ctx.dst->Sym("mat");
184         ast::ExpressionList columns(info.matrix->columns());
185         for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
186           columns[i] = ctx.dst->IndexAccessor(mat, i);
187         }
188         ctx.dst->Func(name,
189                       {
190                           ctx.dst->Param(mat, matrix()),
191                       },
192                       array(),
193                       {
194                           ctx.dst->Return(ctx.dst->Construct(array(), columns)),
195                       });
196         return name;
197       });
198       auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
199       auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
200       return ctx.dst->Assign(lhs, rhs);
201     }
202     return nullptr;
203   });
204 
205   // For all other struct member accesses, we need to convert the array to the
206   // matrix type. Example:
207   //   m = arr_to_mat(ssbo.mat)
208   std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
209   ctx.ReplaceAll(
210       [&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
211         if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
212           auto it = decomposed.find(access->Member()->Declaration());
213           if (it == decomposed.end()) {
214             return nullptr;
215           }
216           MatrixInfo info = it->second;
217           auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
218             auto name = ctx.dst->Symbols().New(
219                 "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
220                 std::to_string(info.matrix->rows()) + "_stride_" +
221                 std::to_string(info.stride));
222 
223             auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
224             auto array = [&] { return info.array(ctx.dst); };
225 
226             auto arr = ctx.dst->Sym("arr");
227             ast::ExpressionList columns(info.matrix->columns());
228             for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
229                  i++) {
230               columns[i] = ctx.dst->IndexAccessor(arr, i);
231             }
232             ctx.dst->Func(
233                 name,
234                 {
235                     ctx.dst->Param(arr, array()),
236                 },
237                 matrix(),
238                 {
239                     ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
240                 });
241             return name;
242           });
243           return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
244         }
245         return nullptr;
246       });
247 
248   ctx.Clone();
249 }
250 
251 }  // namespace transform
252 }  // namespace tint
253