1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
14
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
17 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
18 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
19
20 namespace vkcompute {
21
normalize_idx(const int64_t index,const int64_t max,const int64_t default_value)22 inline int64_t normalize_idx(
23 const int64_t index,
24 const int64_t max,
25 const int64_t default_value) {
26 // INT64_MAX is passed when value is unspecified
27 if (index == INT64_MAX) {
28 return default_value;
29 }
30 if (index == default_value) {
31 return index;
32 }
33 return normalize(index, max);
34 }
35
add_slice_tensor_copy_node(ComputeGraph & graph,ValueRef in,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef step_ref,ValueRef out)36 void add_slice_tensor_copy_node(
37 ComputeGraph& graph,
38 ValueRef in,
39 ValueRef dim_ref,
40 ValueRef opt_start_ref,
41 ValueRef opt_end_ref,
42 ValueRef step_ref,
43 ValueRef out) {
44 vTensorPtr t_in = graph.get_tensor(in);
45 vTensorPtr t_out = graph.get_tensor(out);
46
47 VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim));
48 VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim));
49
50 // Need normalize the dim
51 int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
52
53 VK_CHECK_COND(
54 -t_in->dim() <= dim && dim < t_in->dim(),
55 "dim must be in range of [-self.dim(), self.dim()), but current dim's value is ",
56 dim,
57 " and self.dim() = ",
58 t_in->dim());
59
60 dim = normalize(dim, t_in->dim());
61
62 DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
63
64 std::optional<int64_t> opt_start =
65 graph.extract_optional_scalar<int64_t>(opt_start_ref);
66 std::optional<int64_t> opt_end =
67 graph.extract_optional_scalar<int64_t>(opt_end_ref);
68 int64_t step = graph.extract_scalar<int64_t>(step_ref);
69
70 const auto in_sizes = t_in->sizes();
71 const auto out_sizes = t_out->sizes();
72
73 int64_t start = opt_start.value_or(0);
74 int64_t end = opt_end.value_or(in_sizes[dim]);
75
76 start = normalize_idx(start, in_sizes[dim], 0);
77 end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);
78
79 if (dim_index == kChannel4D) {
80 // slice by channel
81 std::string kernel_name = "slice_channel";
82 kernel_name.reserve(kShaderNameReserve);
83 add_dtype_suffix(kernel_name, *t_out);
84
85 const struct Block final {
86 int offset;
87 int step;
88 } params{
89 static_cast<int32_t>(start),
90 static_cast<int32_t>(step),
91 };
92
93 graph.execute_nodes().emplace_back(new DispatchNode(
94 graph,
95 VK_KERNEL_FROM_STR(kernel_name),
96 graph.create_global_wg_size(out),
97 graph.create_local_wg_size(out),
98 {{out, vkapi::MemoryAccessType::WRITE},
99 {in, vkapi::MemoryAccessType::READ}},
100 {t_out->sizes_ubo(),
101 t_in->sizes_ubo(),
102 graph.create_params_buffer(params)}));
103
104 } else {
105 // GPU's coordinate is in x, y, z
106 int64_t gpu_dim = -1;
107 int64_t stride = 1;
108 if (dim_index == kWidth4D) {
109 gpu_dim = 0; // width: x dimension in gpu
110 VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
111 } else if (dim_index == kHeight4D) {
112 gpu_dim = 1; // height: y dimension
113 VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
114 } else if (dim_index == kBatch4D) {
115 gpu_dim = 2; // batch: z dimension
116
117 // Due to channel packing, each batch value is span over stride planes
118 int64_t n_channels = dim_at(in_sizes, kChannel4D);
119 stride = utils::div_up_4(n_channels);
120 } else {
121 VK_THROW("Unexpected ncwh_dim!");
122 }
123
124 std::string kernel_name = "slice_batch_height_width";
125 kernel_name.reserve(kShaderNameReserve);
126 add_dtype_suffix(kernel_name, *t_out);
127
128 utils::uvec3 global_size = t_out->logical_limits();
129 utils::uvec3 local_size = adaptive_work_group_size(global_size);
130
131 const struct Block final {
132 int dim;
133 int offset;
134 int step;
135 int stride;
136 } params{
137 static_cast<int32_t>(gpu_dim),
138 static_cast<int32_t>(start),
139 static_cast<int32_t>(step),
140 static_cast<int32_t>(stride),
141 };
142
143 graph.execute_nodes().emplace_back(new DispatchNode(
144 graph,
145 VK_KERNEL_FROM_STR(kernel_name),
146 global_size,
147 local_size,
148 {{out, vkapi::MemoryAccessType::WRITE},
149 {in, vkapi::MemoryAccessType::READ}},
150 {t_out->sizes_ubo(), graph.create_params_buffer(params)}));
151 }
152 }
153
get_slice_sizes(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref)154 std::vector<int64_t> get_slice_sizes(
155 ComputeGraph& graph,
156 ValueRef in_ref,
157 ValueRef dim_ref,
158 ValueRef opt_start_ref,
159 ValueRef opt_end_ref) {
160 const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
161 std::optional<int64_t> opt_start =
162 graph.extract_optional_scalar<int64_t>(opt_start_ref);
163 std::optional<int64_t> opt_end =
164 graph.extract_optional_scalar<int64_t>(opt_end_ref);
165
166 int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
167 int64_t start = opt_start.value_or(0);
168 int64_t end = opt_end.value_or(dim_size);
169
170 start = normalize_idx(start, dim_size, 0);
171 end = normalize_idx(end, dim_size, dim_size);
172
173 std::vector<int64_t> new_out_sizes = graph.sizes_of(in_ref);
174 new_out_sizes.at(dim) = end - start;
175
176 return new_out_sizes;
177 }
178
resize_slice_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)179 void resize_slice_view_node(
180 ComputeGraph* graph,
181 const std::vector<ArgGroup>& args,
182 const std::vector<ValueRef>& extra_args) {
183 (void)args;
184 vTensorPtr out = graph->get_tensor(extra_args[0]);
185
186 std::vector<int64_t> new_out_sizes = get_slice_sizes(
187 *graph,
188 extra_args[1], // input
189 extra_args[2], // dim
190 extra_args[3], // optional start
191 extra_args[4]); // optional end
192
193 out->virtual_resize(new_out_sizes);
194 }
195
check_slice_view_args(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef opt_step_ref,ValueRef out_ref)196 void check_slice_view_args(
197 ComputeGraph& graph,
198 ValueRef in_ref,
199 ValueRef dim_ref,
200 ValueRef opt_start_ref,
201 ValueRef opt_end_ref,
202 ValueRef opt_step_ref,
203 ValueRef out_ref) {
204 VK_CHECK_COND(
205 graph.val_is_view_of(out_ref, in_ref),
206 "output must be a view of the input");
207
208 const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
209 const int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
210
211 int64_t start =
212 graph.extract_optional_scalar<int64_t>(opt_start_ref).value_or(0);
213 int64_t end = graph.extract_optional_scalar<int64_t>(opt_end_ref).value_or(0);
214 int64_t step =
215 graph.extract_optional_scalar<int64_t>(opt_step_ref).value_or(1);
216
217 start = normalize_idx(start, dim_size, 0);
218 end = normalize_idx(end, dim_size, dim_size);
219
220 // The start idx must be 0; this is to ensure that the start of the slice view
221 // does not have any offset with respect to the base buffer storage. If the
222 // offset is nonzero, then it will potentially change upon a resize; however
223 // the buffer offset of the view tensor will have been "locked in" when the
224 // descriptor for its buffer storage is bound to a compute shader. Therefore
225 // there is no way to update the offset of the view once it has been bound.
226 VK_CHECK_COND(start == 0, "start must be 0 for slice view");
227 VK_CHECK_COND(step == 1, "step must be 1 for slice view");
228
229 VK_CHECK_COND(
230 end < dim_size, "end must be less than dim size for slice view");
231
232 // We must also check that all earlier dims in the dim order have a size of 1.
233 // This ensures that the slice view encompasses a contiguous memory region of
234 // the source tensor's memory buffer.
235 std::vector<int64_t> in_sizes = graph.sizes_of(in_ref);
236 std::vector<int64_t> in_dim_order = graph.dim_order_of(in_ref);
237 for (int i = 0; i < in_dim_order.size(); ++i) {
238 if (in_dim_order[i] == dim) {
239 break;
240 }
241 VK_CHECK_COND(in_sizes[in_dim_order[i]] == 1);
242 }
243 }
244
add_slice_view_node(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef opt_step_ref,ValueRef out_ref)245 void add_slice_view_node(
246 ComputeGraph& graph,
247 ValueRef in_ref,
248 ValueRef dim_ref,
249 ValueRef opt_start_ref,
250 ValueRef opt_end_ref,
251 ValueRef opt_step_ref,
252 ValueRef out_ref) {
253 check_slice_view_args(
254 graph,
255 in_ref,
256 dim_ref,
257 opt_start_ref,
258 opt_end_ref,
259 opt_step_ref,
260 out_ref);
261
262 std::vector<int64_t> new_out_sizes =
263 get_slice_sizes(graph, in_ref, dim_ref, opt_start_ref, opt_end_ref);
264
265 graph.get_tensor(out_ref)->virtual_resize(new_out_sizes);
266
267 graph.execute_nodes().emplace_back(new ExecuteNode(
268 resize_slice_view_node,
269 {out_ref, in_ref, dim_ref, opt_start_ref, opt_end_ref, opt_step_ref}));
270 }
271
slice_tensor_copy(ComputeGraph & graph,const std::vector<ValueRef> & args)272 void slice_tensor_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
273 return add_slice_tensor_copy_node(
274 graph,
275 args[0],
276 args[1], // dim
277 args[2], // optional start
278 args[3], // optional end
279 args[4], // step
280 args[5]);
281 }
282
slice_tensor(ComputeGraph & graph,const std::vector<ValueRef> & args)283 void slice_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
284 ValueRef in = args[0];
285 ValueRef out = args[5];
286
287 // Special case if out is a view of in
288 if (graph.val_is_view_of(out, in)) {
289 add_slice_view_node(
290 graph,
291 in,
292 args[1], // dim
293 args[2], // optional start
294 args[3], // optional end
295 args[4], // step
296 out);
297 return;
298 }
299
300 add_slice_tensor_copy_node(
301 graph,
302 in,
303 args[1], // dim
304 args[2], // optional start
305 args[3], // optional end
306 args[4], // step
307 out);
308 }
309
310 REGISTER_OPERATORS {
311 VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_copy);
312 VK_REGISTER_OP(aten.slice.Tensor, slice_tensor);
313 }
314
315 } // namespace vkcompute
316