• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/TensorOperators.h>
2 #include <ATen/native/vulkan/ops/Lstm.h>
3 #include <ATen/native/vulkan/ops/Mm.h>
4 #include <torch/library.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/addmm.h>
10 #include <ATen/ops/cat.h>
11 #include <ATen/ops/sigmoid.h>
12 #include <ATen/ops/slice.h>
13 #include <ATen/ops/tanh.h>
14 #endif
15 
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20 namespace {
21 //
22 // input_vk: input tensor of shape (L, N, H_in) when batch_first=False or
23 // (N, L, H_in) when batch_first=True containing the features of the input
24 // sequence
25 //
26 // hx_vk: tensor of shape (D * num_layers, N, H_out) containing the initial
27 // hidden state for each element in the input sequence.
28 //
29 // cx_vk: tensor of shape (D * num_layers, N, H_cell) containing the initial
30 // cell state for each element in the input sequence.
31 //
32 // output: tensor of shape (L, N, D * H_out) when batch_first=False or
33 // (N, L, D * H_out) when batch_first=True, containing the output features
34 // (h_t) from the last layer of the LSTM, for each t
35 //
36 // h_n: tensor of shape (D * num_layers, N, H_out) containing the final hidden
37 // state for each element in the sequence.
38 //
39 // c_n: tensor of shape (D * num_layers, N, H_cell) containing the final cell
40 // state for each element in the sequence.
41 //
42 //  where
43 //    L = sequence length
44 //    N = batch size
45 //    D = 2 if bidirectional=True otherwise 1
46 //    H_in = input_size (# of expected features in the input x)
47 //    H_cell = hidden_size (# of features in the hidden state h)
48 //    H_out = hidden_size
49 //
lstm_input(const Tensor & input_vk,TensorList hx,TensorList params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)50 std::tuple<Tensor, Tensor, Tensor> lstm_input(
51     const Tensor& input_vk, // input sequence (vulkan)
52     TensorList
53         hx, // initial hidden state (vulkan) & initial cell state (vulkan)
54     TensorList params_cpu, // weights/biases (cpu)
55     bool has_biases,
56     int64_t num_layers,
57     double dropout,
58     bool train,
59     bool bidirectional,
60     bool batch_first) {
61   TORCH_CHECK(
62       hx[0].size(2) == hx[1].size(2),
63       "Vulkan LSTM with projections is not supported");
64   TORCH_CHECK(
65       static_cast<int64_t>(params_cpu.size()),
66       "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'.");
67   TORCH_INTERNAL_ASSERT(
68       input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3.");
69   TORCH_INTERNAL_ASSERT(
70       hx[0].sizes().size() == 3,
71       "Vulkan LSTM expects hidden state dims to be 3.");
72   TORCH_INTERNAL_ASSERT(
73       hx[1].sizes().size() == 3,
74       "Vulkan LSTM expects cell state dims to be 3.");
75   TORCH_INTERNAL_ASSERT(
76       has_biases, "Vulkan LSTM expects 'has_biases' to be true.");
77   TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false.");
78   TORCH_INTERNAL_ASSERT(
79       !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false.");
80   TORCH_INTERNAL_ASSERT(
81       dropout < std::numeric_limits<double>::epsilon() * 1000,
82       "Vulkan LSTM expects 'dropout' to be 0.0.");
83 
84   const auto batch_size = input_vk.size(0);
85   const auto seq_length = input_vk.size(1);
86 
87   TORCH_INTERNAL_ASSERT(
88       (batch_size == 1 && seq_length == 1) || batch_first,
89       "Vulkan gru expects batch-first input");
90 
91   const Tensor& hx_vk = hx[0];
92   const Tensor& cx_vk = hx[1];
93 
94   const auto hidden_size = hx_vk.size(2);
95   std::vector<at::Tensor> h_n_list; // hidden state output
96   std::vector<at::Tensor> c_n_list; // cell state output
97 
98   // reshape to 2D due to Vulkan at::mm op accepts only 2D
99   auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
100 
101   h_n_list.reserve(num_layers);
102   c_n_list.reserve(num_layers);
103 
104   for (int64_t l = 0; l < num_layers; ++l) {
105     // extract each hidden state and squeeze into 2D dim
106     auto h = at::slice(hx_vk, 0, l, l + 1, 1);
107     h = h.reshape({h.size(0) * h.size(1), h.size(2)});
108 
109     auto c = at::slice(cx_vk, 0, l, l + 1, 1);
110     c = c.reshape({c.size(0) * c.size(1), c.size(2)});
111 
112     const auto& w_ih = params_cpu[l * 4];
113     const auto& w_hh = params_cpu[l * 4 + 1];
114     const auto& b_ih = params_cpu[l * 4 + 2];
115     const auto& b_hh = params_cpu[l * 4 + 3];
116 
117     const auto& w_i_ifgo = w_ih.split(hidden_size);
118     const auto& w_h_ifgo = w_hh.split(hidden_size);
119     const auto& b_i_ifgo = b_ih.split(hidden_size);
120     const auto& b_h_ifgo = b_hh.split(hidden_size);
121 
122     const auto& w_ii = w_i_ifgo[0];
123     const auto& w_if = w_i_ifgo[1];
124     const auto& w_ig = w_i_ifgo[2];
125     const auto& w_io = w_i_ifgo[3];
126     const auto& w_hi = w_h_ifgo[0];
127     const auto& w_hf = w_h_ifgo[1];
128     const auto& w_hg = w_h_ifgo[2];
129     const auto& w_ho = w_h_ifgo[3];
130     const auto& b_ii = b_i_ifgo[0];
131     const auto& b_if = b_i_ifgo[1];
132     const auto& b_ig = b_i_ifgo[2];
133     const auto& b_io = b_i_ifgo[3];
134     const auto& b_hi = b_h_ifgo[0];
135     const auto& b_hf = b_h_ifgo[1];
136     const auto& b_hg = b_h_ifgo[2];
137     const auto& b_ho = b_h_ifgo[3];
138 
139     const auto& i = at::sigmoid(
140         at::addmm(b_ii, x, w_ii.t()) + at::addmm(b_hi, h, w_hi.t()));
141     const auto& f = at::sigmoid(
142         at::addmm(b_if, x, w_if.t()) + at::addmm(b_hf, h, w_hf.t()));
143     const auto& g =
144         at::tanh(at::addmm(b_ig, x, w_ig.t()) + at::addmm(b_hg, h, w_hg.t()));
145     const auto& o = at::sigmoid(
146         at::addmm(b_io, x, w_io.t()) + at::addmm(b_ho, h, w_ho.t()));
147     c = f * c + i * g;
148     h = o * at::tanh(c);
149     x = h; // next input
150     h_n_list.emplace_back(
151         h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
152     c_n_list.emplace_back(
153         c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op
154   }
155 
156   auto h_n = at::cat(h_n_list, 1);
157   auto c_n = at::cat(c_n_list, 1);
158   x = x.reshape({batch_size, seq_length, x.size(1)});
159   h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
160   c_n = c_n.reshape({c_n.size(0) * c_n.size(1), c_n.size(2), c_n.size(3)});
161   return std::tuple<Tensor, Tensor, Tensor>(x, h_n, c_n);
162 }
163 
164 #ifdef USE_VULKAN_API
165 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)166 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
167   m.impl(TORCH_SELECTIVE_NAME("aten::lstm.input"), TORCH_FN(lstm_input));
168 }
169 
170 #endif /* USE_VULKAN_API */
171 
172 } // namespace
173 
174 static std::vector<c10::intrusive_ptr<LinearPackedContext>>
pack_lstm_linear_op_contexts(const std::vector<Tensor> & params_cpu,int64_t num_layers)175 pack_lstm_linear_op_contexts(
176     const std::vector<Tensor>& params_cpu,
177     int64_t num_layers) {
178   TORCH_CHECK(
179       static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
180       "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'."
181       " But 'params_cpu' has size: ",
182       params_cpu.size(),
183       " and 'num_layers' is: ",
184       num_layers);
185   std::vector<c10::intrusive_ptr<LinearPackedContext>> linear_op_contexts;
186   linear_op_contexts.reserve(num_layers * 8);
187 
188   for (int64_t l = 0; l < num_layers; ++l) {
189     const auto& w_ih = params_cpu[l * 4];
190     const auto& w_hh = params_cpu[l * 4 + 1];
191     const auto& b_ih = params_cpu[l * 4 + 2];
192     const auto& b_hh = params_cpu[l * 4 + 3];
193     const auto& hidden_size = w_ih.size(0) / 4;
194 
195     const auto& w_i_ifgo = w_ih.split(hidden_size);
196     const auto& w_h_ifgo = w_hh.split(hidden_size);
197     const auto& b_i_ifgo = b_ih.split(hidden_size);
198     const auto& b_h_ifgo = b_hh.split(hidden_size);
199 
200     const auto& w_ii = w_i_ifgo[0];
201     const auto& w_if = w_i_ifgo[1];
202     const auto& w_ig = w_i_ifgo[2];
203     const auto& w_io = w_i_ifgo[3];
204     const auto& w_hi = w_h_ifgo[0];
205     const auto& w_hf = w_h_ifgo[1];
206     const auto& w_hg = w_h_ifgo[2];
207     const auto& w_ho = w_h_ifgo[3];
208     const auto& b_ii = b_i_ifgo[0];
209     const auto& b_if = b_i_ifgo[1];
210     const auto& b_ig = b_i_ifgo[2];
211     const auto& b_io = b_i_ifgo[3];
212     const auto& b_hi = b_h_ifgo[0];
213     const auto& b_hf = b_h_ifgo[1];
214     const auto& b_hg = b_h_ifgo[2];
215     const auto& b_ho = b_h_ifgo[3];
216 
217     linear_op_contexts.emplace_back(create_linear_context(w_ii.t(), b_ii));
218     linear_op_contexts.emplace_back(create_linear_context(w_hi.t(), b_hi));
219     linear_op_contexts.emplace_back(create_linear_context(w_if.t(), b_if));
220     linear_op_contexts.emplace_back(create_linear_context(w_hf.t(), b_hf));
221     linear_op_contexts.emplace_back(create_linear_context(w_ig.t(), b_ig));
222     linear_op_contexts.emplace_back(create_linear_context(w_hg.t(), b_hg));
223     linear_op_contexts.emplace_back(create_linear_context(w_io.t(), b_io));
224     linear_op_contexts.emplace_back(create_linear_context(w_ho.t(), b_ho));
225   }
226   return linear_op_contexts;
227 }
228 
LstmPackedContext(const std::vector<Tensor> & params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)229 LstmPackedContext::LstmPackedContext(
230     const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
231     bool has_biases,
232     int64_t num_layers,
233     double dropout,
234     bool train,
235     bool bidirectional,
236     bool batch_first) {
237   TORCH_INTERNAL_ASSERT(
238       has_biases, "Vulkan LSTM expects 'has_biases' to be true.");
239   TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false.");
240   TORCH_INTERNAL_ASSERT(
241       !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false.");
242   TORCH_INTERNAL_ASSERT(
243       dropout < std::numeric_limits<double>::epsilon() * 1000,
244       "Vulkan LSTM expects 'dropout' to be 0.0.");
245 
246   packed_.reserve(Packed::NumArgs);
247   packed_.emplace_back(pack_lstm_linear_op_contexts(params_cpu, num_layers));
248   packed_.emplace_back(has_biases);
249   packed_.emplace_back(num_layers);
250   packed_.emplace_back(dropout);
251   packed_.emplace_back(train);
252   packed_.emplace_back(bidirectional);
253   packed_.emplace_back(batch_first);
254 }
255 
pack(c10::impl::GenericList unpacked)256 LstmPackedContext LstmPackedContext::pack(c10::impl::GenericList unpacked) {
257   return LstmPackedContext(
258       unpacked.get(Unpacked::Params).toTensorVector(),
259       unpacked.get(Unpacked::hasBiases).toBool(),
260       unpacked.get(Unpacked::NumLayers).toInt(),
261       unpacked.get(Unpacked::Dropout).toDouble(),
262       unpacked.get(Unpacked::Train).toBool(),
263       unpacked.get(Unpacked::Bidirectional).toBool(),
264       unpacked.get(Unpacked::BatchFirst).toBool());
265 }
266 
unpack() const267 const c10::impl::GenericList LstmPackedContext::unpack() const {
268   c10::impl::GenericList unpacked_lstm_context{c10::AnyType::get()};
269   unpacked_lstm_context.reserve(Unpacked::NumArgs);
270 
271   const c10::List<c10::IValue> packed_linear_contexts =
272       get_val(Packed::LinearContexts).toList();
273 
274   const int64_t num_layers = get_val(Packed::NumLayers).toInt();
275   const int64_t linear_contexts_per_layer = 8;
276 
277   std::vector<Tensor> params_cpu;
278   params_cpu.reserve(num_layers * linear_contexts_per_layer);
279 
280   for (c10::IValue packed_linear_context : packed_linear_contexts) {
281     const c10::impl::GenericList unpacked_linear_context =
282         packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
283 
284     TORCH_CHECK(
285         unpacked_linear_context.size() > 0u,
286         "unpacked_linear_context does not have any elements!");
287 
288     params_cpu.emplace_back(
289         unpacked_linear_context.get(LinearPackedContext::Unpacked::Weight)
290             .toTensor()
291             .t());
292     params_cpu.emplace_back(
293         unpacked_linear_context.get(LinearPackedContext::Unpacked::Bias)
294             .toTensor());
295   }
296   unpacked_lstm_context.emplace_back(params_cpu);
297   for (int64_t i = 1; i < 7; ++i) {
298     unpacked_lstm_context.emplace_back(get_val(i));
299   }
300 
301   return unpacked_lstm_context;
302 }
303 
create_lstm_context(std::vector<Tensor> && params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)304 c10::intrusive_ptr<LstmPackedContext> create_lstm_context(
305     std::vector<Tensor>&& params_cpu,
306     bool has_biases,
307     int64_t num_layers,
308     double dropout,
309     bool train,
310     bool bidirectional,
311     bool batch_first) {
312   return c10::make_intrusive<LstmPackedContext>(LstmPackedContext(
313       params_cpu,
314       has_biases,
315       num_layers,
316       dropout,
317       train,
318       bidirectional,
319       batch_first));
320 }
321 
run_lstm_context(const Tensor & input_vk,const Tensor & hx_vk,const Tensor & cx_vk,const c10::intrusive_ptr<LstmPackedContext> & lstm_context)322 std::tuple<Tensor, Tensor, Tensor> run_lstm_context(
323     const Tensor& input_vk, // input sequence (vulkan)
324     const Tensor& hx_vk, // initial hidden state (vulkan)
325     const Tensor& cx_vk, // initial cell state (vulkan)
326     const c10::intrusive_ptr<LstmPackedContext>& lstm_context) {
327   TORCH_INTERNAL_ASSERT(
328       input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3.");
329   TORCH_INTERNAL_ASSERT(
330       hx_vk.sizes().size() == 3,
331       "Vulkan LSTM expects hidden state dims to be 3.");
332   TORCH_INTERNAL_ASSERT(
333       cx_vk.sizes().size() == 3,
334       "Vulkan LSTM expects cell state dims to be 3.");
335 
336   const int64_t num_layers =
337       lstm_context->get_val(LstmPackedContext::Packed::NumLayers).toInt();
338   const bool batch_first =
339       lstm_context->get_val(LstmPackedContext::Packed::BatchFirst).toBool();
340   const auto batch_size = input_vk.size(0);
341   const auto seq_length = input_vk.size(1);
342 
343   TORCH_INTERNAL_ASSERT(
344       (batch_size == 1 && seq_length == 1) || batch_first,
345       "Vulkan gru expects batch-first input");
346 
347   const c10::List<c10::IValue> packed_linear_op_contexts =
348       lstm_context->get_val(LstmPackedContext::Packed::LinearContexts).toList();
349 
350   const int64_t linear_op_contexts_per_layer = 8;
351   // (b_ii, w_ii), (b_hi, w_hi), (b_if, w_if), (b_hf, w_hf),
352   // (b_ig, w_ig), (b_hg, w_hg), (b_io, w_io), (b_ho, w_ho)
353 
354   std::vector<at::Tensor> h_n_list; // hidden state output
355   std::vector<at::Tensor> c_n_list; // cell state output
356 
357   // reshape to 2D due to Vulkan at::mm op accepts only 2D
358   auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
359 
360   h_n_list.reserve(num_layers);
361   c_n_list.reserve(num_layers);
362 
363   for (int64_t l = 0; l < num_layers; ++l) {
364     // extract each hidden state and squeeze into 2D dim
365     auto h = at::slice(hx_vk, 0, l, l + 1, 1);
366     h = h.reshape({h.size(0) * h.size(1), h.size(2)});
367 
368     auto c = at::slice(cx_vk, 0, l, l + 1, 1);
369     c = c.reshape({c.size(0) * c.size(1), c.size(2)});
370 
371     const auto& cxt_ii =
372         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 0]
373             .toCustomClass<LinearPackedContext>();
374     const auto& cxt_hi =
375         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 1]
376             .toCustomClass<LinearPackedContext>();
377     const auto& cxt_if =
378         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 2]
379             .toCustomClass<LinearPackedContext>();
380     const auto& cxt_hf =
381         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 3]
382             .toCustomClass<LinearPackedContext>();
383     const auto& cxt_ig =
384         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 4]
385             .toCustomClass<LinearPackedContext>();
386     const auto& cxt_hg =
387         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 5]
388             .toCustomClass<LinearPackedContext>();
389     const auto& cxt_io =
390         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 6]
391             .toCustomClass<LinearPackedContext>();
392     const auto& cxt_ho =
393         packed_linear_op_contexts[l * linear_op_contexts_per_layer + 7]
394             .toCustomClass<LinearPackedContext>();
395 
396     const auto& i = at::sigmoid(
397         run_linear_context(x, cxt_ii) + run_linear_context(h, cxt_hi));
398     // cxt_ii->run(x, 1.0f, 1.0f) + cxt_hi->run(h, 1.0f, 1.0f));
399     const auto& f = at::sigmoid(
400         run_linear_context(x, cxt_if) + run_linear_context(h, cxt_hf));
401     // cxt_if->run(x, 1.0f, 1.0f) + cxt_hf->run(h, 1.0f, 1.0f));
402     const auto& g =
403         at::tanh(run_linear_context(x, cxt_ig) + run_linear_context(h, cxt_hg));
404     // cxt_ig->run(x, 1.0f, 1.0f) + cxt_hg->run(h, 1.0f, 1.0f));
405     const auto& o = at::sigmoid(
406         run_linear_context(x, cxt_io) + run_linear_context(h, cxt_ho));
407     // cxt_io->run(x, 1.0f, 1.0f) + cxt_ho->run(h, 1.0f, 1.0f));
408     c = f * c + i * g;
409     h = o * at::tanh(c);
410     x = h; // next input
411     h_n_list.emplace_back(
412         h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
413     c_n_list.emplace_back(
414         c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op
415   }
416 
417   auto h_n = at::cat(h_n_list, 1);
418   auto c_n = at::cat(c_n_list, 1);
419   x = x.reshape({batch_size, seq_length, x.size(1)});
420   h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
421   c_n = c_n.reshape({c_n.size(0) * c_n.size(1), c_n.size(2), c_n.size(3)});
422   return std::tuple<Tensor, Tensor, Tensor>(x, h_n, c_n);
423 }
424 
425 } // namespace ops
426 } // namespace vulkan
427 } // namespace native
428 } // namespace at
429