• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/TensorOperators.h>
2 #include <ATen/native/vulkan/ops/Gru.h>
3 #include <ATen/native/vulkan/ops/Mm.h>
4 #include <vector>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/addmm.h>
10 #include <ATen/ops/cat.h>
11 #include <ATen/ops/gru.h>
12 #include <ATen/ops/sigmoid.h>
13 #include <ATen/ops/slice.h>
14 #include <ATen/ops/tanh.h>
15 #endif
16 
17 namespace at {
18 namespace native {
19 namespace vulkan {
20 namespace ops {
21 namespace {
22 //
23 // input_vk: input tensor containing the features of the input sequence
24 //           tensor of shape (N, L, H_in) when batch_first=True
25 //                           (L, N, H_in) when batch_first=False
26 //
27 // hx_vk: initial hidden state for each element in the batch.
28 //        tensor of shape (D * num_layers, N, H_out)
29 //
30 // output: tensor of shape (N, L, D * H_out) when batch_first=True
31 //                         (L, N, D * H_out) when batch_first=False
32 //
33 // h_n: tensor of shape (D * num_layers, N, H_out)
34 //
35 // where
36 //    L = sequence length
37 //    N = batch size
38 //    D = 2 if bidirectional=True otherwise 1
39 //    H_in = input_size (# of expected features in the input x)
40 //    H_out = hidden_size (# of features in the hidden state h)
41 //
gru_input(const Tensor & input_vk,const Tensor & hx_vk,TensorList params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)42 std::tuple<Tensor, Tensor> gru_input(
43     const Tensor& input_vk, // input sequence (vulkan)
44     const Tensor& hx_vk, // initial hidden state (vulkan)
45     TensorList params_cpu, // weights/biases (cpu)
46     bool has_biases,
47     int64_t num_layers,
48     double dropout,
49     bool train,
50     bool bidirectional,
51     bool batch_first) {
52   TORCH_CHECK(
53       static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
54       "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'.");
55   TORCH_INTERNAL_ASSERT(
56       input_vk.sizes().size() == 3,
57       "Vulkan gru expects 'input_vk' dims to be 3.");
58   TORCH_INTERNAL_ASSERT(
59       hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3.");
60   TORCH_INTERNAL_ASSERT(
61       has_biases, "Vulkan gru expects 'has_biases' to be true.");
62   TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false.");
63   TORCH_INTERNAL_ASSERT(
64       !bidirectional, "Vulkan gru expects 'bidirectional' to be false.");
65   TORCH_INTERNAL_ASSERT(
66       dropout < std::numeric_limits<double>::epsilon() * 1000,
67       "Vulkan gru expects 'dropout' to be 0.0.");
68 
69   const auto batch_size = input_vk.size(0);
70   const auto seq_length = input_vk.size(1);
71 
72   TORCH_INTERNAL_ASSERT(
73       (batch_size == 1 && seq_length == 1) || batch_first,
74       "Vulkan gru expects batch-first input");
75 
76   const auto hidden_size = hx_vk.size(2);
77   std::vector<at::Tensor> h_n_list; // hidden output
78 
79   // reshape to 2D due to Vulkan at::mm op accepts only 2D
80   auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
81 
82   for (int64_t i = 0; i < num_layers; ++i) {
83     // extract each hidden state and squeeze into 2D dim
84     auto h = at::slice(hx_vk, 0, i, i + 1, 1);
85     h = h.reshape({h.size(0) * h.size(1), h.size(2)});
86 
87     const auto& w_ih = params_cpu[i * 4];
88     const auto& w_hh = params_cpu[i * 4 + 1];
89     const auto& b_ih = params_cpu[i * 4 + 2];
90     const auto& b_hh = params_cpu[i * 4 + 3];
91 
92     const auto& w_i_rzn = w_ih.split(hidden_size);
93     const auto& w_h_rzn = w_hh.split(hidden_size);
94     const auto& b_i_rzn = b_ih.split(hidden_size);
95     const auto& b_h_rzn = b_hh.split(hidden_size);
96 
97     const auto& w_ir = w_i_rzn[0];
98     const auto& w_iz = w_i_rzn[1];
99     const auto& w_in = w_i_rzn[2];
100     const auto& w_hr = w_h_rzn[0];
101     const auto& w_hz = w_h_rzn[1];
102     const auto& w_hn = w_h_rzn[2];
103     const auto& b_ir = b_i_rzn[0];
104     const auto& b_iz = b_i_rzn[1];
105     const auto& b_in = b_i_rzn[2];
106     const auto& b_hr = b_h_rzn[0];
107     const auto& b_hz = b_h_rzn[1];
108     const auto& b_hn = b_h_rzn[2];
109 
110     const auto& r = at::sigmoid(
111         at::addmm(b_ir, x, w_ir.t()) + at::addmm(b_hr, h, w_hr.t()));
112     const auto& z = at::sigmoid(
113         at::addmm(b_iz, x, w_iz.t()) + at::addmm(b_hz, h, w_hz.t()));
114     const auto& n = at::tanh(
115         at::addmm(b_in, x, w_in.t()) + r * (at::addmm(b_hn, h, w_hn.t())));
116     h = (z * (-1) + 1) * n + z * h;
117     x = h; // next input
118     h_n_list.emplace_back(
119         h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
120   }
121 
122   auto h_n = at::cat(h_n_list, 1);
123   x = x.reshape({batch_size, seq_length, x.size(1)});
124   h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
125   return std::tuple<Tensor, Tensor>(x, h_n);
126 }
127 
128 #ifdef USE_VULKAN_API
129 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)130 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
131   m.impl(TORCH_SELECTIVE_NAME("aten::gru.input"), TORCH_FN(gru_input));
132 }
133 
134 #endif /* USE_VULKAN_API */
135 
136 } // namespace
137 
138 static std::vector<c10::intrusive_ptr<LinearPackedContext>>
pack_linear_op_contexts(const std::vector<Tensor> & params_cpu,int64_t num_layers)139 pack_linear_op_contexts(
140     const std::vector<Tensor>& params_cpu,
141     int64_t num_layers) {
142   TORCH_CHECK(
143       static_cast<int64_t>(params_cpu.size()) == 4 * num_layers,
144       "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."
145       " But 'params_cpu' has size: ",
146       params_cpu.size(),
147       " and 'num_layers' is: ",
148       num_layers);
149   std::vector<c10::intrusive_ptr<LinearPackedContext>> linear_op_contexts;
150   linear_op_contexts.reserve(num_layers * 6);
151 
152   for (int64_t i = 0; i < num_layers; ++i) {
153     const auto& w_ih = params_cpu.at(i * 4);
154     const auto& w_hh = params_cpu.at(i * 4 + 1);
155     const auto& b_ih = params_cpu.at(i * 4 + 2);
156     const auto& b_hh = params_cpu.at(i * 4 + 3);
157     const auto& hidden_size = w_ih.size(0) / 3;
158 
159     const auto& w_i_rzn = w_ih.split(hidden_size);
160     const auto& w_h_rzn = w_hh.split(hidden_size);
161     const auto& b_i_rzn = b_ih.split(hidden_size);
162     const auto& b_h_rzn = b_hh.split(hidden_size);
163 
164     const auto& w_ir = w_i_rzn[0];
165     const auto& w_iz = w_i_rzn[1];
166     const auto& w_in = w_i_rzn[2];
167     const auto& w_hr = w_h_rzn[0];
168     const auto& w_hz = w_h_rzn[1];
169     const auto& w_hn = w_h_rzn[2];
170     const auto& b_ir = b_i_rzn[0];
171     const auto& b_iz = b_i_rzn[1];
172     const auto& b_in = b_i_rzn[2];
173     const auto& b_hr = b_h_rzn[0];
174     const auto& b_hz = b_h_rzn[1];
175     const auto& b_hn = b_h_rzn[2];
176 
177     linear_op_contexts.emplace_back(create_linear_context(w_ir.t(), b_ir));
178     linear_op_contexts.emplace_back(create_linear_context(w_hr.t(), b_hr));
179     linear_op_contexts.emplace_back(create_linear_context(w_iz.t(), b_iz));
180     linear_op_contexts.emplace_back(create_linear_context(w_hz.t(), b_hz));
181     linear_op_contexts.emplace_back(create_linear_context(w_in.t(), b_in));
182     linear_op_contexts.emplace_back(create_linear_context(w_hn.t(), b_hn));
183   }
184   return linear_op_contexts;
185 }
186 
GruPackedContext(const std::vector<Tensor> & params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)187 GruPackedContext::GruPackedContext(
188     const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
189     bool has_biases,
190     int64_t num_layers,
191     double dropout,
192     bool train,
193     bool bidirectional,
194     bool batch_first) {
195   TORCH_INTERNAL_ASSERT(
196       has_biases, "Vulkan gru expects 'has_biases' to be true.");
197   TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false.");
198   TORCH_INTERNAL_ASSERT(
199       !bidirectional, "Vulkan gru expects 'bidirectional' to be false.");
200   TORCH_INTERNAL_ASSERT(
201       dropout < std::numeric_limits<double>::epsilon() * 1000,
202       "Vulkan gru expects 'dropout' to be 0.0.");
203 
204   packed_.reserve(Packed::NumArgs);
205   packed_.emplace_back(pack_linear_op_contexts(params_cpu, num_layers));
206   packed_.emplace_back(has_biases);
207   packed_.emplace_back(num_layers);
208   packed_.emplace_back(dropout);
209   packed_.emplace_back(train);
210   packed_.emplace_back(bidirectional);
211   packed_.emplace_back(batch_first);
212 }
213 
pack(c10::impl::GenericList unpacked)214 GruPackedContext GruPackedContext::pack(c10::impl::GenericList unpacked) {
215   return GruPackedContext(
216       unpacked.get(Unpacked::Params).toTensorVector(),
217       unpacked.get(Unpacked::hasBiases).toBool(),
218       unpacked.get(Unpacked::NumLayers).toInt(),
219       unpacked.get(Unpacked::Dropout).toDouble(),
220       unpacked.get(Unpacked::Train).toBool(),
221       unpacked.get(Unpacked::Bidirectional).toBool(),
222       unpacked.get(Unpacked::BatchFirst).toBool());
223 }
224 
unpack() const225 const c10::impl::GenericList GruPackedContext::unpack() const {
226   c10::impl::GenericList unpacked_gru_context{c10::AnyType::get()};
227   unpacked_gru_context.reserve(Unpacked::NumArgs);
228 
229   const c10::List<c10::IValue> packed_linear_contexts =
230       get_val(Packed::LinearContexts).toList();
231 
232   const int64_t num_layers = get_val(Packed::NumLayers).toInt();
233   const int64_t linear_contexts_per_layer = 6;
234 
235   std::vector<Tensor> params_cpu;
236   params_cpu.reserve(num_layers * linear_contexts_per_layer);
237 
238   for (c10::IValue packed_linear_context : packed_linear_contexts) {
239     const c10::impl::GenericList unpacked_linear_context =
240         packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
241 
242     TORCH_CHECK(
243         unpacked_linear_context.size() > 0u,
244         "unpacked_linear_context does not have any elements!");
245 
246     params_cpu.emplace_back(
247         unpacked_linear_context.get(LinearPackedContext::Unpacked::Weight)
248             .toTensor()
249             .t());
250     params_cpu.emplace_back(
251         unpacked_linear_context.get(LinearPackedContext::Unpacked::Bias)
252             .toTensor());
253   }
254   unpacked_gru_context.emplace_back(params_cpu);
255   for (int64_t i = 1; i < Unpacked::NumArgs; ++i) {
256     unpacked_gru_context.emplace_back(get_val(i));
257   }
258 
259   return unpacked_gru_context;
260 }
261 
create_gru_context(std::vector<Tensor> && params_cpu,bool has_biases,int64_t num_layers,double dropout,bool train,bool bidirectional,bool batch_first)262 c10::intrusive_ptr<GruPackedContext> create_gru_context(
263     std::vector<Tensor>&& params_cpu,
264     bool has_biases,
265     int64_t num_layers,
266     double dropout,
267     bool train,
268     bool bidirectional,
269     bool batch_first) {
270   return c10::make_intrusive<GruPackedContext>(GruPackedContext(
271       params_cpu,
272       has_biases,
273       num_layers,
274       dropout,
275       train,
276       bidirectional,
277       batch_first));
278 }
279 
run_gru_context(const Tensor & input_vk,const Tensor & hx_vk,const c10::intrusive_ptr<GruPackedContext> & gru_context)280 std::tuple<Tensor, Tensor> run_gru_context(
281     const Tensor& input_vk, // input sequence (vulkan)
282     const Tensor& hx_vk, // initial hidden state (vulkan)
283     const c10::intrusive_ptr<GruPackedContext>& gru_context) {
284   TORCH_INTERNAL_ASSERT(
285       input_vk.sizes().size() == 3,
286       "Vulkan gru expects 'input_vk' dims to be 3.");
287   TORCH_INTERNAL_ASSERT(
288       hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3.");
289 
290   const int64_t num_layers =
291       gru_context->get_val(GruPackedContext::Packed::NumLayers).toInt();
292   const bool batch_first =
293       gru_context->get_val(GruPackedContext::Packed::BatchFirst).toBool();
294   const auto batch_size = input_vk.size(0);
295   const auto seq_length = input_vk.size(1);
296 
297   TORCH_INTERNAL_ASSERT(
298       (batch_size == 1 && seq_length == 1) || batch_first,
299       "Vulkan gru expects batch-first input");
300 
301   const c10::List<c10::IValue> packed_linear_contexts =
302       gru_context->get_val(GruPackedContext::Packed::LinearContexts).toList();
303 
304   const int64_t linear_contexts_per_layer = 6;
305   // (b_ir, w_ir), (b_hr, w_hr), (b_iz, w_iz),
306   // (b_hz, w_hz), (b_in,cw_in), (b_hn, w_hn)
307   std::vector<at::Tensor> h_n_list; // hidden output
308 
309   // reshape to 2D due to Vulkan at::mm op accepts only 2D
310   auto x = input_vk.reshape({batch_size * seq_length, input_vk.size(2)});
311 
312   for (int64_t i = 0; i < num_layers; ++i) {
313     // extract each hidden state and squeeze into 2D dim
314     auto h = at::slice(hx_vk, 0, i, i + 1, 1);
315     h = h.reshape({h.size(0) * h.size(1), h.size(2)});
316 
317     const auto& cxt_ir =
318         packed_linear_contexts[i * linear_contexts_per_layer + 0]
319             .toCustomClass<LinearPackedContext>();
320     const auto& cxt_hr =
321         packed_linear_contexts[i * linear_contexts_per_layer + 1]
322             .toCustomClass<LinearPackedContext>();
323     const auto& cxt_iz =
324         packed_linear_contexts[i * linear_contexts_per_layer + 2]
325             .toCustomClass<LinearPackedContext>();
326     const auto& cxt_hz =
327         packed_linear_contexts[i * linear_contexts_per_layer + 3]
328             .toCustomClass<LinearPackedContext>();
329     const auto& cxt_in =
330         packed_linear_contexts[i * linear_contexts_per_layer + 4]
331             .toCustomClass<LinearPackedContext>();
332     const auto& cxt_hn =
333         packed_linear_contexts[i * linear_contexts_per_layer + 5]
334             .toCustomClass<LinearPackedContext>();
335 
336     const auto& r = at::sigmoid(
337         run_linear_context(x, cxt_ir) + run_linear_context(h, cxt_hr));
338     // cxt_ir->run(x, 1.0f, 1.0f) + cxt_hr->run(h, 1.0f, 1.0f));
339     const auto& z = at::sigmoid(
340         run_linear_context(x, cxt_iz) + run_linear_context(h, cxt_hz));
341     // cxt_iz->run(x, 1.0f, 1.0f) + cxt_hz->run(h, 1.0f, 1.0f));
342     const auto& n = at::tanh(
343         run_linear_context(x, cxt_in) + r * run_linear_context(h, cxt_hn));
344     // cxt_in->run(x, 1.0f, 1.0f) + r * (cxt_hn->run(h, 1.0f, 1.0f)));
345     h = (z * (-1) + 1) * n + z * h;
346     x = h; // next input
347     h_n_list.emplace_back(
348         h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op
349   }
350 
351   auto h_n = at::cat(h_n_list, 1);
352   x = x.reshape({batch_size, seq_length, x.size(1)});
353   h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)});
354   return std::tuple<Tensor, Tensor>(x, h_n);
355 }
356 
357 } // namespace ops
358 } // namespace vulkan
359 } // namespace native
360 } // namespace at
361