• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/nn/functional/padding.h>
4 #include <torch/nn/functional/pooling.h>
5 #include <torch/nn/options/normalization.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 namespace functional {
11 
12 #ifndef DOXYGEN_SHOULD_SKIP_THIS
13 namespace detail {
normalize(const Tensor & input,double p,int64_t dim,double eps,std::optional<Tensor> out)14 inline Tensor normalize(
15     const Tensor& input,
16     double p,
17     int64_t dim,
18     double eps,
19     std::optional<Tensor> out) {
20   if (out == std::nullopt) {
21     auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input);
22     return input / denom;
23   } else {
24     auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input);
25     return torch::div_out(*out, input, denom);
26   }
27 }
28 } // namespace detail
29 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
30 
31 /// See
32 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.normalize
33 /// about the exact behavior of this functional.
34 ///
35 /// See the documentation for `torch::nn::functional::NormalizeFuncOptions`
36 /// class to learn what optional arguments are supported for this functional.
37 ///
38 /// Example:
39 /// ```
40 /// namespace F = torch::nn::functional;
41 /// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
42 /// ```
43 inline Tensor normalize(
44     const Tensor& input,
45     NormalizeFuncOptions options = {}) {
46   return detail::normalize(
47       input, options.p(), options.dim(), options.eps(), options.out());
48 }
49 
50 // ============================================================================
51 
52 #ifndef DOXYGEN_SHOULD_SKIP_THIS
53 namespace detail {
layer_norm(const Tensor & input,const std::vector<int64_t> & normalized_shape,const Tensor & weight,const Tensor & bias,double eps)54 inline Tensor layer_norm(
55     const Tensor& input,
56     const std::vector<int64_t>& normalized_shape,
57     const Tensor& weight,
58     const Tensor& bias,
59     double eps) {
60   return torch::layer_norm(input, normalized_shape, weight, bias, eps);
61 }
62 } // namespace detail
63 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
64 
65 /// See
66 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.layer_norm
67 /// about the exact behavior of this functional.
68 ///
69 /// See the documentation for `torch::nn::functional::LayerNormFuncOptions`
70 /// class to learn what optional arguments are supported for this functional.
71 ///
72 /// Example:
73 /// ```
74 /// namespace F = torch::nn::functional;
75 /// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
76 /// ```
layer_norm(const Tensor & input,const LayerNormFuncOptions & options)77 inline Tensor layer_norm(
78     const Tensor& input,
79     const LayerNormFuncOptions& options) {
80   return detail::layer_norm(
81       input,
82       options.normalized_shape(),
83       options.weight(),
84       options.bias(),
85       options.eps());
86 }
87 
88 // ============================================================================
89 
90 #ifndef DOXYGEN_SHOULD_SKIP_THIS
91 namespace detail {
local_response_norm(const Tensor & input,int64_t size,double alpha,double beta,double k)92 inline Tensor local_response_norm(
93     const Tensor& input,
94     int64_t size,
95     double alpha,
96     double beta,
97     double k) {
98   auto dim = input.dim();
99   TORCH_CHECK(
100       dim >= 3,
101       "Expected 3D or higher dimensionality input (got ",
102       dim,
103       " dimensions)");
104   auto div = input.mul(input).unsqueeze(1);
105   if (dim == 3) {
106     div = detail::pad(
107         div,
108         /*pad=*/{0, 0, size / 2, (size - 1) / 2},
109         /*mode=*/torch::kConstant,
110         /*value=*/0);
111     div = detail::avg_pool2d(
112               div,
113               /*kernel_size=*/{size, 1},
114               /*stride=*/1,
115               /*padding=*/0,
116               /*ceil_mode=*/false,
117               /*count_include_pad=*/true,
118               /*divisor_override=*/std::nullopt)
119               .squeeze(1);
120   } else {
121     auto sizes = input.sizes();
122     div = div.view({sizes[0], 1, sizes[1], sizes[2], -1});
123     div = detail::pad(
124         div,
125         /*pad=*/{0, 0, 0, 0, size / 2, (size - 1) / 2},
126         /*mode=*/torch::kConstant,
127         /*value=*/0);
128     div = detail::avg_pool3d(
129               div,
130               /*kernel_size=*/{size, 1, 1},
131               /*stride=*/1,
132               /*padding=*/0,
133               /*ceil_mode=*/false,
134               /*count_include_pad=*/true,
135               /*divisor_override=*/std::nullopt)
136               .squeeze(1);
137     div = div.view(sizes);
138   }
139   div = div.mul(alpha).add(k).pow(beta);
140   return input / div;
141 }
142 } // namespace detail
143 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
144 
145 /// See
146 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.local_response_norm
147 /// about the exact behavior of this functional.
148 ///
149 /// See the documentation for
150 /// `torch::nn::functional::LocalResponseNormFuncOptions` class to learn what
151 /// optional arguments are supported for this functional.
152 ///
153 /// Example:
154 /// ```
155 /// namespace F = torch::nn::functional;
156 /// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
157 /// ```
local_response_norm(const Tensor & input,const LocalResponseNormFuncOptions & options)158 inline Tensor local_response_norm(
159     const Tensor& input,
160     const LocalResponseNormFuncOptions& options) {
161   return detail::local_response_norm(
162       input, options.size(), options.alpha(), options.beta(), options.k());
163 }
164 
165 // ============================================================================
166 
167 #ifndef DOXYGEN_SHOULD_SKIP_THIS
168 namespace detail {
group_norm(const Tensor & input,int64_t num_groups,const Tensor & weight,const Tensor & bias,double eps)169 inline Tensor group_norm(
170     const Tensor& input,
171     int64_t num_groups,
172     const Tensor& weight,
173     const Tensor& bias,
174     double eps) {
175   return torch::group_norm(
176       input,
177       num_groups,
178       weight,
179       bias,
180       eps,
181       at::globalContext().userEnabledCuDNN());
182 }
183 } // namespace detail
184 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
185 
186 /// See
187 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.group_norm
188 /// about the exact behavior of this functional.
189 ///
190 /// See the documentation for `torch::nn::functional::GroupNormFuncOptions`
191 /// class to learn what optional arguments are supported for this functional.
192 ///
193 /// Example:
194 /// ```
195 /// namespace F = torch::nn::functional;
196 /// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
197 /// ```
group_norm(const Tensor & input,const GroupNormFuncOptions & options)198 inline Tensor group_norm(
199     const Tensor& input,
200     const GroupNormFuncOptions& options) {
201   return detail::group_norm(
202       input,
203       options.num_groups(),
204       options.weight(),
205       options.bias(),
206       options.eps());
207 }
208 
209 } // namespace functional
210 } // namespace nn
211 } // namespace torch
212