1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17
18 namespace vkcompute {
19
check_pool2d_args(const api::vTensor & in,const api::vTensor & out)20 void check_pool2d_args(const api::vTensor& in, const api::vTensor& out) {
21 VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
22 VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
23 }
24
resize_pool2d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)25 void resize_pool2d_node(
26 ComputeGraph* graph,
27 const std::vector<ArgGroup>& args,
28 const std::vector<ValueRef>& extra_args) {
29 bool is_max_pool2d = extra_args[3] != kDummyValueRef;
30
31 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
32 vTensorPtr self = graph->get_tensor(args[1].refs[0]);
33
34 size_t ndim = self->sizes().size();
35 std::vector<int64_t> new_out_sizes(ndim);
36
37 // Batch, Channel
38 if (ndim == 4) {
39 new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4);
40 }
41 new_out_sizes.at(ndim - 3) = self->sizes().at(ndim - 3);
42
43 // Height, Width
44 const auto& new_out_sizes_hw = calc_out_sizes_hw(
45 *graph,
46 self->sizes(),
47 extra_args[0],
48 /*kernel_size_only = */ true,
49 {extra_args[1], extra_args[2], extra_args[3], extra_args[4]});
50 new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
51 new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
52
53 out->virtual_resize(new_out_sizes);
54
55 if (is_max_pool2d) {
56 vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
57 indices->virtual_resize(new_out_sizes);
58 }
59 }
60
61 //
62 // max_pool2d
63 //
64
add_max_pool2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef kernel_size,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef ceil_mode,const ValueRef out)65 void add_max_pool2d_node(
66 ComputeGraph& graph,
67 const ValueRef in,
68 const ValueRef kernel_size,
69 const ValueRef stride,
70 const ValueRef padding,
71 const ValueRef dilation,
72 const ValueRef ceil_mode,
73 const ValueRef out) {
74 vTensorPtr t_in = graph.get_tensor(in);
75
76 const auto out_val = graph.get_value_list(out);
77 vTensorPtr t_out = graph.get_tensor(out_val->at(0));
78
79 check_pool2d_args(*t_in, *t_out);
80
81 utils::uvec3 global_size = t_out->logical_limits();
82 utils::uvec3 local_size = adaptive_work_group_size(global_size);
83
84 std::string kernel_name("max_pool2d");
85 add_dtype_suffix(kernel_name, *t_out);
86
87 Kernel2dParams kernel_params = create_kernel2d_params(
88 graph,
89 kernel_size,
90 /*kernel_size_only = */ true,
91 stride,
92 padding,
93 dilation);
94
95 graph.execute_nodes().emplace_back(new DispatchNode(
96 graph,
97 VK_KERNEL_FROM_STR(kernel_name),
98 global_size,
99 local_size,
100 // Inputs and Outputs
101 {{{out_val->at(0), out_val->at(1)}, vkapi::MemoryAccessType::WRITE},
102 {in, vkapi::MemoryAccessType::READ}},
103 // Shader params buffers
104 {
105 t_out->logical_limits_ubo(),
106 t_in->sizes_ubo(),
107 graph.create_params_buffer(kernel_params),
108 },
109 // Specialization Constants
110 {},
111 // Resizing Logic
112 resize_pool2d_node,
113 {kernel_size, stride, padding, dilation, ceil_mode}));
114 }
115
max_pool2d(ComputeGraph & graph,const std::vector<ValueRef> & args)116 void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
117 return add_max_pool2d_node(
118 graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
119 }
120
121 //
122 // avg_pool2d
123 //
124
125 struct DivisorParams final {
126 int32_t divisor_override;
127 bool count_include_pad;
128 };
129
create_divisor_params(ComputeGraph & graph,const ValueRef divisor_override,const ValueRef count_include_pad)130 DivisorParams create_divisor_params(
131 ComputeGraph& graph,
132 const ValueRef divisor_override,
133 const ValueRef count_include_pad) {
134 return {
135 graph.val_is_int(divisor_override)
136 ? static_cast<int32_t>(graph.get_int(divisor_override))
137 : 0,
138 graph.get_bool(count_include_pad)};
139 }
140
add_avg_pool2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef kernel_size,const ValueRef stride,const ValueRef padding,const ValueRef ceil_mode,const ValueRef count_include_pad,const ValueRef divisor_override,const ValueRef out)141 void add_avg_pool2d_node(
142 ComputeGraph& graph,
143 const ValueRef in,
144 const ValueRef kernel_size,
145 const ValueRef stride,
146 const ValueRef padding,
147 const ValueRef ceil_mode,
148 const ValueRef count_include_pad,
149 const ValueRef divisor_override,
150 const ValueRef out) {
151 vTensorPtr t_in = graph.get_tensor(in);
152 vTensorPtr t_out = graph.get_tensor(out);
153
154 check_pool2d_args(*t_in, *t_out);
155
156 utils::uvec3 global_size = t_out->logical_limits();
157 utils::uvec3 local_size = adaptive_work_group_size(global_size);
158
159 std::string kernel_name("avg_pool2d");
160 add_dtype_suffix(kernel_name, *t_out);
161
162 Kernel2dParams kernel_params =
163 create_kernel2d_params(graph, kernel_size, stride, padding);
164
165 DivisorParams divisor_params =
166 create_divisor_params(graph, divisor_override, count_include_pad);
167
168 graph.execute_nodes().emplace_back(new DispatchNode(
169 graph,
170 VK_KERNEL_FROM_STR(kernel_name),
171 global_size,
172 local_size,
173 // Inputs and Outputs
174 {{out, vkapi::MemoryAccessType::WRITE},
175 {in, vkapi::MemoryAccessType::READ}},
176 // Shader params buffers
177 {t_out->logical_limits_ubo(),
178 t_in->sizes_ubo(),
179 graph.create_params_buffer(kernel_params),
180 graph.create_params_buffer(divisor_params)},
181 // Specialization Constants
182 {},
183 // Resizing Logic
184 resize_pool2d_node,
185 {kernel_size,
186 stride,
187 padding,
188 /*dilation= */ kDummyValueRef,
189 ceil_mode}));
190 }
191
avg_pool2d(ComputeGraph & graph,const std::vector<ValueRef> & args)192 void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
193 return add_avg_pool2d_node(
194 graph,
195 args[0],
196 args[1],
197 args[2],
198 args[3],
199 args[4],
200 args[5],
201 args[6],
202 args[7]);
203 }
204
205 REGISTER_OPERATORS {
206 VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
207 VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
208 }
209
210 } // namespace vkcompute
211