1 #include <ATen/TensorOperators.h>
2 #include <ATen/native/vulkan/ops/Lstm.h>
3 #include <ATen/native/vulkan/ops/Mm.h>
4 #include <torch/library.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/addmm.h>
10 #include <ATen/ops/cat.h>
11 #include <ATen/ops/sigmoid.h>
12 #include <ATen/ops/slice.h>
13 #include <ATen/ops/tanh.h>
14 #endif
15
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20 namespace {
21 //
22 // input_vk: input tensor of shape (L, N, H_in) when batch_first=False or
23 // (N, L, H_in) when batch_first=True containing the features of the input
24 // sequence
25 //
26 // hx_vk: tensor of shape (D * num_layers, N, H_out) containing the initial
27 // hidden state for each element in the input sequence.
28 //
29 // cx_vk: tensor of shape (D * num_layers, N, H_cell) containing the initial
30 // cell state for each element in the input sequence.
31 //
32 // output: tensor of shape (L, N, D * H_out) when batch_first=False or
33 // (N, L, D * H_out) when batch_first=True, containing the output features
34 // (h_t) from the last layer of the LSTM, for each t
35 //
36 // h_n: tensor of shape (D * num_layers, N, H_out) containing the final hidden
37 // state for each element in the sequence.
38 //
39 // c_n: tensor of shape (D * num_layers, N, H_cell) containing the final cell
40 // state for each element in the sequence.
41 //
42 // where
43 // L = sequence length
44 // N = batch size
45 // D = 2 if bidirectional=True otherwise 1
46 // H_in = input_size (# of expected features in the input x)
47 // H_cell = hidden_size (# of features in the hidden state h)
48 // H_out = hidden_size
49 //
lstm_input(const Tensor & input_vk,TensorList hx,TensorList params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)50 std::tuple<Tensor, Tensor, Tensor> lstm_input(
51 const Tensor& input_vk, // input sequence (vulkan)
52 TensorList
53 hx, // initial hidden state (vulkan) & initial cell state (vulkan)
54 TensorList params_cpu, // weights/biases (cpu)
55 bool has_biases,
56 int64_t num_layers,
57 double dropout,
58 bool train,
59 bool bidirectional,
60 bool batch_first) {
61 TORCH_CHECK(
62 hx[0].size(2) == hx[1].size(2),
63 "Vulkan LSTM with projections is not supported");
64 TORCH_CHECK(
65 static_cast<int64_t>(params_cpu.size()),
66 "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'.");
67 TORCH_INTERNAL_ASSERT(
68 input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3.");
69 TORCH_INTERNAL_ASSERT(
70 hx[0].sizes().size() == 3,
71 "Vulkan LSTM expects hidden state dims to be 3.");
72 TORCH_INTERNAL_ASSERT(
73 hx[1].sizes().size() == 3,
74 "Vulkan LSTM expects cell state dims to be 3.");
75 TORCH_INTERNAL_ASSERT(
76 has_biases, "Vulkan LSTM expects 'has_biases' to be true.");
77 TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false.");
78 TORCH_INTERNAL_ASSERT(
79 !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false.");
80 TORCH_INTERNAL_ASSERT(
81 dropout < std::numeric_limits<double>::epsilon() * 1000,
82 "Vulkan LSTM expects 'dropout' to be 0.0.");
83
84 const auto batch_size = input_vk.size(0);
85 const auto seq_length = input_vk.size(1);
86
87 TORCH_INTERNAL_ASSERT(
88 (batch_size == 1 && seq_length == 1) || batch_first,
89 "Vulkan gru expects batch-first input");
90
91 const Tensor& hx_vk = hx[0];
92 const Tensor& cx_vk = hx[1];
93
94 const auto hidden_size = hx_vk.size(2);
95 std::vector<at::Tensor> h_n_list; // hidden state output
96 std::vector<at::Tensor> c_n_list; // cell state output
97
98 // reshape to 2D due to Vulkan at::mm op accepts only 2D
99 auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
100
101 h_n_list.reserve(num_layers);
102 c_n_list.reserve(num_layers);
103
104 for (int64_t l = 0; l < num_layers; ++l) {
105 // extract each hidden state and squeeze into 2D dim
106 auto h = at::slice(hx_vk, 0, l, l + 1, 1);
107 h = h.reshape({h.size(0) * h.size(1), h.size(2)});
108
109 auto c = at::slice(cx_vk, 0, l, l + 1, 1);
110 c = c.reshape({c.size(0) * c.size(1), c.size(2)});
111
112 const auto& w_ih = params_cpu[l * 4];
113 const auto& w_hh = params_cpu[l * 4 + 1];
114 const auto& b_ih = params_cpu[l * 4 + 2];
115 const auto& b_hh = params_cpu[l * 4 + 3];
116
117 const auto& w_i_ifgo = w_ih.split(hidden_size);
118 const auto& w_h_ifgo = w_hh.split(hidden_size);
119 const auto& b_i_ifgo = b_ih.split(hidden_size);
120 const auto& b_h_ifgo = b_hh.split(hidden_size);
121
122 const auto& w_ii = w_i_ifgo[0];
123 const auto& w_if = w_i_ifgo[1];
124 const auto& w_ig = w_i_ifgo[2];
125 const auto& w_io = w_i_ifgo[3];
126 const auto& w_hi = w_h_ifgo[0];
127 const auto& w_hf = w_h_ifgo[1];
128 const auto& w_hg = w_h_ifgo[2];
129 const auto& w_ho = w_h_ifgo[3];
130 const auto& b_ii = b_i_ifgo[0];
131 const auto& b_if = b_i_ifgo[1];
132 const auto& b_ig = b_i_ifgo[2];
133 const auto& b_io = b_i_ifgo[3];
134 const auto& b_hi = b_h_ifgo[0];
135 const auto& b_hf = b_h_ifgo[1];
136 const auto& b_hg = b_h_ifgo[2];
137 const auto& b_ho = b_h_ifgo[3];
138
139 const auto& i = at::sigmoid(
140 at::addmm(b_ii, x, w_ii.t()) + at::addmm(b_hi, h, w_hi.t()));
141 const auto& f = at::sigmoid(
142 at::addmm(b_if, x, w_if.t()) + at::addmm(b_hf, h, w_hf.t()));
143 const auto& g =
144 at::tanh(at::addmm(b_ig, x, w_ig.t()) + at::addmm(b_hg, h, w_hg.t()));
145 const auto& o = at::sigmoid(
146 at::addmm(b_io, x, w_io.t()) + at::addmm(b_ho, h, w_ho.t()));
147 c = f * c + i * g;
148 h = o * at::tanh(c);
149 x = h; // next input
150 h_n_list.emplace_back(
151 h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
152 c_n_list.emplace_back(
153 c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op
154 }
155
156 auto h_n = at::cat(h_n_list, 1);
157 auto c_n = at::cat(c_n_list, 1);
158 x = x.reshape({batch_size, seq_length, x.size(1)});
159 h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
160 c_n = c_n.reshape({c_n.size(0) * c_n.size(1), c_n.size(2), c_n.size(3)});
161 return std::tuple<Tensor, Tensor, Tensor>(x, h_n, c_n);
162 }
163
164 #ifdef USE_VULKAN_API
165
TORCH_LIBRARY_IMPL(aten,Vulkan,m)166 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
167 m.impl(TORCH_SELECTIVE_NAME("aten::lstm.input"), TORCH_FN(lstm_input));
168 }
169
170 #endif /* USE_VULKAN_API */
171
172 } // namespace
173
174 static std::vector<c10::intrusive_ptr<LinearPackedContext>>
pack_lstm_linear_op_contexts(const std::vector<Tensor> & params_cpu,int64_t num_layers)175 pack_lstm_linear_op_contexts(
176 const std::vector<Tensor>& params_cpu,
177 int64_t num_layers) {
178 TORCH_CHECK(
179 static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
180 "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'."
181 " But 'params_cpu' has size: ",
182 params_cpu.size(),
183 " and 'num_layers' is: ",
184 num_layers);
185 std::vector<c10::intrusive_ptr<LinearPackedContext>> linear_op_contexts;
186 linear_op_contexts.reserve(num_layers * 8);
187
188 for (int64_t l = 0; l < num_layers; ++l) {
189 const auto& w_ih = params_cpu[l * 4];
190 const auto& w_hh = params_cpu[l * 4 + 1];
191 const auto& b_ih = params_cpu[l * 4 + 2];
192 const auto& b_hh = params_cpu[l * 4 + 3];
193 const auto& hidden_size = w_ih.size(0) / 4;
194
195 const auto& w_i_ifgo = w_ih.split(hidden_size);
196 const auto& w_h_ifgo = w_hh.split(hidden_size);
197 const auto& b_i_ifgo = b_ih.split(hidden_size);
198 const auto& b_h_ifgo = b_hh.split(hidden_size);
199
200 const auto& w_ii = w_i_ifgo[0];
201 const auto& w_if = w_i_ifgo[1];
202 const auto& w_ig = w_i_ifgo[2];
203 const auto& w_io = w_i_ifgo[3];
204 const auto& w_hi = w_h_ifgo[0];
205 const auto& w_hf = w_h_ifgo[1];
206 const auto& w_hg = w_h_ifgo[2];
207 const auto& w_ho = w_h_ifgo[3];
208 const auto& b_ii = b_i_ifgo[0];
209 const auto& b_if = b_i_ifgo[1];
210 const auto& b_ig = b_i_ifgo[2];
211 const auto& b_io = b_i_ifgo[3];
212 const auto& b_hi = b_h_ifgo[0];
213 const auto& b_hf = b_h_ifgo[1];
214 const auto& b_hg = b_h_ifgo[2];
215 const auto& b_ho = b_h_ifgo[3];
216
217 linear_op_contexts.emplace_back(create_linear_context(w_ii.t(), b_ii));
218 linear_op_contexts.emplace_back(create_linear_context(w_hi.t(), b_hi));
219 linear_op_contexts.emplace_back(create_linear_context(w_if.t(), b_if));
220 linear_op_contexts.emplace_back(create_linear_context(w_hf.t(), b_hf));
221 linear_op_contexts.emplace_back(create_linear_context(w_ig.t(), b_ig));
222 linear_op_contexts.emplace_back(create_linear_context(w_hg.t(), b_hg));
223 linear_op_contexts.emplace_back(create_linear_context(w_io.t(), b_io));
224 linear_op_contexts.emplace_back(create_linear_context(w_ho.t(), b_ho));
225 }
226 return linear_op_contexts;
227 }
228
LstmPackedContext(const std::vector<Tensor> & params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)229 LstmPackedContext::LstmPackedContext(
230 const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
231 bool has_biases,
232 int64_t num_layers,
233 double dropout,
234 bool train,
235 bool bidirectional,
236 bool batch_first) {
237 TORCH_INTERNAL_ASSERT(
238 has_biases, "Vulkan LSTM expects 'has_biases' to be true.");
239 TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false.");
240 TORCH_INTERNAL_ASSERT(
241 !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false.");
242 TORCH_INTERNAL_ASSERT(
243 dropout < std::numeric_limits<double>::epsilon() * 1000,
244 "Vulkan LSTM expects 'dropout' to be 0.0.");
245
246 packed_.reserve(Packed::NumArgs);
247 packed_.emplace_back(pack_lstm_linear_op_contexts(params_cpu, num_layers));
248 packed_.emplace_back(has_biases);
249 packed_.emplace_back(num_layers);
250 packed_.emplace_back(dropout);
251 packed_.emplace_back(train);
252 packed_.emplace_back(bidirectional);
253 packed_.emplace_back(batch_first);
254 }
255
pack(c10::impl::GenericList unpacked)256 LstmPackedContext LstmPackedContext::pack(c10::impl::GenericList unpacked) {
257 return LstmPackedContext(
258 unpacked.get(Unpacked::Params).toTensorVector(),
259 unpacked.get(Unpacked::hasBiases).toBool(),
260 unpacked.get(Unpacked::NumLayers).toInt(),
261 unpacked.get(Unpacked::Dropout).toDouble(),
262 unpacked.get(Unpacked::Train).toBool(),
263 unpacked.get(Unpacked::Bidirectional).toBool(),
264 unpacked.get(Unpacked::BatchFirst).toBool());
265 }
266
unpack() const267 const c10::impl::GenericList LstmPackedContext::unpack() const {
268 c10::impl::GenericList unpacked_lstm_context{c10::AnyType::get()};
269 unpacked_lstm_context.reserve(Unpacked::NumArgs);
270
271 const c10::List<c10::IValue> packed_linear_contexts =
272 get_val(Packed::LinearContexts).toList();
273
274 const int64_t num_layers = get_val(Packed::NumLayers).toInt();
275 const int64_t linear_contexts_per_layer = 8;
276
277 std::vector<Tensor> params_cpu;
278 params_cpu.reserve(num_layers * linear_contexts_per_layer);
279
280 for (c10::IValue packed_linear_context : packed_linear_contexts) {
281 const c10::impl::GenericList unpacked_linear_context =
282 packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
283
284 TORCH_CHECK(
285 unpacked_linear_context.size() > 0u,
286 "unpacked_linear_context does not have any elements!");
287
288 params_cpu.emplace_back(
289 unpacked_linear_context.get(LinearPackedContext::Unpacked::Weight)
290 .toTensor()
291 .t());
292 params_cpu.emplace_back(
293 unpacked_linear_context.get(LinearPackedContext::Unpacked::Bias)
294 .toTensor());
295 }
296 unpacked_lstm_context.emplace_back(params_cpu);
297 for (int64_t i = 1; i < 7; ++i) {
298 unpacked_lstm_context.emplace_back(get_val(i));
299 }
300
301 return unpacked_lstm_context;
302 }
303
create_lstm_context(std::vector<Tensor> && params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)304 c10::intrusive_ptr<LstmPackedContext> create_lstm_context(
305 std::vector<Tensor>&& params_cpu,
306 bool has_biases,
307 int64_t num_layers,
308 double dropout,
309 bool train,
310 bool bidirectional,
311 bool batch_first) {
312 return c10::make_intrusive<LstmPackedContext>(LstmPackedContext(
313 params_cpu,
314 has_biases,
315 num_layers,
316 dropout,
317 train,
318 bidirectional,
319 batch_first));
320 }
321
run_lstm_context(const Tensor & input_vk,const Tensor & hx_vk,const Tensor & cx_vk,const c10::intrusive_ptr<LstmPackedContext> & lstm_context)322 std::tuple<Tensor, Tensor, Tensor> run_lstm_context(
323 const Tensor& input_vk, // input sequence (vulkan)
324 const Tensor& hx_vk, // initial hidden state (vulkan)
325 const Tensor& cx_vk, // initial cell state (vulkan)
326 const c10::intrusive_ptr<LstmPackedContext>& lstm_context) {
327 TORCH_INTERNAL_ASSERT(
328 input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3.");
329 TORCH_INTERNAL_ASSERT(
330 hx_vk.sizes().size() == 3,
331 "Vulkan LSTM expects hidden state dims to be 3.");
332 TORCH_INTERNAL_ASSERT(
333 cx_vk.sizes().size() == 3,
334 "Vulkan LSTM expects cell state dims to be 3.");
335
336 const int64_t num_layers =
337 lstm_context->get_val(LstmPackedContext::Packed::NumLayers).toInt();
338 const bool batch_first =
339 lstm_context->get_val(LstmPackedContext::Packed::BatchFirst).toBool();
340 const auto batch_size = input_vk.size(0);
341 const auto seq_length = input_vk.size(1);
342
343 TORCH_INTERNAL_ASSERT(
344 (batch_size == 1 && seq_length == 1) || batch_first,
345 "Vulkan gru expects batch-first input");
346
347 const c10::List<c10::IValue> packed_linear_op_contexts =
348 lstm_context->get_val(LstmPackedContext::Packed::LinearContexts).toList();
349
350 const int64_t linear_op_contexts_per_layer = 8;
351 // (b_ii, w_ii), (b_hi, w_hi), (b_if, w_if), (b_hf, w_hf),
352 // (b_ig, w_ig), (b_hg, w_hg), (b_io, w_io), (b_ho, w_ho)
353
354 std::vector<at::Tensor> h_n_list; // hidden state output
355 std::vector<at::Tensor> c_n_list; // cell state output
356
357 // reshape to 2D due to Vulkan at::mm op accepts only 2D
358 auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
359
360 h_n_list.reserve(num_layers);
361 c_n_list.reserve(num_layers);
362
363 for (int64_t l = 0; l < num_layers; ++l) {
364 // extract each hidden state and squeeze into 2D dim
365 auto h = at::slice(hx_vk, 0, l, l + 1, 1);
366 h = h.reshape({h.size(0) * h.size(1), h.size(2)});
367
368 auto c = at::slice(cx_vk, 0, l, l + 1, 1);
369 c = c.reshape({c.size(0) * c.size(1), c.size(2)});
370
371 const auto& cxt_ii =
372 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 0]
373 .toCustomClass<LinearPackedContext>();
374 const auto& cxt_hi =
375 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 1]
376 .toCustomClass<LinearPackedContext>();
377 const auto& cxt_if =
378 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 2]
379 .toCustomClass<LinearPackedContext>();
380 const auto& cxt_hf =
381 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 3]
382 .toCustomClass<LinearPackedContext>();
383 const auto& cxt_ig =
384 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 4]
385 .toCustomClass<LinearPackedContext>();
386 const auto& cxt_hg =
387 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 5]
388 .toCustomClass<LinearPackedContext>();
389 const auto& cxt_io =
390 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 6]
391 .toCustomClass<LinearPackedContext>();
392 const auto& cxt_ho =
393 packed_linear_op_contexts[l * linear_op_contexts_per_layer + 7]
394 .toCustomClass<LinearPackedContext>();
395
396 const auto& i = at::sigmoid(
397 run_linear_context(x, cxt_ii) + run_linear_context(h, cxt_hi));
398 // cxt_ii->run(x, 1.0f, 1.0f) + cxt_hi->run(h, 1.0f, 1.0f));
399 const auto& f = at::sigmoid(
400 run_linear_context(x, cxt_if) + run_linear_context(h, cxt_hf));
401 // cxt_if->run(x, 1.0f, 1.0f) + cxt_hf->run(h, 1.0f, 1.0f));
402 const auto& g =
403 at::tanh(run_linear_context(x, cxt_ig) + run_linear_context(h, cxt_hg));
404 // cxt_ig->run(x, 1.0f, 1.0f) + cxt_hg->run(h, 1.0f, 1.0f));
405 const auto& o = at::sigmoid(
406 run_linear_context(x, cxt_io) + run_linear_context(h, cxt_ho));
407 // cxt_io->run(x, 1.0f, 1.0f) + cxt_ho->run(h, 1.0f, 1.0f));
408 c = f * c + i * g;
409 h = o * at::tanh(c);
410 x = h; // next input
411 h_n_list.emplace_back(
412 h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
413 c_n_list.emplace_back(
414 c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op
415 }
416
417 auto h_n = at::cat(h_n_list, 1);
418 auto c_n = at::cat(c_n_list, 1);
419 x = x.reshape({batch_size, seq_length, x.size(1)});
420 h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
421 c_n = c_n.reshape({c_n.size(0) * c_n.size(1), c_n.size(2), c_n.size(3)});
422 return std::tuple<Tensor, Tensor, Tensor>(x, h_n, c_n);
423 }
424
425 } // namespace ops
426 } // namespace vulkan
427 } // namespace native
428 } // namespace at
429