• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/ATen.h>
2 #include <ATen/native/TensorIterator.h>
3 #include <ATen/native/cuda/Loops.cuh>
4 
5 namespace at {
6 namespace native {
7 
relu_quantized_cuda_(Tensor & self)8 Tensor& relu_quantized_cuda_(Tensor& self) {
9   const auto zero_point = self.q_zero_point();
10   AT_DISPATCH_QINT_TYPES(
11     self.scalar_type(), "qrelu_cuda", [&]() {
12       auto iter = TensorIterator::unary_op(self, self);
13       gpu_kernel(iter, [zero_point] GPU_LAMBDA(scalar_t value) -> scalar_t {
14         return scalar_t(std::max<underlying_t>(value.val_, zero_point));
15         });
16   });
17   return self;
18 }
19 
20 }  // namespace at::native
21 }  // namespace at
22