• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
9 DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
10 
11 enum class BatchNormBackend {
12   Native,
13   Cudnn,
14   Miopen,
15 };
16 
17 TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
18 
19 }  // namespace at::native
20