• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
7 #include <torch/library.h>
8 
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace ops {
13 
14 class LayernormPackedContext final : virtual public VulkanPackedContext,
15                                      public torch::jit::CustomClassHolder {
16  private:
17   c10::impl::GenericList unpacked_;
18 
19  public:
20   LayernormPackedContext(
21       const std::optional<Tensor>& weight,
22       const std::optional<Tensor>& bias,
23       double eps);
24 
25   /*
26    * Assigns a name to each index in the unpacked list.
27    */
28   struct ListArgs final {
29     static constexpr uint32_t kWeight = 0u;
30     static constexpr uint32_t kBias = 1u;
31     static constexpr uint32_t kEps = 2u;
32 
33     static constexpr uint32_t kNumArgs = 3u;
34   };
35 
36   static LayernormPackedContext pack(const c10::impl::GenericList);
37 
unpack()38   const c10::impl::GenericList unpack() const override {
39     TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
40 
41     return unpacked_;
42   }
43 };
44 
45 c10::intrusive_ptr<LayernormPackedContext> create_layernorm_context(
46     std::optional<Tensor>&& weight,
47     std::optional<Tensor>&& bias,
48     double eps);
49 
50 Tensor run_layernorm_context(
51     const Tensor& input,
52     IntArrayRef normalized_shape,
53     const c10::intrusive_ptr<LayernormPackedContext>& context);
54 
55 } // namespace ops
56 } // namespace vulkan
57 } // namespace native
58 } // namespace at
59 
60 #endif /* USE_VULKAN_API */
61