1 #include <ATen/native/vulkan/ops/Layernorm.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3
4 #include <ATen/Context.h>
5 #include <c10/util/irange.h>
6
7 #include <ATen/native/vulkan/ops/Common.h>
8 #include <torch/library.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/native_layer_norm.h>
14 #endif
15
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20
LayernormPackedContext(const std::optional<Tensor> & weight,const std::optional<Tensor> & bias,double eps)21 LayernormPackedContext::LayernormPackedContext(
22 const std::optional<Tensor>& weight,
23 const std::optional<Tensor>& bias,
24 double eps)
25 : unpacked_{c10::AnyType::get()} {
26 packed_.reserve(ListArgs::kNumArgs);
27
28 TORCH_CHECK(weight, "Weight must be provided!");
29 packed_.emplace_back(weight->vulkan());
30 TORCH_CHECK(bias, "Bias must be provided!");
31 packed_.emplace_back(bias->vulkan());
32 packed_.emplace_back(eps);
33
34 if (!at::globalContext().releaseWeightsWhenPrepacking()) {
35 unpacked_.reserve(ListArgs::kNumArgs);
36 unpacked_.emplace_back(weight);
37 unpacked_.emplace_back(bias);
38 unpacked_.emplace_back(eps);
39 }
40 }
41
pack(c10::impl::GenericList unpacked)42 LayernormPackedContext LayernormPackedContext::pack(
43 c10::impl::GenericList unpacked) {
44 return LayernormPackedContext(
45 get_optional_tensor(unpacked, ListArgs::kWeight),
46 get_optional_tensor(unpacked, ListArgs::kBias),
47 unpacked.get(ListArgs::kEps).toDouble());
48 }
49
create_layernorm_context(std::optional<Tensor> && weight,std::optional<Tensor> && bias,double eps)50 c10::intrusive_ptr<LayernormPackedContext> create_layernorm_context(
51 std::optional<Tensor>&& weight,
52 std::optional<Tensor>&& bias,
53 double eps) {
54 return c10::make_intrusive<LayernormPackedContext>(
55 LayernormPackedContext(weight, bias, eps));
56 }
57
run_layernorm_context(const Tensor & input_arg,IntArrayRef normalized_shape,const c10::intrusive_ptr<LayernormPackedContext> & layernorm_context)58 Tensor run_layernorm_context(
59 const Tensor& input_arg,
60 IntArrayRef normalized_shape,
61 const c10::intrusive_ptr<LayernormPackedContext>& layernorm_context) {
62 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
63
64 const std::optional<Tensor>& weight_opt =
65 layernorm_context->get_val(LayernormPackedContext::ListArgs::kWeight)
66 .toTensor();
67 const std::optional<Tensor>& bias_opt =
68 layernorm_context->get_val(LayernormPackedContext::ListArgs::kBias)
69 .toTensor();
70 const float eps = api::utils::safe_downcast<float>(
71 layernorm_context->get_val(LayernormPackedContext::ListArgs::kEps)
72 .toDouble());
73
74 // We invoke native_layer_norm which returns a tuple of tensors: <layer_norm,
75 // mean, 1/sqrt(var+eps)>, but we only need the first tensor (layer_norm).
76 std::tuple<Tensor, Tensor, Tensor> native_layer_norm_output =
77 at::native_layer_norm(input, normalized_shape, weight_opt, bias_opt, eps);
78 return std::get<0>(native_layer_norm_output);
79 }
80
layer_norm(const at::Tensor & input_arg,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)81 static Tensor layer_norm(
82 const at::Tensor& input_arg,
83 IntArrayRef normalized_shape,
84 const std::optional<Tensor>& weight_opt /* optional */,
85 const std::optional<Tensor>& bias_opt /* optional */,
86 double eps,
87 bool /* cudnn_enable, deprecated */) {
88 return run_layernorm_context(
89 input_arg,
90 normalized_shape,
91 c10::make_intrusive<LayernormPackedContext>(
92 LayernormPackedContext(weight_opt, bias_opt, eps)));
93 }
94
95 #ifdef USE_VULKAN_API
96
TORCH_LIBRARY_IMPL(aten,Vulkan,m)97 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
98 m.impl(TORCH_SELECTIVE_NAME("aten::layer_norm"), TORCH_FN(layer_norm));
99 }
100
101 #endif /* USE_VULKAN_API */
102
103 } // namespace ops
104 } // namespace vulkan
105 } // namespace native
106 } // namespace at
107