• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
7 #include <torch/library.h>
8 
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace ops {
13 
14 class GruPackedContext final : virtual public VulkanPackedContext,
15                                public torch::jit::CustomClassHolder {
16  public:
17   GruPackedContext(
18       const std::vector<Tensor>& params_cpu, // weights/biases (cpu)
19       bool has_biases,
20       int64_t num_layers,
21       double dropout,
22       bool train,
23       bool bidirectional,
24       bool batch_first);
25 
26   /*
27    * Assigns a name to each index in the unpacked list.
28    */
29   struct Unpacked final {
30     static constexpr uint32_t Params = 0u;
31     static constexpr uint32_t hasBiases = 1u;
32     static constexpr uint32_t NumLayers = 2u;
33     static constexpr uint32_t Dropout = 3u;
34     static constexpr uint32_t Train = 4u;
35     static constexpr uint32_t Bidirectional = 5u;
36     static constexpr uint32_t BatchFirst = 6u;
37 
38     static constexpr uint32_t NumArgs = 7u;
39   };
40 
41   /*
42    * Assigns a name to each index in the packed list.
43    */
44   struct Packed final {
45     static constexpr uint32_t LinearContexts = 0u;
46     static constexpr uint32_t hasBiases = 1u;
47     static constexpr uint32_t NumLayers = 2u;
48     static constexpr uint32_t Dropout = 3u;
49     static constexpr uint32_t Train = 4u;
50     static constexpr uint32_t Bidirectional = 5u;
51     static constexpr uint32_t BatchFirst = 6u;
52 
53     static constexpr uint32_t NumArgs = 7u;
54   };
55 
56   static GruPackedContext pack(c10::impl::GenericList);
57 
58   const c10::impl::GenericList unpack() const override;
59 };
60 
61 c10::intrusive_ptr<GruPackedContext> create_gru_context(
62     std::vector<Tensor>&& params_cpu, // weights/biases (cpu)
63     bool has_biases,
64     int64_t num_layers,
65     double dropout,
66     bool train,
67     bool bidirectional,
68     bool batch_first);
69 
70 std::tuple<Tensor, Tensor> run_gru_context(
71     const Tensor& input_vk,
72     const Tensor& hx_vk,
73     const c10::intrusive_ptr<GruPackedContext>& vulkan_context);
74 
75 } // namespace ops
76 } // namespace vulkan
77 } // namespace native
78 } // namespace at
79 
80 #endif /* USE_VULKAN_API */
81