1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
17
18 namespace vkcompute {
19
20 namespace {
21
check_args(const api::vTensor & in,const std::vector<int64_t> & repeats,const api::vTensor & out)22 void check_args(
23 const api::vTensor& in,
24 const std::vector<int64_t>& repeats,
25 const api::vTensor& out) {
26 VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
27 VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
28
29 VK_CHECK_COND(in.storage_type() == out.storage_type());
30 if (in.storage_type() == utils::kTexture2D) {
31 VK_CHECK_COND(in.dim() <= 2);
32 }
33
34 int64_t in_dim = in.dim();
35 VK_CHECK_COND(
36 in_dim <= repeats.size(),
37 "Input tensor dim size must be not greater than the repeat argument's size");
38
39 VK_CHECK_COND(
40 dim_at<kWidth4D>(in.sizes()) * dim_at<kWidth4D>(repeats) ==
41 dim_at<kWidth4D>(out.sizes()),
42 "Output's width doesn't match input's width * repeat count");
43
44 VK_CHECK_COND(
45 dim_at<kHeight4D>(in.sizes()) * dim_at<kHeight4D>(repeats) ==
46 dim_at<kHeight4D>(out.sizes()),
47 "Output's height doesn't match input's height * repeat count");
48
49 VK_CHECK_COND(
50 dim_at<kChannel4D>(in.sizes()) * dim_at<kChannel4D>(repeats) ==
51 dim_at<kChannel4D>(out.sizes()),
52 "Output's channel doesn't match input's channel * repeat count");
53
54 VK_CHECK_COND(
55 dim_at<kBatch4D>(in.sizes()) * dim_at<kBatch4D>(repeats) ==
56 dim_at<kBatch4D>(out.sizes()),
57 "Output's batch doesn't match input's batch * repeat count");
58 }
59
60 } // namespace
61
add_repeat_channel_node(ComputeGraph & graph,ValueRef in,int64_t repeat_channel,ValueRef out,utils::ivec3 & running_range)62 void add_repeat_channel_node(
63 ComputeGraph& graph,
64 ValueRef in,
65 int64_t repeat_channel,
66 ValueRef out,
67 utils::ivec3& running_range) {
68 vTensorPtr t_in = graph.get_tensor(in);
69 vTensorPtr t_out = graph.get_tensor(out);
70
71 std::string kernel_name = "repeat_channel";
72 kernel_name.reserve(kShaderNameReserve);
73 add_dtype_suffix(kernel_name, *t_out);
74
75 const std::vector<int64_t>& in_sizes = t_in->sizes();
76
77 int32_t in_width = utils::safe_downcast<int32_t>(dim_at<kWidth4D>(in_sizes));
78 int32_t in_height =
79 utils::safe_downcast<int32_t>(dim_at<kHeight4D>(in_sizes));
80 int32_t in_channel =
81 utils::safe_downcast<int32_t>(dim_at<kChannel4D>(in_sizes));
82 int32_t in_batch = utils::safe_downcast<int32_t>(dim_at<kBatch4D>(in_sizes));
83
84 int32_t out_channel = repeat_channel * in_channel;
85
86 utils::ivec4 out_whcn_sizes{in_width, in_height, out_channel, in_batch};
87
88 utils::ivec4 in_whcn_sizes{in_width, in_height, in_channel, in_batch};
89
90 // Channel packed global work ids
91 running_range[2] = out_whcn_sizes[3] * utils::div_up_4(out_whcn_sizes[2]);
92 utils::uvec3 global_size = utils::make_uvec3(running_range);
93 utils::uvec3 local_size = adaptive_work_group_size(global_size);
94
95 const struct Block final {
96 utils::ivec4 out_sizes;
97 utils::ivec4 in_size;
98 } repeat_channel_args{
99 out_whcn_sizes,
100 in_whcn_sizes,
101 };
102
103 auto shader = VK_KERNEL_FROM_STR(kernel_name);
104
105 graph.execute_nodes().emplace_back(new DispatchNode(
106 graph,
107 VK_KERNEL_FROM_STR(kernel_name),
108 global_size,
109 local_size,
110 // Inputs and Outputs
111 {{out, vkapi::MemoryAccessType::WRITE},
112 {in, vkapi::MemoryAccessType::READ}},
113 // Parameter buffers
114 {graph.create_params_buffer(repeat_channel_args)},
115 // Specialization Constants
116 {SV(t_out->packed_dim())}));
117 }
118
add_repeat_node(ComputeGraph & graph,ValueRef in,ValueRef repeats_ref,ValueRef out)119 void add_repeat_node(
120 ComputeGraph& graph,
121 ValueRef in,
122 ValueRef repeats_ref,
123 ValueRef out) {
124 std::vector<int64_t> repeats = *(graph.get_int_list(repeats_ref));
125
126 vTensorPtr t_in = graph.get_tensor(in);
127 vTensorPtr t_out = graph.get_tensor(out);
128 check_args(*t_in, repeats, *t_out);
129
130 // In this function, we expand the dimensions in the following order:
131 // 1. Channel
132 // 2. Width
133 // 3. Height
134 // 4. Batch
135 // After expanding a dimension, we will update the "running_range" since we
136 // will need to copy the "expanded" area.
137
138 utils::ivec3 running_range = t_in->logical_limits();
139
140 const std::vector<int64_t>& in_sizes = t_in->sizes();
141
142 // Since we use channel packing, repeating the channel dimension is the most
143 // complicated and time-consuming, as we need to reason over misaligned
144 // channels. Hence we expand it first to minimize cost. Also, in this first
145 // dimension, we copy over the input texure to the output. In subsequent
146 // dimensions, we read and write from the same tensor.
147
148 if (int64_t channel_repeat = dim_at<kChannel4D>(repeats);
149 channel_repeat == 1) {
150 // If no repeat, short-cut to a direct copy
151 utils::ivec3 src_offset{0, 0, 0};
152 utils::ivec3 dst_offset{0, 0, 0};
153
154 add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);
155
156 } else {
157 add_repeat_channel_node(graph, in, channel_repeat, out, running_range);
158 }
159
160 // TODO: refactor width, height, and batch into a common helper function.
161 // Width
162 if (int64_t width_repeat = dim_at<kWidth4D>(repeats); width_repeat > 1) {
163 utils::ivec3 src_offset{0, 0, 0};
164
165 for (int i = 1; i < width_repeat; ++i) {
166 utils::ivec3 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0};
167
168 add_copy_offset_node(
169 graph, out, running_range, src_offset, dst_offset, out);
170 }
171
172 running_range[0] = running_range[0] * width_repeat;
173 }
174
175 // Height
176 if (int64_t height_repeat = dim_at<kHeight4D>(repeats); height_repeat > 1) {
177 utils::ivec3 src_offset{0, 0, 0};
178
179 for (int i = 1; i < height_repeat; ++i) {
180 utils::ivec3 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0};
181
182 add_copy_offset_node(
183 graph, out, running_range, src_offset, dst_offset, out);
184 }
185
186 running_range[1] = running_range[1] * height_repeat;
187 }
188
189 // Batch
190 if (int64_t batch_repeat = dim_at<kBatch4D>(repeats); batch_repeat > 1) {
191 utils::ivec3 src_offset{0, 0, 0};
192
193 for (int i = 1; i < batch_repeat; ++i) {
194 utils::ivec3 dst_offset = {0, 0, i * running_range[2]};
195
196 add_copy_offset_node(
197 graph, out, running_range, src_offset, dst_offset, out);
198 }
199
200 running_range[2] = running_range[2] * batch_repeat;
201 }
202 }
203
repeat(ComputeGraph & graph,const std::vector<ValueRef> & args)204 void repeat(ComputeGraph& graph, const std::vector<ValueRef>& args) {
205 add_repeat_node(graph, args[0], args[1], args[2]);
206 }
207
208 REGISTER_OPERATORS {
209 VK_REGISTER_OP(aten.repeat.default, repeat);
210 }
211
212 } // namespace vkcompute
213