1 #pragma once 2 3 #ifdef USE_VULKAN_API 4 5 #include <ATen/native/vulkan/ops/Common.h> 6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h> 7 #include <torch/library.h> 8 9 namespace at { 10 namespace native { 11 namespace vulkan { 12 namespace ops { 13 14 class LayernormPackedContext final : virtual public VulkanPackedContext, 15 public torch::jit::CustomClassHolder { 16 private: 17 c10::impl::GenericList unpacked_; 18 19 public: 20 LayernormPackedContext( 21 const std::optional<Tensor>& weight, 22 const std::optional<Tensor>& bias, 23 double eps); 24 25 /* 26 * Assigns a name to each index in the unpacked list. 27 */ 28 struct ListArgs final { 29 static constexpr uint32_t kWeight = 0u; 30 static constexpr uint32_t kBias = 1u; 31 static constexpr uint32_t kEps = 2u; 32 33 static constexpr uint32_t kNumArgs = 3u; 34 }; 35 36 static LayernormPackedContext pack(const c10::impl::GenericList); 37 unpack()38 const c10::impl::GenericList unpack() const override { 39 TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!"); 40 41 return unpacked_; 42 } 43 }; 44 45 c10::intrusive_ptr<LayernormPackedContext> create_layernorm_context( 46 std::optional<Tensor>&& weight, 47 std::optional<Tensor>&& bias, 48 double eps); 49 50 Tensor run_layernorm_context( 51 const Tensor& input, 52 IntArrayRef normalized_shape, 53 const c10::intrusive_ptr<LayernormPackedContext>& context); 54 55 } // namespace ops 56 } // namespace vulkan 57 } // namespace native 58 } // namespace at 59 60 #endif /* USE_VULKAN_API */ 61