1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/decompose_strided_matrix.h"
16
17 #include <unordered_map>
18 #include <utility>
19 #include <vector>
20
21 #include "src/program_builder.h"
22 #include "src/sem/expression.h"
23 #include "src/sem/member_accessor_expression.h"
24 #include "src/transform/simplify_pointers.h"
25 #include "src/utils/hash.h"
26 #include "src/utils/map.h"
27
28 TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
29
30 namespace tint {
31 namespace transform {
32 namespace {
33
34 /// MatrixInfo describes a matrix member with a custom stride
35 struct MatrixInfo {
36 /// The stride in bytes between columns of the matrix
37 uint32_t stride = 0;
38 /// The type of the matrix
39 const sem::Matrix* matrix = nullptr;
40
41 /// @returns a new ast::Array that holds an vector column for each row of the
42 /// matrix.
arraytint::transform::__anon971f5c9d0111::MatrixInfo43 const ast::Array* array(ProgramBuilder* b) const {
44 return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
45 matrix->columns(), stride);
46 }
47
48 /// Equality operator
operator ==tint::transform::__anon971f5c9d0111::MatrixInfo49 bool operator==(const MatrixInfo& info) const {
50 return stride == info.stride && matrix == info.matrix;
51 }
52 /// Hash function
53 struct Hasher {
operator ()tint::transform::__anon971f5c9d0111::MatrixInfo::Hasher54 size_t operator()(const MatrixInfo& t) const {
55 return utils::Hash(t.stride, t.matrix);
56 }
57 };
58 };
59
60 /// Return type of the callback function of GatherCustomStrideMatrixMembers
61 enum GatherResult { kContinue, kStop };
62
63 /// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
64 /// storage and uniform structs, which are of a matrix type, and have a custom
65 /// matrix stride attribute. For each matrix member found, `callback` is called.
66 /// `callback` is a function with the signature:
67 /// GatherResult(const sem::StructMember* member,
68 /// sem::Matrix* matrix,
69 /// uint32_t stride)
70 /// If `callback` return GatherResult::kStop, then the scanning will immediately
71 /// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
72 /// scanning will continue.
73 template <typename F>
GatherCustomStrideMatrixMembers(const Program * program,F && callback)74 void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
75 for (auto* node : program->ASTNodes().Objects()) {
76 if (auto* str = node->As<ast::Struct>()) {
77 auto* str_ty = program->Sem().Get(str);
78 if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
79 !str_ty->UsedAs(ast::StorageClass::kStorage)) {
80 continue;
81 }
82 for (auto* member : str_ty->Members()) {
83 auto* matrix = member->Type()->As<sem::Matrix>();
84 if (!matrix) {
85 continue;
86 }
87 auto* deco = ast::GetDecoration<ast::StrideDecoration>(
88 member->Declaration()->decorations);
89 if (!deco) {
90 continue;
91 }
92 uint32_t stride = deco->stride;
93 if (matrix->ColumnStride() == stride) {
94 continue;
95 }
96 if (callback(member, matrix, stride) == GatherResult::kStop) {
97 return;
98 }
99 }
100 }
101 }
102 }
103
104 } // namespace
105
106 DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
107
108 DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
109
ShouldRun(const Program * program)110 bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
111 bool should_run = false;
112 GatherCustomStrideMatrixMembers(
113 program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
114 should_run = true;
115 return GatherResult::kStop;
116 });
117 return should_run;
118 }
119
Run(CloneContext & ctx,const DataMap &,DataMap &)120 void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
121 if (!Requires<SimplifyPointers>(ctx)) {
122 return;
123 }
124
125 // Scan the program for all storage and uniform structure matrix members with
126 // a custom stride attribute. Replace these matrices with an equivalent array,
127 // and populate the `decomposed` map with the members that have been replaced.
128 std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
129 GatherCustomStrideMatrixMembers(
130 ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
131 uint32_t stride) {
132 // We've got ourselves a struct member of a matrix type with a custom
133 // stride. Replace this with an array of column vectors.
134 MatrixInfo info{stride, matrix};
135 auto* replacement = ctx.dst->Member(
136 member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
137 ctx.Replace(member->Declaration(), replacement);
138 decomposed.emplace(member->Declaration(), info);
139 return GatherResult::kContinue;
140 });
141
142 // For all expressions where a single matrix column vector was indexed, we can
143 // preserve these without calling conversion functions.
144 // Example:
145 // ssbo.mat[2] -> ssbo.mat[2]
146 ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr)
147 -> const ast::IndexAccessorExpression* {
148 if (auto* access =
149 ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
150 auto it = decomposed.find(access->Member()->Declaration());
151 if (it != decomposed.end()) {
152 auto* obj = ctx.CloneWithoutTransform(expr->object);
153 auto* idx = ctx.Clone(expr->index);
154 return ctx.dst->IndexAccessor(obj, idx);
155 }
156 }
157 return nullptr;
158 });
159
160 // For all struct member accesses to the matrix on the LHS of an assignment,
161 // we need to convert the matrix to the array before assigning to the
162 // structure.
163 // Example:
164 // ssbo.mat = mat_to_arr(m)
165 std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
166 ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
167 -> const ast::Statement* {
168 if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
169 auto it = decomposed.find(access->Member()->Declaration());
170 if (it == decomposed.end()) {
171 return nullptr;
172 }
173 MatrixInfo info = it->second;
174 auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
175 auto name = ctx.dst->Symbols().New(
176 "mat" + std::to_string(info.matrix->columns()) + "x" +
177 std::to_string(info.matrix->rows()) + "_stride_" +
178 std::to_string(info.stride) + "_to_arr");
179
180 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
181 auto array = [&] { return info.array(ctx.dst); };
182
183 auto mat = ctx.dst->Sym("mat");
184 ast::ExpressionList columns(info.matrix->columns());
185 for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
186 columns[i] = ctx.dst->IndexAccessor(mat, i);
187 }
188 ctx.dst->Func(name,
189 {
190 ctx.dst->Param(mat, matrix()),
191 },
192 array(),
193 {
194 ctx.dst->Return(ctx.dst->Construct(array(), columns)),
195 });
196 return name;
197 });
198 auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
199 auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
200 return ctx.dst->Assign(lhs, rhs);
201 }
202 return nullptr;
203 });
204
205 // For all other struct member accesses, we need to convert the array to the
206 // matrix type. Example:
207 // m = arr_to_mat(ssbo.mat)
208 std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
209 ctx.ReplaceAll(
210 [&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
211 if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
212 auto it = decomposed.find(access->Member()->Declaration());
213 if (it == decomposed.end()) {
214 return nullptr;
215 }
216 MatrixInfo info = it->second;
217 auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
218 auto name = ctx.dst->Symbols().New(
219 "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
220 std::to_string(info.matrix->rows()) + "_stride_" +
221 std::to_string(info.stride));
222
223 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
224 auto array = [&] { return info.array(ctx.dst); };
225
226 auto arr = ctx.dst->Sym("arr");
227 ast::ExpressionList columns(info.matrix->columns());
228 for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
229 i++) {
230 columns[i] = ctx.dst->IndexAccessor(arr, i);
231 }
232 ctx.dst->Func(
233 name,
234 {
235 ctx.dst->Param(arr, array()),
236 },
237 matrix(),
238 {
239 ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
240 });
241 return name;
242 });
243 return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
244 }
245 return nullptr;
246 });
247
248 ctx.Clone();
249 }
250
251 } // namespace transform
252 } // namespace tint
253