1 #pragma once 2 3 #ifdef USE_VULKAN_API 4 5 #include <ATen/native/vulkan/ops/Common.h> 6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h> 7 #include <torch/library.h> 8 9 namespace at { 10 namespace native { 11 namespace vulkan { 12 namespace ops { 13 14 class LstmPackedContext final : virtual public VulkanPackedContext, 15 public torch::jit::CustomClassHolder { 16 public: 17 LstmPackedContext( 18 const std::vector<Tensor>& params_cpu, // weights/biases (cpu) 19 bool has_biases, 20 int64_t num_layers, 21 double dropout, 22 bool train, 23 bool bidirectional, 24 bool batch_first); 25 26 /* 27 * Assigns a name to each index in the unpacked list. 28 */ 29 struct Unpacked final { 30 static constexpr uint32_t Params = 0u; 31 static constexpr uint32_t hasBiases = 1u; 32 static constexpr uint32_t NumLayers = 2u; 33 static constexpr uint32_t Dropout = 3u; 34 static constexpr uint32_t Train = 4u; 35 static constexpr uint32_t Bidirectional = 5u; 36 static constexpr uint32_t BatchFirst = 6u; 37 38 static constexpr uint32_t NumArgs = 7u; 39 }; 40 41 /* 42 * Assigns a name to each index in the packed list. 43 */ 44 struct Packed final { 45 static constexpr uint32_t LinearContexts = 0u; 46 static constexpr uint32_t hasBiases = 1u; 47 static constexpr uint32_t NumLayers = 2u; 48 static constexpr uint32_t Dropout = 3u; 49 static constexpr uint32_t Train = 4u; 50 static constexpr uint32_t Bidirectional = 5u; 51 static constexpr uint32_t BatchFirst = 6u; 52 53 static constexpr uint32_t NumArgs = 7u; 54 }; 55 56 static LstmPackedContext pack(c10::impl::GenericList); 57 58 const c10::impl::GenericList unpack() const override; 59 }; 60 61 c10::intrusive_ptr<LstmPackedContext> create_lstm_context( 62 std::vector<Tensor>&& params_cpu, // weights/biases (cpu) 63 bool has_biases, 64 int64_t num_layers, 65 double dropout, 66 bool train, 67 bool bidirectional, 68 bool batch_first); 69 70 std::tuple<Tensor, Tensor, Tensor> run_lstm_context( 71 const Tensor& input_vk, // input sequence (vulkan) 72 const Tensor& hx_vk, // initial hidden state (vulkan) 73 const Tensor& cx_vk, // initial cell state (vulkan) 74 const c10::intrusive_ptr<LstmPackedContext>& vulkan_context); 75 76 } // namespace ops 77 } // namespace vulkan 78 } // namespace native 79 } // namespace at 80 81 #endif /* USE_VULKAN_API */ 82