1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18
19 namespace vkcompute {
20
check_addmm_args(ComputeGraph & graph,const ValueRef self,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out)21 void check_addmm_args(
22 ComputeGraph& graph,
23 const ValueRef self,
24 const ValueRef mat1,
25 const ValueRef mat2_data,
26 const ValueRef beta,
27 const ValueRef alpha,
28 const ValueRef out) {
29 (void)alpha;
30 (void)beta;
31
32 std::vector<int64_t> self_sizes = graph.sizes_of(self);
33 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
34 std::vector<int64_t> mat2_sizes = graph.sizes_of(mat2_data);
35
36 VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
37 VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());
38
39 VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
40
41 VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));
42
43 if (utils::val_at(-1, self_sizes) != 1) {
44 VK_CHECK_COND(
45 utils::val_at(-1, self_sizes) == utils::val_at(-1, mat2_sizes));
46 }
47 if (utils::val_at(-2, self_sizes) != 1) {
48 VK_CHECK_COND(
49 utils::val_at(-2, self_sizes) == utils::val_at(-2, mat1_sizes));
50 }
51 }
52
resize_addmm_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)53 void resize_addmm_node(
54 ComputeGraph* graph,
55 const std::vector<ArgGroup>& args,
56 const std::vector<ValueRef>& extra_args) {
57 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
58 vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
59 vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
60 vTensorPtr self = graph->get_tensor(args[1].refs[2]);
61
62 bool mat2_is_transposed = graph->get_bool(extra_args[0]);
63
64 const int out_cols = utils::val_at(-2, mat1->sizes());
65 const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
66 : utils::val_at(-1, mat2->sizes());
67
68 std::vector<int64_t> new_out_sizes(3);
69 if (mat1->sizes().size() == 2) {
70 new_out_sizes.resize(2);
71 new_out_sizes.at(0) = out_cols;
72 new_out_sizes.at(1) = out_rows;
73 } else {
74 new_out_sizes.at(0) = mat1->sizes().at(0);
75 new_out_sizes.at(1) = out_cols;
76 new_out_sizes.at(2) = out_rows;
77 }
78
79 out->virtual_resize(new_out_sizes);
80 }
81
82 struct Params final {
83 float alpha;
84 float beta;
85 };
86
add_addmm_naive_node(ComputeGraph & graph,const ValueRef self_data,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out,const Params & params,const ValueRef mat2_is_transposed)87 void add_addmm_naive_node(
88 ComputeGraph& graph,
89 const ValueRef self_data,
90 const ValueRef mat1,
91 const ValueRef mat2_data,
92 const ValueRef beta,
93 const ValueRef alpha,
94 const ValueRef out,
95 const Params& params,
96 const ValueRef mat2_is_transposed) {
97 utils::StorageType stype = graph.storage_type_of(out);
98 ValueRef self = prepack_standard(
99 graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true);
100 ValueRef mat2 = prepack_standard(
101 graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true);
102
103 std::string kernel_name =
104 graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
105 kernel_name.reserve(kShaderNameReserve);
106 add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
107 add_dtype_suffix(kernel_name, graph.dtype_of(out));
108
109 utils::uvec3 global_wg_size = graph.logical_limits_of(out);
110 graph.execute_nodes().emplace_back(new DispatchNode(
111 graph,
112 VK_KERNEL_FROM_STR(kernel_name),
113 global_wg_size,
114 graph.create_local_wg_size(global_wg_size),
115 // Inputs and Outputs
116 {{out, vkapi::MemoryAccessType::WRITE},
117 {{mat1, mat2, self}, vkapi::MemoryAccessType::READ}},
118 // Shader params buffers
119 {
120 graph.sizes_ubo(out),
121 graph.logical_limits_ubo(out),
122 graph.sizes_ubo(mat1),
123 graph.sizes_ubo(mat2),
124 graph.sizes_ubo(self),
125 graph.create_params_buffer(params),
126 },
127 // Specialization Constants
128 {graph.hashed_layout_of(out),
129 graph.hashed_layout_of(mat1),
130 graph.hashed_layout_of(mat2),
131 graph.hashed_layout_of(self)},
132 // Resizing Logic
133 resize_addmm_node,
134 {mat2_is_transposed}));
135 }
136
add_addmm_optimized_node(ComputeGraph & graph,const ValueRef self_data,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out,const Params & params,const ValueRef mat2_is_transposed)137 void add_addmm_optimized_node(
138 ComputeGraph& graph,
139 const ValueRef self_data,
140 const ValueRef mat1,
141 const ValueRef mat2_data,
142 const ValueRef beta,
143 const ValueRef alpha,
144 const ValueRef out,
145 const Params& params,
146 const ValueRef mat2_is_transposed) {
147 utils::StorageType stype = graph.storage_type_of(out);
148 ValueRef self = prepack_standard(
149 graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true);
150 ValueRef mat2 = prepack_standard(
151 graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true);
152
153 // Ensure mat1 is width packed
154 ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
155 auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
156 viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
157
158 const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed);
159
160 // Ensure mat2 is height packed
161 ValueRef mat2_packed = mat2;
162 const utils::GPUMemoryLayout mat2_layout =
163 mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
164 if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
165 mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
166 viewFn(graph, {mat2, graph.add_none(), mat2_packed});
167 }
168
169 std::string kernel_name = graph.get_bool(mat2_is_transposed)
170 ? "linear_optimized"
171 : "addmm_optimized";
172
173 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
174 int mat1_dims = mat1_sizes.size();
175 if (mat1_dims == 3) {
176 kernel_name = "batch_" + kernel_name;
177 }
178 if (mat1_sizes.at(mat1_dims - 2) < 8) {
179 kernel_name += "_tile_row_2";
180 } else {
181 kernel_name += "_tile_row_4";
182 }
183
184 add_dtype_suffix(kernel_name, graph.dtype_of(out));
185
186 utils::uvec3 global_size = graph.logical_limits_of(out);
187
188 // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
189 // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
190 // channels packed, C does not need to be divided by 4. The "identity" of each
191 // thread is the (x, y, z) coordinate of the output tile it is computing, and
192 // this identity can be used to compute the tensor index of the top left
193 // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
194 if (mat1_sizes.at(mat1_dims - 2) < 8) {
195 // Use `logical_extents` instead of `image_extents` because the workgroup
196 // axes need to correspond to tensor dimensions.
197 global_size = utils::divup_vec(global_size, {4, 2, 1});
198 } else {
199 global_size = utils::divup_vec(global_size, {4, 4, 1});
200 }
201 utils::uvec3 local_size = adaptive_work_group_size(global_size);
202
203 graph.execute_nodes().emplace_back(new DispatchNode(
204 graph,
205 VK_KERNEL_FROM_STR(kernel_name),
206 global_size,
207 local_size,
208 // Inputs and Outputs
209 {{out, vkapi::MemoryAccessType::WRITE},
210 {{mat1_W_packed, mat2_packed, self}, vkapi::MemoryAccessType::READ}},
211 // Shader params buffers
212 {
213 graph.sizes_ubo(out),
214 graph.sizes_ubo(mat1_W_packed),
215 graph.sizes_ubo(mat2_packed),
216 graph.sizes_ubo(self),
217 graph.create_params_buffer(params),
218 },
219 // Specialization Constants
220 {graph.hashed_layout_of(out),
221 graph.hashed_layout_of(mat1_W_packed),
222 graph.hashed_layout_of(mat2_packed),
223 graph.hashed_layout_of(self)},
224 // Resizing Logic
225 resize_addmm_node,
226 {mat2_is_transposed}));
227 }
228
add_addmm_node(ComputeGraph & graph,const ValueRef self,const ValueRef mat1,const ValueRef mat2,const ValueRef beta,const ValueRef alpha,const ValueRef out,const ValueRef mat2_is_transposed)229 void add_addmm_node(
230 ComputeGraph& graph,
231 const ValueRef self,
232 const ValueRef mat1,
233 const ValueRef mat2,
234 const ValueRef beta,
235 const ValueRef alpha,
236 const ValueRef out,
237 const ValueRef mat2_is_transposed) {
238 float alpha_val = 1.0f;
239 float beta_val = 1.0f;
240
241 if (alpha != kDummyValueRef) {
242 alpha_val = graph.extract_scalar<float>(alpha);
243 }
244 if (beta != kDummyValueRef) {
245 beta_val = graph.extract_scalar<float>(beta);
246 }
247
248 Params params = {alpha_val, beta_val};
249 if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
250 add_addmm_optimized_node(
251 graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
252 } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
253 add_addmm_naive_node(
254 graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
255 } else {
256 VK_THROW("Input should be channel packed or width packed.");
257 }
258 }
259
addmm(ComputeGraph & graph,const std::vector<ValueRef> & args)260 void addmm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
261 check_addmm_args(graph, args[0], args[1], args[2], args[3], args[4], args[5]);
262 ValueRef mat2_is_transposed = graph.add_scalar(false);
263 return add_addmm_node(
264 graph,
265 args[0],
266 args[1],
267 args[2],
268 args[3],
269 args[4],
270 args[5],
271 mat2_is_transposed);
272 }
273
linear(ComputeGraph & graph,const std::vector<ValueRef> & args)274 void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
275 ValueRef input = args.at(0);
276 ValueRef weight_data = args.at(1);
277 ValueRef bias = args.at(2);
278 ValueRef out = args.at(3);
279 ValueRef weight = prepack_standard(
280 graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
281 ValueRef mat2_is_transposed = graph.add_scalar(true);
282
283 if (graph.val_is_none(bias)) {
284 return add_matmul_node(graph, input, weight, out, mat2_is_transposed);
285 } else {
286 // Buffer implementation does not yet support biases
287 VK_CHECK_COND(!graph.is_buffer_storage(out));
288 return add_addmm_node(
289 graph,
290 bias,
291 input,
292 weight,
293 kDummyValueRef,
294 kDummyValueRef,
295 out,
296 mat2_is_transposed);
297 }
298 }
299
300 REGISTER_OPERATORS {
301 VK_REGISTER_OP(aten.addmm.default, addmm);
302 VK_REGISTER_OP(aten.linear.default, linear);
303 }
304
305 } // namespace vkcompute
306