• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/ops/Common.h>
2 
3 #include <torch/library.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10 
11 using namespace api::utils;
12 
_local_scalar_dense(const Tensor & self)13 Scalar _local_scalar_dense(const Tensor& self) {
14   TORCH_CHECK(
15       self.dtype() == ScalarType::Float, "Only float dtype is supported");
16   return Scalar(self.cpu().item<float>());
17 }
18 
19 #ifdef USE_VULKAN_API
20 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)21 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
22   m.impl(
23       TORCH_SELECTIVE_NAME("aten::_local_scalar_dense"),
24       TORCH_FN(_local_scalar_dense));
25 }
26 
27 #endif /* USE_VULKAN_API */
28 
29 } // namespace
30 } // namespace ops
31 } // namespace vulkan
32 } // namespace native
33 } // namespace at
34