• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/ops/Layernorm.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 
4 #include <ATen/Context.h>
5 #include <c10/util/irange.h>
6 
7 #include <ATen/native/vulkan/ops/Common.h>
8 #include <torch/library.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/native_layer_norm.h>
14 #endif
15 
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20 
LayernormPackedContext(const std::optional<Tensor> & weight,const std::optional<Tensor> & bias,double eps)21 LayernormPackedContext::LayernormPackedContext(
22     const std::optional<Tensor>& weight,
23     const std::optional<Tensor>& bias,
24     double eps)
25     : unpacked_{c10::AnyType::get()} {
26   packed_.reserve(ListArgs::kNumArgs);
27 
28   TORCH_CHECK(weight, "Weight must be provided!");
29   packed_.emplace_back(weight->vulkan());
30   TORCH_CHECK(bias, "Bias must be provided!");
31   packed_.emplace_back(bias->vulkan());
32   packed_.emplace_back(eps);
33 
34   if (!at::globalContext().releaseWeightsWhenPrepacking()) {
35     unpacked_.reserve(ListArgs::kNumArgs);
36     unpacked_.emplace_back(weight);
37     unpacked_.emplace_back(bias);
38     unpacked_.emplace_back(eps);
39   }
40 }
41 
pack(c10::impl::GenericList unpacked)42 LayernormPackedContext LayernormPackedContext::pack(
43     c10::impl::GenericList unpacked) {
44   return LayernormPackedContext(
45       get_optional_tensor(unpacked, ListArgs::kWeight),
46       get_optional_tensor(unpacked, ListArgs::kBias),
47       unpacked.get(ListArgs::kEps).toDouble());
48 }
49 
create_layernorm_context(std::optional<Tensor> && weight,std::optional<Tensor> && bias,double eps)50 c10::intrusive_ptr<LayernormPackedContext> create_layernorm_context(
51     std::optional<Tensor>&& weight,
52     std::optional<Tensor>&& bias,
53     double eps) {
54   return c10::make_intrusive<LayernormPackedContext>(
55       LayernormPackedContext(weight, bias, eps));
56 }
57 
run_layernorm_context(const Tensor & input_arg,IntArrayRef normalized_shape,const c10::intrusive_ptr<LayernormPackedContext> & layernorm_context)58 Tensor run_layernorm_context(
59     const Tensor& input_arg,
60     IntArrayRef normalized_shape,
61     const c10::intrusive_ptr<LayernormPackedContext>& layernorm_context) {
62   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
63 
64   const std::optional<Tensor>& weight_opt =
65       layernorm_context->get_val(LayernormPackedContext::ListArgs::kWeight)
66           .toTensor();
67   const std::optional<Tensor>& bias_opt =
68       layernorm_context->get_val(LayernormPackedContext::ListArgs::kBias)
69           .toTensor();
70   const float eps = api::utils::safe_downcast<float>(
71       layernorm_context->get_val(LayernormPackedContext::ListArgs::kEps)
72           .toDouble());
73 
74   // We invoke native_layer_norm which returns a tuple of tensors: <layer_norm,
75   // mean, 1/sqrt(var+eps)>, but we only need the first tensor (layer_norm).
76   std::tuple<Tensor, Tensor, Tensor> native_layer_norm_output =
77       at::native_layer_norm(input, normalized_shape, weight_opt, bias_opt, eps);
78   return std::get<0>(native_layer_norm_output);
79 }
80 
layer_norm(const at::Tensor & input_arg,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)81 static Tensor layer_norm(
82     const at::Tensor& input_arg,
83     IntArrayRef normalized_shape,
84     const std::optional<Tensor>& weight_opt /* optional */,
85     const std::optional<Tensor>& bias_opt /* optional */,
86     double eps,
87     bool /* cudnn_enable, deprecated */) {
88   return run_layernorm_context(
89       input_arg,
90       normalized_shape,
91       c10::make_intrusive<LayernormPackedContext>(
92           LayernormPackedContext(weight_opt, bias_opt, eps)));
93 }
94 
95 #ifdef USE_VULKAN_API
96 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)97 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
98   m.impl(TORCH_SELECTIVE_NAME("aten::layer_norm"), TORCH_FN(layer_norm));
99 }
100 
101 #endif /* USE_VULKAN_API */
102 
103 } // namespace ops
104 } // namespace vulkan
105 } // namespace native
106 } // namespace at
107