• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15 
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17 
18 namespace vkcompute {
19 
check_pool2d_args(const api::vTensor & in,const api::vTensor & out)20 void check_pool2d_args(const api::vTensor& in, const api::vTensor& out) {
21   VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
22   VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
23 }
24 
resize_pool2d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)25 void resize_pool2d_node(
26     ComputeGraph* graph,
27     const std::vector<ArgGroup>& args,
28     const std::vector<ValueRef>& extra_args) {
29   bool is_max_pool2d = extra_args[3] != kDummyValueRef;
30 
31   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
32   vTensorPtr self = graph->get_tensor(args[1].refs[0]);
33 
34   size_t ndim = self->sizes().size();
35   std::vector<int64_t> new_out_sizes(ndim);
36 
37   // Batch, Channel
38   if (ndim == 4) {
39     new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4);
40   }
41   new_out_sizes.at(ndim - 3) = self->sizes().at(ndim - 3);
42 
43   // Height, Width
44   const auto& new_out_sizes_hw = calc_out_sizes_hw(
45       *graph,
46       self->sizes(),
47       extra_args[0],
48       /*kernel_size_only = */ true,
49       {extra_args[1], extra_args[2], extra_args[3], extra_args[4]});
50   new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
51   new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
52 
53   out->virtual_resize(new_out_sizes);
54 
55   if (is_max_pool2d) {
56     vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
57     indices->virtual_resize(new_out_sizes);
58   }
59 }
60 
61 //
62 // max_pool2d
63 //
64 
add_max_pool2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef kernel_size,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef ceil_mode,const ValueRef out)65 void add_max_pool2d_node(
66     ComputeGraph& graph,
67     const ValueRef in,
68     const ValueRef kernel_size,
69     const ValueRef stride,
70     const ValueRef padding,
71     const ValueRef dilation,
72     const ValueRef ceil_mode,
73     const ValueRef out) {
74   vTensorPtr t_in = graph.get_tensor(in);
75 
76   const auto out_val = graph.get_value_list(out);
77   vTensorPtr t_out = graph.get_tensor(out_val->at(0));
78 
79   check_pool2d_args(*t_in, *t_out);
80 
81   utils::uvec3 global_size = t_out->logical_limits();
82   utils::uvec3 local_size = adaptive_work_group_size(global_size);
83 
84   std::string kernel_name("max_pool2d");
85   add_dtype_suffix(kernel_name, *t_out);
86 
87   Kernel2dParams kernel_params = create_kernel2d_params(
88       graph,
89       kernel_size,
90       /*kernel_size_only = */ true,
91       stride,
92       padding,
93       dilation);
94 
95   graph.execute_nodes().emplace_back(new DispatchNode(
96       graph,
97       VK_KERNEL_FROM_STR(kernel_name),
98       global_size,
99       local_size,
100       // Inputs and Outputs
101       {{{out_val->at(0), out_val->at(1)}, vkapi::MemoryAccessType::WRITE},
102        {in, vkapi::MemoryAccessType::READ}},
103       // Shader params buffers
104       {
105           t_out->logical_limits_ubo(),
106           t_in->sizes_ubo(),
107           graph.create_params_buffer(kernel_params),
108       },
109       // Specialization Constants
110       {},
111       // Resizing Logic
112       resize_pool2d_node,
113       {kernel_size, stride, padding, dilation, ceil_mode}));
114 }
115 
max_pool2d(ComputeGraph & graph,const std::vector<ValueRef> & args)116 void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
117   return add_max_pool2d_node(
118       graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
119 }
120 
121 //
122 // avg_pool2d
123 //
124 
125 struct DivisorParams final {
126   int32_t divisor_override;
127   bool count_include_pad;
128 };
129 
create_divisor_params(ComputeGraph & graph,const ValueRef divisor_override,const ValueRef count_include_pad)130 DivisorParams create_divisor_params(
131     ComputeGraph& graph,
132     const ValueRef divisor_override,
133     const ValueRef count_include_pad) {
134   return {
135       graph.val_is_int(divisor_override)
136           ? static_cast<int32_t>(graph.get_int(divisor_override))
137           : 0,
138       graph.get_bool(count_include_pad)};
139 }
140 
add_avg_pool2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef kernel_size,const ValueRef stride,const ValueRef padding,const ValueRef ceil_mode,const ValueRef count_include_pad,const ValueRef divisor_override,const ValueRef out)141 void add_avg_pool2d_node(
142     ComputeGraph& graph,
143     const ValueRef in,
144     const ValueRef kernel_size,
145     const ValueRef stride,
146     const ValueRef padding,
147     const ValueRef ceil_mode,
148     const ValueRef count_include_pad,
149     const ValueRef divisor_override,
150     const ValueRef out) {
151   vTensorPtr t_in = graph.get_tensor(in);
152   vTensorPtr t_out = graph.get_tensor(out);
153 
154   check_pool2d_args(*t_in, *t_out);
155 
156   utils::uvec3 global_size = t_out->logical_limits();
157   utils::uvec3 local_size = adaptive_work_group_size(global_size);
158 
159   std::string kernel_name("avg_pool2d");
160   add_dtype_suffix(kernel_name, *t_out);
161 
162   Kernel2dParams kernel_params =
163       create_kernel2d_params(graph, kernel_size, stride, padding);
164 
165   DivisorParams divisor_params =
166       create_divisor_params(graph, divisor_override, count_include_pad);
167 
168   graph.execute_nodes().emplace_back(new DispatchNode(
169       graph,
170       VK_KERNEL_FROM_STR(kernel_name),
171       global_size,
172       local_size,
173       // Inputs and Outputs
174       {{out, vkapi::MemoryAccessType::WRITE},
175        {in, vkapi::MemoryAccessType::READ}},
176       // Shader params buffers
177       {t_out->logical_limits_ubo(),
178        t_in->sizes_ubo(),
179        graph.create_params_buffer(kernel_params),
180        graph.create_params_buffer(divisor_params)},
181       // Specialization Constants
182       {},
183       // Resizing Logic
184       resize_pool2d_node,
185       {kernel_size,
186        stride,
187        padding,
188        /*dilation= */ kDummyValueRef,
189        ceil_mode}));
190 }
191 
avg_pool2d(ComputeGraph & graph,const std::vector<ValueRef> & args)192 void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
193   return add_avg_pool2d_node(
194       graph,
195       args[0],
196       args[1],
197       args[2],
198       args[3],
199       args[4],
200       args[5],
201       args[6],
202       args[7]);
203 }
204 
205 REGISTER_OPERATORS {
206   VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
207   VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
208 }
209 
210 } // namespace vkcompute
211