1 #include <ATen/TensorOperators.h>
2 #include <ATen/native/vulkan/ops/Gru.h>
3 #include <ATen/native/vulkan/ops/Mm.h>
4 #include <vector>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/addmm.h>
10 #include <ATen/ops/cat.h>
11 #include <ATen/ops/gru.h>
12 #include <ATen/ops/sigmoid.h>
13 #include <ATen/ops/slice.h>
14 #include <ATen/ops/tanh.h>
15 #endif
16
17 namespace at {
18 namespace native {
19 namespace vulkan {
20 namespace ops {
21 namespace {
22 //
23 // input_vk: input tensor containing the features of the input sequence
24 // tensor of shape (N, L, H_in) when batch_first=True
25 // (L, N, H_in) when batch_first=False
26 //
27 // hx_vk: initial hidden state for each element in the batch.
28 // tensor of shape (D * num_layers, N, H_out)
29 //
30 // output: tensor of shape (N, L, D * H_out) when batch_first=True
31 // (L, N, D * H_out) when batch_first=False
32 //
33 // h_n: tensor of shape (D * num_layers, N, H_out)
34 //
35 // where
36 // L = sequence length
37 // N = batch size
38 // D = 2 if bidirectional=True otherwise 1
39 // H_in = input_size (# of expected features in the input x)
40 // H_out = hidden_size (# of features in the hidden state h)
41 //
gru_input(const Tensor & input_vk,const Tensor & hx_vk,TensorList params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)42 std::tuple<Tensor, Tensor> gru_input(
43 const Tensor& input_vk, // input sequence (vulkan)
44 const Tensor& hx_vk, // initial hidden state (vulkan)
45 TensorList params_cpu, // weights/biases (cpu)
46 bool has_biases,
47 int64_t num_layers,
48 double dropout,
49 bool train,
50 bool bidirectional,
51 bool batch_first) {
52 TORCH_CHECK(
53 static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
54 "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'.");
55 TORCH_INTERNAL_ASSERT(
56 input_vk.sizes().size() == 3,
57 "Vulkan gru expects 'input_vk' dims to be 3.");
58 TORCH_INTERNAL_ASSERT(
59 hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3.");
60 TORCH_INTERNAL_ASSERT(
61 has_biases, "Vulkan gru expects 'has_biases' to be true.");
62 TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false.");
63 TORCH_INTERNAL_ASSERT(
64 !bidirectional, "Vulkan gru expects 'bidirectional' to be false.");
65 TORCH_INTERNAL_ASSERT(
66 dropout < std::numeric_limits<double>::epsilon() * 1000,
67 "Vulkan gru expects 'dropout' to be 0.0.");
68
69 const auto batch_size = input_vk.size(0);
70 const auto seq_length = input_vk.size(1);
71
72 TORCH_INTERNAL_ASSERT(
73 (batch_size == 1 && seq_length == 1) || batch_first,
74 "Vulkan gru expects batch-first input");
75
76 const auto hidden_size = hx_vk.size(2);
77 std::vector<at::Tensor> h_n_list; // hidden output
78
79 // reshape to 2D due to Vulkan at::mm op accepts only 2D
80 auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
81
82 for (int64_t i = 0; i < num_layers; ++i) {
83 // extract each hidden state and squeeze into 2D dim
84 auto h = at::slice(hx_vk, 0, i, i + 1, 1);
85 h = h.reshape({h.size(0) * h.size(1), h.size(2)});
86
87 const auto& w_ih = params_cpu[i * 4];
88 const auto& w_hh = params_cpu[i * 4 + 1];
89 const auto& b_ih = params_cpu[i * 4 + 2];
90 const auto& b_hh = params_cpu[i * 4 + 3];
91
92 const auto& w_i_rzn = w_ih.split(hidden_size);
93 const auto& w_h_rzn = w_hh.split(hidden_size);
94 const auto& b_i_rzn = b_ih.split(hidden_size);
95 const auto& b_h_rzn = b_hh.split(hidden_size);
96
97 const auto& w_ir = w_i_rzn[0];
98 const auto& w_iz = w_i_rzn[1];
99 const auto& w_in = w_i_rzn[2];
100 const auto& w_hr = w_h_rzn[0];
101 const auto& w_hz = w_h_rzn[1];
102 const auto& w_hn = w_h_rzn[2];
103 const auto& b_ir = b_i_rzn[0];
104 const auto& b_iz = b_i_rzn[1];
105 const auto& b_in = b_i_rzn[2];
106 const auto& b_hr = b_h_rzn[0];
107 const auto& b_hz = b_h_rzn[1];
108 const auto& b_hn = b_h_rzn[2];
109
110 const auto& r = at::sigmoid(
111 at::addmm(b_ir, x, w_ir.t()) + at::addmm(b_hr, h, w_hr.t()));
112 const auto& z = at::sigmoid(
113 at::addmm(b_iz, x, w_iz.t()) + at::addmm(b_hz, h, w_hz.t()));
114 const auto& n = at::tanh(
115 at::addmm(b_in, x, w_in.t()) + r * (at::addmm(b_hn, h, w_hn.t())));
116 h = (z * (-1) + 1) * n + z * h;
117 x = h; // next input
118 h_n_list.emplace_back(
119 h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
120 }
121
122 auto h_n = at::cat(h_n_list, 1);
123 x = x.reshape({batch_size, seq_length, x.size(1)});
124 h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
125 return std::tuple<Tensor, Tensor>(x, h_n);
126 }
127
128 #ifdef USE_VULKAN_API
129
TORCH_LIBRARY_IMPL(aten,Vulkan,m)130 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
131 m.impl(TORCH_SELECTIVE_NAME("aten::gru.input"), TORCH_FN(gru_input));
132 }
133
134 #endif /* USE_VULKAN_API */
135
136 } // namespace
137
138 static std::vector<c10::intrusive_ptr<LinearPackedContext>>
pack_linear_op_contexts(const std::vector<Tensor> & params_cpu,int64_t num_layers)139 pack_linear_op_contexts(
140 const std::vector<Tensor>& params_cpu,
141 int64_t num_layers) {
142 TORCH_CHECK(
143 static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
144 "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."
145 " But 'params_cpu' has size: ",
146 params_cpu.size(),
147 " and 'num_layers' is: ",
148 num_layers);
149 std::vector<c10::intrusive_ptr<LinearPackedContext>> linear_op_contexts;
150 linear_op_contexts.reserve(num_layers * 6);
151
152 for (int64_t i = 0; i < num_layers; ++i) {
153 const auto& w_ih = params_cpu.at(i * 4);
154 const auto& w_hh = params_cpu.at(i * 4 + 1);
155 const auto& b_ih = params_cpu.at(i * 4 + 2);
156 const auto& b_hh = params_cpu.at(i * 4 + 3);
157 const auto& hidden_size = w_ih.size(0) / 3;
158
159 const auto& w_i_rzn = w_ih.split(hidden_size);
160 const auto& w_h_rzn = w_hh.split(hidden_size);
161 const auto& b_i_rzn = b_ih.split(hidden_size);
162 const auto& b_h_rzn = b_hh.split(hidden_size);
163
164 const auto& w_ir = w_i_rzn[0];
165 const auto& w_iz = w_i_rzn[1];
166 const auto& w_in = w_i_rzn[2];
167 const auto& w_hr = w_h_rzn[0];
168 const auto& w_hz = w_h_rzn[1];
169 const auto& w_hn = w_h_rzn[2];
170 const auto& b_ir = b_i_rzn[0];
171 const auto& b_iz = b_i_rzn[1];
172 const auto& b_in = b_i_rzn[2];
173 const auto& b_hr = b_h_rzn[0];
174 const auto& b_hz = b_h_rzn[1];
175 const auto& b_hn = b_h_rzn[2];
176
177 linear_op_contexts.emplace_back(create_linear_context(w_ir.t(), b_ir));
178 linear_op_contexts.emplace_back(create_linear_context(w_hr.t(), b_hr));
179 linear_op_contexts.emplace_back(create_linear_context(w_iz.t(), b_iz));
180 linear_op_contexts.emplace_back(create_linear_context(w_hz.t(), b_hz));
181 linear_op_contexts.emplace_back(create_linear_context(w_in.t(), b_in));
182 linear_op_contexts.emplace_back(create_linear_context(w_hn.t(), b_hn));
183 }
184 return linear_op_contexts;
185 }
186
GruPackedContext(const std::vector<Tensor> & params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)187 GruPackedContext::GruPackedContext(
188 const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
189 bool has_biases,
190 int64_t num_layers,
191 double dropout,
192 bool train,
193 bool bidirectional,
194 bool batch_first) {
195 TORCH_INTERNAL_ASSERT(
196 has_biases, "Vulkan gru expects 'has_biases' to be true.");
197 TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false.");
198 TORCH_INTERNAL_ASSERT(
199 !bidirectional, "Vulkan gru expects 'bidirectional' to be false.");
200 TORCH_INTERNAL_ASSERT(
201 dropout < std::numeric_limits<double>::epsilon() * 1000,
202 "Vulkan gru expects 'dropout' to be 0.0.");
203
204 packed_.reserve(Packed::NumArgs);
205 packed_.emplace_back(pack_linear_op_contexts(params_cpu, num_layers));
206 packed_.emplace_back(has_biases);
207 packed_.emplace_back(num_layers);
208 packed_.emplace_back(dropout);
209 packed_.emplace_back(train);
210 packed_.emplace_back(bidirectional);
211 packed_.emplace_back(batch_first);
212 }
213
pack(c10::impl::GenericList unpacked)214 GruPackedContext GruPackedContext::pack(c10::impl::GenericList unpacked) {
215 return GruPackedContext(
216 unpacked.get(Unpacked::Params).toTensorVector(),
217 unpacked.get(Unpacked::hasBiases).toBool(),
218 unpacked.get(Unpacked::NumLayers).toInt(),
219 unpacked.get(Unpacked::Dropout).toDouble(),
220 unpacked.get(Unpacked::Train).toBool(),
221 unpacked.get(Unpacked::Bidirectional).toBool(),
222 unpacked.get(Unpacked::BatchFirst).toBool());
223 }
224
unpack() const225 const c10::impl::GenericList GruPackedContext::unpack() const {
226 c10::impl::GenericList unpacked_gru_context{c10::AnyType::get()};
227 unpacked_gru_context.reserve(Unpacked::NumArgs);
228
229 const c10::List<c10::IValue> packed_linear_contexts =
230 get_val(Packed::LinearContexts).toList();
231
232 const int64_t num_layers = get_val(Packed::NumLayers).toInt();
233 const int64_t linear_contexts_per_layer = 6;
234
235 std::vector<Tensor> params_cpu;
236 params_cpu.reserve(num_layers * linear_contexts_per_layer);
237
238 for (c10::IValue packed_linear_context : packed_linear_contexts) {
239 const c10::impl::GenericList unpacked_linear_context =
240 packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
241
242 TORCH_CHECK(
243 unpacked_linear_context.size() > 0u,
244 "unpacked_linear_context does not have any elements!");
245
246 params_cpu.emplace_back(
247 unpacked_linear_context.get(LinearPackedContext::Unpacked::Weight)
248 .toTensor()
249 .t());
250 params_cpu.emplace_back(
251 unpacked_linear_context.get(LinearPackedContext::Unpacked::Bias)
252 .toTensor());
253 }
254 unpacked_gru_context.emplace_back(params_cpu);
255 for (int64_t i = 1; i < Unpacked::NumArgs; ++i) {
256 unpacked_gru_context.emplace_back(get_val(i));
257 }
258
259 return unpacked_gru_context;
260 }
261
create_gru_context(std::vector<Tensor> && params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)262 c10::intrusive_ptr<GruPackedContext> create_gru_context(
263 std::vector<Tensor>&& params_cpu,
264 bool has_biases,
265 int64_t num_layers,
266 double dropout,
267 bool train,
268 bool bidirectional,
269 bool batch_first) {
270 return c10::make_intrusive<GruPackedContext>(GruPackedContext(
271 params_cpu,
272 has_biases,
273 num_layers,
274 dropout,
275 train,
276 bidirectional,
277 batch_first));
278 }
279
run_gru_context(const Tensor & input_vk,const Tensor & hx_vk,const c10::intrusive_ptr<GruPackedContext> & gru_context)280 std::tuple<Tensor, Tensor> run_gru_context(
281 const Tensor& input_vk, // input sequence (vulkan)
282 const Tensor& hx_vk, // initial hidden state (vulkan)
283 const c10::intrusive_ptr<GruPackedContext>& gru_context) {
284 TORCH_INTERNAL_ASSERT(
285 input_vk.sizes().size() == 3,
286 "Vulkan gru expects 'input_vk' dims to be 3.");
287 TORCH_INTERNAL_ASSERT(
288 hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3.");
289
290 const int64_t num_layers =
291 gru_context->get_val(GruPackedContext::Packed::NumLayers).toInt();
292 const bool batch_first =
293 gru_context->get_val(GruPackedContext::Packed::BatchFirst).toBool();
294 const auto batch_size = input_vk.size(0);
295 const auto seq_length = input_vk.size(1);
296
297 TORCH_INTERNAL_ASSERT(
298 (batch_size == 1 && seq_length == 1) || batch_first,
299 "Vulkan gru expects batch-first input");
300
301 const c10::List<c10::IValue> packed_linear_contexts =
302 gru_context->get_val(GruPackedContext::Packed::LinearContexts).toList();
303
304 const int64_t linear_contexts_per_layer = 6;
305 // (b_ir, w_ir), (b_hr, w_hr), (b_iz, w_iz),
306 // (b_hz, w_hz), (b_in,cw_in), (b_hn, w_hn)
307 std::vector<at::Tensor> h_n_list; // hidden output
308
309 // reshape to 2D due to Vulkan at::mm op accepts only 2D
310 auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
311
312 for (int64_t i = 0; i < num_layers; ++i) {
313 // extract each hidden state and squeeze into 2D dim
314 auto h = at::slice(hx_vk, 0, i, i + 1, 1);
315 h = h.reshape({h.size(0) * h.size(1), h.size(2)});
316
317 const auto& cxt_ir =
318 packed_linear_contexts[i * linear_contexts_per_layer + 0]
319 .toCustomClass<LinearPackedContext>();
320 const auto& cxt_hr =
321 packed_linear_contexts[i * linear_contexts_per_layer + 1]
322 .toCustomClass<LinearPackedContext>();
323 const auto& cxt_iz =
324 packed_linear_contexts[i * linear_contexts_per_layer + 2]
325 .toCustomClass<LinearPackedContext>();
326 const auto& cxt_hz =
327 packed_linear_contexts[i * linear_contexts_per_layer + 3]
328 .toCustomClass<LinearPackedContext>();
329 const auto& cxt_in =
330 packed_linear_contexts[i * linear_contexts_per_layer + 4]
331 .toCustomClass<LinearPackedContext>();
332 const auto& cxt_hn =
333 packed_linear_contexts[i * linear_contexts_per_layer + 5]
334 .toCustomClass<LinearPackedContext>();
335
336 const auto& r = at::sigmoid(
337 run_linear_context(x, cxt_ir) + run_linear_context(h, cxt_hr));
338 // cxt_ir->run(x, 1.0f, 1.0f) + cxt_hr->run(h, 1.0f, 1.0f));
339 const auto& z = at::sigmoid(
340 run_linear_context(x, cxt_iz) + run_linear_context(h, cxt_hz));
341 // cxt_iz->run(x, 1.0f, 1.0f) + cxt_hz->run(h, 1.0f, 1.0f));
342 const auto& n = at::tanh(
343 run_linear_context(x, cxt_in) + r * run_linear_context(h, cxt_hn));
344 // cxt_in->run(x, 1.0f, 1.0f) + r * (cxt_hn->run(h, 1.0f, 1.0f)));
345 h = (z * (-1) + 1) * n + z * h;
346 x = h; // next input
347 h_n_list.emplace_back(
348 h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
349 }
350
351 auto h_n = at::cat(h_n_list, 1);
352 x = x.reshape({batch_size, seq_length, x.size(1)});
353 h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
354 return std::tuple<Tensor, Tensor>(x, h_n);
355 }
356
357 } // namespace ops
358 } // namespace vulkan
359 } // namespace native
360 } // namespace at
361