• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16 
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18 
19 namespace vkcompute {
20 
check_addmm_args(ComputeGraph & graph,const ValueRef self,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out)21 void check_addmm_args(
22     ComputeGraph& graph,
23     const ValueRef self,
24     const ValueRef mat1,
25     const ValueRef mat2_data,
26     const ValueRef beta,
27     const ValueRef alpha,
28     const ValueRef out) {
29   (void)alpha;
30   (void)beta;
31 
32   std::vector<int64_t> self_sizes = graph.sizes_of(self);
33   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
34   std::vector<int64_t> mat2_sizes = graph.sizes_of(mat2_data);
35 
36   VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
37   VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());
38 
39   VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
40 
41   VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));
42 
43   if (utils::val_at(-1, self_sizes) != 1) {
44     VK_CHECK_COND(
45         utils::val_at(-1, self_sizes) == utils::val_at(-1, mat2_sizes));
46   }
47   if (utils::val_at(-2, self_sizes) != 1) {
48     VK_CHECK_COND(
49         utils::val_at(-2, self_sizes) == utils::val_at(-2, mat1_sizes));
50   }
51 }
52 
resize_addmm_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)53 void resize_addmm_node(
54     ComputeGraph* graph,
55     const std::vector<ArgGroup>& args,
56     const std::vector<ValueRef>& extra_args) {
57   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
58   vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
59   vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
60   vTensorPtr self = graph->get_tensor(args[1].refs[2]);
61 
62   bool mat2_is_transposed = graph->get_bool(extra_args[0]);
63 
64   const int out_cols = utils::val_at(-2, mat1->sizes());
65   const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
66                                           : utils::val_at(-1, mat2->sizes());
67 
68   std::vector<int64_t> new_out_sizes(3);
69   if (mat1->sizes().size() == 2) {
70     new_out_sizes.resize(2);
71     new_out_sizes.at(0) = out_cols;
72     new_out_sizes.at(1) = out_rows;
73   } else {
74     new_out_sizes.at(0) = mat1->sizes().at(0);
75     new_out_sizes.at(1) = out_cols;
76     new_out_sizes.at(2) = out_rows;
77   }
78 
79   out->virtual_resize(new_out_sizes);
80 }
81 
82 struct Params final {
83   float alpha;
84   float beta;
85 };
86 
add_addmm_naive_node(ComputeGraph & graph,const ValueRef self_data,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out,const Params & params,const ValueRef mat2_is_transposed)87 void add_addmm_naive_node(
88     ComputeGraph& graph,
89     const ValueRef self_data,
90     const ValueRef mat1,
91     const ValueRef mat2_data,
92     const ValueRef beta,
93     const ValueRef alpha,
94     const ValueRef out,
95     const Params& params,
96     const ValueRef mat2_is_transposed) {
97   utils::StorageType stype = graph.storage_type_of(out);
98   ValueRef self = prepack_standard(
99       graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true);
100   ValueRef mat2 = prepack_standard(
101       graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true);
102 
103   std::string kernel_name =
104       graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
105   kernel_name.reserve(kShaderNameReserve);
106   add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
107   add_dtype_suffix(kernel_name, graph.dtype_of(out));
108 
109   utils::uvec3 global_wg_size = graph.logical_limits_of(out);
110   graph.execute_nodes().emplace_back(new DispatchNode(
111       graph,
112       VK_KERNEL_FROM_STR(kernel_name),
113       global_wg_size,
114       graph.create_local_wg_size(global_wg_size),
115       // Inputs and Outputs
116       {{out, vkapi::MemoryAccessType::WRITE},
117        {{mat1, mat2, self}, vkapi::MemoryAccessType::READ}},
118       // Shader params buffers
119       {
120           graph.sizes_ubo(out),
121           graph.logical_limits_ubo(out),
122           graph.sizes_ubo(mat1),
123           graph.sizes_ubo(mat2),
124           graph.sizes_ubo(self),
125           graph.create_params_buffer(params),
126       },
127       // Specialization Constants
128       {graph.hashed_layout_of(out),
129        graph.hashed_layout_of(mat1),
130        graph.hashed_layout_of(mat2),
131        graph.hashed_layout_of(self)},
132       // Resizing Logic
133       resize_addmm_node,
134       {mat2_is_transposed}));
135 }
136 
add_addmm_optimized_node(ComputeGraph & graph,const ValueRef self_data,const ValueRef mat1,const ValueRef mat2_data,const ValueRef beta,const ValueRef alpha,const ValueRef out,const Params & params,const ValueRef mat2_is_transposed)137 void add_addmm_optimized_node(
138     ComputeGraph& graph,
139     const ValueRef self_data,
140     const ValueRef mat1,
141     const ValueRef mat2_data,
142     const ValueRef beta,
143     const ValueRef alpha,
144     const ValueRef out,
145     const Params& params,
146     const ValueRef mat2_is_transposed) {
147   utils::StorageType stype = graph.storage_type_of(out);
148   ValueRef self = prepack_standard(
149       graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true);
150   ValueRef mat2 = prepack_standard(
151       graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true);
152 
153   // Ensure mat1 is width packed
154   ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
155   auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
156   viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
157 
158   const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed);
159 
160   // Ensure mat2 is height packed
161   ValueRef mat2_packed = mat2;
162   const utils::GPUMemoryLayout mat2_layout =
163       mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
164   if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
165     mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
166     viewFn(graph, {mat2, graph.add_none(), mat2_packed});
167   }
168 
169   std::string kernel_name = graph.get_bool(mat2_is_transposed)
170       ? "linear_optimized"
171       : "addmm_optimized";
172 
173   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
174   int mat1_dims = mat1_sizes.size();
175   if (mat1_dims == 3) {
176     kernel_name = "batch_" + kernel_name;
177   }
178   if (mat1_sizes.at(mat1_dims - 2) < 8) {
179     kernel_name += "_tile_row_2";
180   } else {
181     kernel_name += "_tile_row_4";
182   }
183 
184   add_dtype_suffix(kernel_name, graph.dtype_of(out));
185 
186   utils::uvec3 global_size = graph.logical_limits_of(out);
187 
188   // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
189   // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
190   // channels packed, C does not need to be divided by 4. The "identity" of each
191   // thread is the (x, y, z) coordinate of the output tile it is computing, and
192   // this identity can be used to compute the tensor index of the top left
193   // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
194   if (mat1_sizes.at(mat1_dims - 2) < 8) {
195     // Use `logical_extents` instead of `image_extents` because the workgroup
196     // axes need to correspond to tensor dimensions.
197     global_size = utils::divup_vec(global_size, {4, 2, 1});
198   } else {
199     global_size = utils::divup_vec(global_size, {4, 4, 1});
200   }
201   utils::uvec3 local_size = adaptive_work_group_size(global_size);
202 
203   graph.execute_nodes().emplace_back(new DispatchNode(
204       graph,
205       VK_KERNEL_FROM_STR(kernel_name),
206       global_size,
207       local_size,
208       // Inputs and Outputs
209       {{out, vkapi::MemoryAccessType::WRITE},
210        {{mat1_W_packed, mat2_packed, self}, vkapi::MemoryAccessType::READ}},
211       // Shader params buffers
212       {
213           graph.sizes_ubo(out),
214           graph.sizes_ubo(mat1_W_packed),
215           graph.sizes_ubo(mat2_packed),
216           graph.sizes_ubo(self),
217           graph.create_params_buffer(params),
218       },
219       // Specialization Constants
220       {graph.hashed_layout_of(out),
221        graph.hashed_layout_of(mat1_W_packed),
222        graph.hashed_layout_of(mat2_packed),
223        graph.hashed_layout_of(self)},
224       // Resizing Logic
225       resize_addmm_node,
226       {mat2_is_transposed}));
227 }
228 
add_addmm_node(ComputeGraph & graph,const ValueRef self,const ValueRef mat1,const ValueRef mat2,const ValueRef beta,const ValueRef alpha,const ValueRef out,const ValueRef mat2_is_transposed)229 void add_addmm_node(
230     ComputeGraph& graph,
231     const ValueRef self,
232     const ValueRef mat1,
233     const ValueRef mat2,
234     const ValueRef beta,
235     const ValueRef alpha,
236     const ValueRef out,
237     const ValueRef mat2_is_transposed) {
238   float alpha_val = 1.0f;
239   float beta_val = 1.0f;
240 
241   if (alpha != kDummyValueRef) {
242     alpha_val = graph.extract_scalar<float>(alpha);
243   }
244   if (beta != kDummyValueRef) {
245     beta_val = graph.extract_scalar<float>(beta);
246   }
247 
248   Params params = {alpha_val, beta_val};
249   if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
250     add_addmm_optimized_node(
251         graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
252   } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
253     add_addmm_naive_node(
254         graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
255   } else {
256     VK_THROW("Input should be channel packed or width packed.");
257   }
258 }
259 
addmm(ComputeGraph & graph,const std::vector<ValueRef> & args)260 void addmm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
261   check_addmm_args(graph, args[0], args[1], args[2], args[3], args[4], args[5]);
262   ValueRef mat2_is_transposed = graph.add_scalar(false);
263   return add_addmm_node(
264       graph,
265       args[0],
266       args[1],
267       args[2],
268       args[3],
269       args[4],
270       args[5],
271       mat2_is_transposed);
272 }
273 
linear(ComputeGraph & graph,const std::vector<ValueRef> & args)274 void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
275   ValueRef input = args.at(0);
276   ValueRef weight_data = args.at(1);
277   ValueRef bias = args.at(2);
278   ValueRef out = args.at(3);
279   ValueRef weight = prepack_standard(
280       graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
281   ValueRef mat2_is_transposed = graph.add_scalar(true);
282 
283   if (graph.val_is_none(bias)) {
284     return add_matmul_node(graph, input, weight, out, mat2_is_transposed);
285   } else {
286     // Buffer implementation does not yet support biases
287     VK_CHECK_COND(!graph.is_buffer_storage(out));
288     return add_addmm_node(
289         graph,
290         bias,
291         input,
292         weight,
293         kDummyValueRef,
294         kDummyValueRef,
295         out,
296         mat2_is_transposed);
297   }
298 }
299 
300 REGISTER_OPERATORS {
301   VK_REGISTER_OP(aten.addmm.default, addmm);
302   VK_REGISTER_OP(aten.linear.default, linear);
303 }
304 
305 } // namespace vkcompute
306