• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
7 #include <torch/library.h>
8 
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace ops {
13 
14 class LstmPackedContext final : virtual public VulkanPackedContext,
15                                 public torch::jit::CustomClassHolder {
16  public:
17   LstmPackedContext(
18       const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
19       bool has_biases,
20       int64_t num_layers,
21       double dropout,
22       bool train,
23       bool bidirectional,
24       bool batch_first);
25 
26   /*
27    * Assigns a name to each index in the unpacked list.
28    */
29   struct Unpacked final {
30     static constexpr uint32_t Params = 0u;
31     static constexpr uint32_t hasBiases = 1u;
32     static constexpr uint32_t NumLayers = 2u;
33     static constexpr uint32_t Dropout = 3u;
34     static constexpr uint32_t Train = 4u;
35     static constexpr uint32_t Bidirectional = 5u;
36     static constexpr uint32_t BatchFirst = 6u;
37 
38     static constexpr uint32_t NumArgs = 7u;
39   };
40 
41   /*
42    * Assigns a name to each index in the packed list.
43    */
44   struct Packed final {
45     static constexpr uint32_t LinearContexts = 0u;
46     static constexpr uint32_t hasBiases = 1u;
47     static constexpr uint32_t NumLayers = 2u;
48     static constexpr uint32_t Dropout = 3u;
49     static constexpr uint32_t Train = 4u;
50     static constexpr uint32_t Bidirectional = 5u;
51     static constexpr uint32_t BatchFirst = 6u;
52 
53     static constexpr uint32_t NumArgs = 7u;
54   };
55 
56   static LstmPackedContext pack(c10::impl::GenericList);
57 
58   const c10::impl::GenericList unpack() const override;
59 };
60 
61 c10::intrusive_ptr<LstmPackedContext> create_lstm_context(
62     std::vector<Tensor>&& params_cpu, // weights/biases (cpu)
63     bool has_biases,
64     int64_t num_layers,
65     double dropout,
66     bool train,
67     bool bidirectional,
68     bool batch_first);
69 
70 std::tuple<Tensor, Tensor, Tensor> run_lstm_context(
71     const Tensor& input_vk, // input sequence (vulkan)
72     const Tensor& hx_vk, // initial hidden state (vulkan)
73     const Tensor& cx_vk, // initial cell state (vulkan)
74     const c10::intrusive_ptr<LstmPackedContext>& vulkan_context);
75 
76 } // namespace ops
77 } // namespace vulkan
78 } // namespace native
79 } // namespace at
80 
81 #endif /* USE_VULKAN_API */
82