• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
17 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
18 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
19 
20 namespace vkcompute {
21 
normalize_idx(const int64_t index,const int64_t max,const int64_t default_value)22 inline int64_t normalize_idx(
23     const int64_t index,
24     const int64_t max,
25     const int64_t default_value) {
26   // INT64_MAX is passed when value is unspecified
27   if (index == INT64_MAX) {
28     return default_value;
29   }
30   if (index == default_value) {
31     return index;
32   }
33   return normalize(index, max);
34 }
35 
add_slice_tensor_copy_node(ComputeGraph & graph,ValueRef in,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef step_ref,ValueRef out)36 void add_slice_tensor_copy_node(
37     ComputeGraph& graph,
38     ValueRef in,
39     ValueRef dim_ref,
40     ValueRef opt_start_ref,
41     ValueRef opt_end_ref,
42     ValueRef step_ref,
43     ValueRef out) {
44   vTensorPtr t_in = graph.get_tensor(in);
45   vTensorPtr t_out = graph.get_tensor(out);
46 
47   VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim));
48   VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim));
49 
50   // Need normalize the dim
51   int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
52 
53   VK_CHECK_COND(
54       -t_in->dim() <= dim && dim < t_in->dim(),
55       "dim must be in range of [-self.dim(), self.dim()), but current dim's value is ",
56       dim,
57       " and self.dim() = ",
58       t_in->dim());
59 
60   dim = normalize(dim, t_in->dim());
61 
62   DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
63 
64   std::optional<int64_t> opt_start =
65       graph.extract_optional_scalar<int64_t>(opt_start_ref);
66   std::optional<int64_t> opt_end =
67       graph.extract_optional_scalar<int64_t>(opt_end_ref);
68   int64_t step = graph.extract_scalar<int64_t>(step_ref);
69 
70   const auto in_sizes = t_in->sizes();
71   const auto out_sizes = t_out->sizes();
72 
73   int64_t start = opt_start.value_or(0);
74   int64_t end = opt_end.value_or(in_sizes[dim]);
75 
76   start = normalize_idx(start, in_sizes[dim], 0);
77   end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);
78 
79   if (dim_index == kChannel4D) {
80     // slice by channel
81     std::string kernel_name = "slice_channel";
82     kernel_name.reserve(kShaderNameReserve);
83     add_dtype_suffix(kernel_name, *t_out);
84 
85     const struct Block final {
86       int offset;
87       int step;
88     } params{
89         static_cast<int32_t>(start),
90         static_cast<int32_t>(step),
91     };
92 
93     graph.execute_nodes().emplace_back(new DispatchNode(
94         graph,
95         VK_KERNEL_FROM_STR(kernel_name),
96         graph.create_global_wg_size(out),
97         graph.create_local_wg_size(out),
98         {{out, vkapi::MemoryAccessType::WRITE},
99          {in, vkapi::MemoryAccessType::READ}},
100         {t_out->sizes_ubo(),
101          t_in->sizes_ubo(),
102          graph.create_params_buffer(params)}));
103 
104   } else {
105     // GPU's coordinate is in x, y, z
106     int64_t gpu_dim = -1;
107     int64_t stride = 1;
108     if (dim_index == kWidth4D) {
109       gpu_dim = 0; // width: x dimension in gpu
110       VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
111     } else if (dim_index == kHeight4D) {
112       gpu_dim = 1; // height: y dimension
113       VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
114     } else if (dim_index == kBatch4D) {
115       gpu_dim = 2; // batch: z dimension
116 
117       // Due to channel packing, each batch value is span over stride planes
118       int64_t n_channels = dim_at(in_sizes, kChannel4D);
119       stride = utils::div_up_4(n_channels);
120     } else {
121       VK_THROW("Unexpected ncwh_dim!");
122     }
123 
124     std::string kernel_name = "slice_batch_height_width";
125     kernel_name.reserve(kShaderNameReserve);
126     add_dtype_suffix(kernel_name, *t_out);
127 
128     utils::uvec3 global_size = t_out->logical_limits();
129     utils::uvec3 local_size = adaptive_work_group_size(global_size);
130 
131     const struct Block final {
132       int dim;
133       int offset;
134       int step;
135       int stride;
136     } params{
137         static_cast<int32_t>(gpu_dim),
138         static_cast<int32_t>(start),
139         static_cast<int32_t>(step),
140         static_cast<int32_t>(stride),
141     };
142 
143     graph.execute_nodes().emplace_back(new DispatchNode(
144         graph,
145         VK_KERNEL_FROM_STR(kernel_name),
146         global_size,
147         local_size,
148         {{out, vkapi::MemoryAccessType::WRITE},
149          {in, vkapi::MemoryAccessType::READ}},
150         {t_out->sizes_ubo(), graph.create_params_buffer(params)}));
151   }
152 }
153 
get_slice_sizes(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref)154 std::vector<int64_t> get_slice_sizes(
155     ComputeGraph& graph,
156     ValueRef in_ref,
157     ValueRef dim_ref,
158     ValueRef opt_start_ref,
159     ValueRef opt_end_ref) {
160   const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
161   std::optional<int64_t> opt_start =
162       graph.extract_optional_scalar<int64_t>(opt_start_ref);
163   std::optional<int64_t> opt_end =
164       graph.extract_optional_scalar<int64_t>(opt_end_ref);
165 
166   int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
167   int64_t start = opt_start.value_or(0);
168   int64_t end = opt_end.value_or(dim_size);
169 
170   start = normalize_idx(start, dim_size, 0);
171   end = normalize_idx(end, dim_size, dim_size);
172 
173   std::vector<int64_t> new_out_sizes = graph.sizes_of(in_ref);
174   new_out_sizes.at(dim) = end - start;
175 
176   return new_out_sizes;
177 }
178 
resize_slice_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)179 void resize_slice_view_node(
180     ComputeGraph* graph,
181     const std::vector<ArgGroup>& args,
182     const std::vector<ValueRef>& extra_args) {
183   (void)args;
184   vTensorPtr out = graph->get_tensor(extra_args[0]);
185 
186   std::vector<int64_t> new_out_sizes = get_slice_sizes(
187       *graph,
188       extra_args[1], // input
189       extra_args[2], // dim
190       extra_args[3], // optional start
191       extra_args[4]); // optional end
192 
193   out->virtual_resize(new_out_sizes);
194 }
195 
check_slice_view_args(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef opt_step_ref,ValueRef out_ref)196 void check_slice_view_args(
197     ComputeGraph& graph,
198     ValueRef in_ref,
199     ValueRef dim_ref,
200     ValueRef opt_start_ref,
201     ValueRef opt_end_ref,
202     ValueRef opt_step_ref,
203     ValueRef out_ref) {
204   VK_CHECK_COND(
205       graph.val_is_view_of(out_ref, in_ref),
206       "output must be a view of the input");
207 
208   const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
209   const int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
210 
211   int64_t start =
212       graph.extract_optional_scalar<int64_t>(opt_start_ref).value_or(0);
213   int64_t end = graph.extract_optional_scalar<int64_t>(opt_end_ref).value_or(0);
214   int64_t step =
215       graph.extract_optional_scalar<int64_t>(opt_step_ref).value_or(1);
216 
217   start = normalize_idx(start, dim_size, 0);
218   end = normalize_idx(end, dim_size, dim_size);
219 
220   // The start idx must be 0; this is to ensure that the start of the slice view
221   // does not have any offset with respect to the base buffer storage. If the
222   // offset is nonzero, then it will potentially change upon a resize; however
223   // the buffer offset of the view tensor will have been "locked in" when the
224   // descriptor for its buffer storage is bound to a compute shader. Therefore
225   // there is no way to update the offset of the view once it has been bound.
226   VK_CHECK_COND(start == 0, "start must be 0 for slice view");
227   VK_CHECK_COND(step == 1, "step must be 1 for slice view");
228 
229   VK_CHECK_COND(
230       end < dim_size, "end must be less than dim size for slice view");
231 
232   // We must also check that all earlier dims in the dim order have a size of 1.
233   // This ensures that the slice view encompasses a contiguous memory region of
234   // the source tensor's memory buffer.
235   std::vector<int64_t> in_sizes = graph.sizes_of(in_ref);
236   std::vector<int64_t> in_dim_order = graph.dim_order_of(in_ref);
237   for (int i = 0; i < in_dim_order.size(); ++i) {
238     if (in_dim_order[i] == dim) {
239       break;
240     }
241     VK_CHECK_COND(in_sizes[in_dim_order[i]] == 1);
242   }
243 }
244 
add_slice_view_node(ComputeGraph & graph,ValueRef in_ref,ValueRef dim_ref,ValueRef opt_start_ref,ValueRef opt_end_ref,ValueRef opt_step_ref,ValueRef out_ref)245 void add_slice_view_node(
246     ComputeGraph& graph,
247     ValueRef in_ref,
248     ValueRef dim_ref,
249     ValueRef opt_start_ref,
250     ValueRef opt_end_ref,
251     ValueRef opt_step_ref,
252     ValueRef out_ref) {
253   check_slice_view_args(
254       graph,
255       in_ref,
256       dim_ref,
257       opt_start_ref,
258       opt_end_ref,
259       opt_step_ref,
260       out_ref);
261 
262   std::vector<int64_t> new_out_sizes =
263       get_slice_sizes(graph, in_ref, dim_ref, opt_start_ref, opt_end_ref);
264 
265   graph.get_tensor(out_ref)->virtual_resize(new_out_sizes);
266 
267   graph.execute_nodes().emplace_back(new ExecuteNode(
268       resize_slice_view_node,
269       {out_ref, in_ref, dim_ref, opt_start_ref, opt_end_ref, opt_step_ref}));
270 }
271 
slice_tensor_copy(ComputeGraph & graph,const std::vector<ValueRef> & args)272 void slice_tensor_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
273   return add_slice_tensor_copy_node(
274       graph,
275       args[0],
276       args[1], // dim
277       args[2], // optional start
278       args[3], // optional end
279       args[4], // step
280       args[5]);
281 }
282 
slice_tensor(ComputeGraph & graph,const std::vector<ValueRef> & args)283 void slice_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
284   ValueRef in = args[0];
285   ValueRef out = args[5];
286 
287   // Special case if out is a view of in
288   if (graph.val_is_view_of(out, in)) {
289     add_slice_view_node(
290         graph,
291         in,
292         args[1], // dim
293         args[2], // optional start
294         args[3], // optional end
295         args[4], // step
296         out);
297     return;
298   }
299 
300   add_slice_tensor_copy_node(
301       graph,
302       in,
303       args[1], // dim
304       args[2], // optional start
305       args[3], // optional end
306       args[4], // step
307       out);
308 }
309 
310 REGISTER_OPERATORS {
311   VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_copy);
312   VK_REGISTER_OP(aten.slice.Tensor, slice_tensor);
313 }
314 
315 } // namespace vkcompute
316