1 #pragma once
2
3 #include <torch/nn/functional/padding.h>
4 #include <torch/nn/functional/pooling.h>
5 #include <torch/nn/options/normalization.h>
6 #include <torch/types.h>
7
8 namespace torch {
9 namespace nn {
10 namespace functional {
11
12 #ifndef DOXYGEN_SHOULD_SKIP_THIS
13 namespace detail {
normalize(const Tensor & input,double p,int64_t dim,double eps,std::optional<Tensor> out)14 inline Tensor normalize(
15 const Tensor& input,
16 double p,
17 int64_t dim,
18 double eps,
19 std::optional<Tensor> out) {
20 if (out == std::nullopt) {
21 auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input);
22 return input / denom;
23 } else {
24 auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input);
25 return torch::div_out(*out, input, denom);
26 }
27 }
28 } // namespace detail
29 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
30
31 /// See
32 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.normalize
33 /// about the exact behavior of this functional.
34 ///
35 /// See the documentation for `torch::nn::functional::NormalizeFuncOptions`
36 /// class to learn what optional arguments are supported for this functional.
37 ///
38 /// Example:
39 /// ```
40 /// namespace F = torch::nn::functional;
41 /// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
42 /// ```
43 inline Tensor normalize(
44 const Tensor& input,
45 NormalizeFuncOptions options = {}) {
46 return detail::normalize(
47 input, options.p(), options.dim(), options.eps(), options.out());
48 }
49
50 // ============================================================================
51
52 #ifndef DOXYGEN_SHOULD_SKIP_THIS
53 namespace detail {
layer_norm(const Tensor & input,const std::vector<int64_t> & normalized_shape,const Tensor & weight,const Tensor & bias,double eps)54 inline Tensor layer_norm(
55 const Tensor& input,
56 const std::vector<int64_t>& normalized_shape,
57 const Tensor& weight,
58 const Tensor& bias,
59 double eps) {
60 return torch::layer_norm(input, normalized_shape, weight, bias, eps);
61 }
62 } // namespace detail
63 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
64
65 /// See
66 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.layer_norm
67 /// about the exact behavior of this functional.
68 ///
69 /// See the documentation for `torch::nn::functional::LayerNormFuncOptions`
70 /// class to learn what optional arguments are supported for this functional.
71 ///
72 /// Example:
73 /// ```
74 /// namespace F = torch::nn::functional;
75 /// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
76 /// ```
layer_norm(const Tensor & input,const LayerNormFuncOptions & options)77 inline Tensor layer_norm(
78 const Tensor& input,
79 const LayerNormFuncOptions& options) {
80 return detail::layer_norm(
81 input,
82 options.normalized_shape(),
83 options.weight(),
84 options.bias(),
85 options.eps());
86 }
87
88 // ============================================================================
89
90 #ifndef DOXYGEN_SHOULD_SKIP_THIS
91 namespace detail {
local_response_norm(const Tensor & input,int64_t size,double alpha,double beta,double k)92 inline Tensor local_response_norm(
93 const Tensor& input,
94 int64_t size,
95 double alpha,
96 double beta,
97 double k) {
98 auto dim = input.dim();
99 TORCH_CHECK(
100 dim >= 3,
101 "Expected 3D or higher dimensionality input (got ",
102 dim,
103 " dimensions)");
104 auto div = input.mul(input).unsqueeze(1);
105 if (dim == 3) {
106 div = detail::pad(
107 div,
108 /*pad=*/{0, 0, size / 2, (size - 1) / 2},
109 /*mode=*/torch::kConstant,
110 /*value=*/0);
111 div = detail::avg_pool2d(
112 div,
113 /*kernel_size=*/{size, 1},
114 /*stride=*/1,
115 /*padding=*/0,
116 /*ceil_mode=*/false,
117 /*count_include_pad=*/true,
118 /*divisor_override=*/std::nullopt)
119 .squeeze(1);
120 } else {
121 auto sizes = input.sizes();
122 div = div.view({sizes[0], 1, sizes[1], sizes[2], -1});
123 div = detail::pad(
124 div,
125 /*pad=*/{0, 0, 0, 0, size / 2, (size - 1) / 2},
126 /*mode=*/torch::kConstant,
127 /*value=*/0);
128 div = detail::avg_pool3d(
129 div,
130 /*kernel_size=*/{size, 1, 1},
131 /*stride=*/1,
132 /*padding=*/0,
133 /*ceil_mode=*/false,
134 /*count_include_pad=*/true,
135 /*divisor_override=*/std::nullopt)
136 .squeeze(1);
137 div = div.view(sizes);
138 }
139 div = div.mul(alpha).add(k).pow(beta);
140 return input / div;
141 }
142 } // namespace detail
143 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
144
145 /// See
146 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.local_response_norm
147 /// about the exact behavior of this functional.
148 ///
149 /// See the documentation for
150 /// `torch::nn::functional::LocalResponseNormFuncOptions` class to learn what
151 /// optional arguments are supported for this functional.
152 ///
153 /// Example:
154 /// ```
155 /// namespace F = torch::nn::functional;
156 /// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
157 /// ```
local_response_norm(const Tensor & input,const LocalResponseNormFuncOptions & options)158 inline Tensor local_response_norm(
159 const Tensor& input,
160 const LocalResponseNormFuncOptions& options) {
161 return detail::local_response_norm(
162 input, options.size(), options.alpha(), options.beta(), options.k());
163 }
164
165 // ============================================================================
166
167 #ifndef DOXYGEN_SHOULD_SKIP_THIS
168 namespace detail {
group_norm(const Tensor & input,int64_t num_groups,const Tensor & weight,const Tensor & bias,double eps)169 inline Tensor group_norm(
170 const Tensor& input,
171 int64_t num_groups,
172 const Tensor& weight,
173 const Tensor& bias,
174 double eps) {
175 return torch::group_norm(
176 input,
177 num_groups,
178 weight,
179 bias,
180 eps,
181 at::globalContext().userEnabledCuDNN());
182 }
183 } // namespace detail
184 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
185
186 /// See
187 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.group_norm
188 /// about the exact behavior of this functional.
189 ///
190 /// See the documentation for `torch::nn::functional::GroupNormFuncOptions`
191 /// class to learn what optional arguments are supported for this functional.
192 ///
193 /// Example:
194 /// ```
195 /// namespace F = torch::nn::functional;
196 /// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
197 /// ```
group_norm(const Tensor & input,const GroupNormFuncOptions & options)198 inline Tensor group_norm(
199 const Tensor& input,
200 const GroupNormFuncOptions& options) {
201 return detail::group_norm(
202 input,
203 options.num_groups(),
204 options.weight(),
205 options.bias(),
206 options.eps());
207 }
208
209 } // namespace functional
210 } // namespace nn
211 } // namespace torch
212