1 #include <ATen/native/vulkan/ops/Common.h> 2 3 #include <torch/library.h> 4 5 namespace at { 6 namespace native { 7 namespace vulkan { 8 namespace ops { 9 namespace { 10 11 using namespace api::utils; 12 _local_scalar_dense(const Tensor & self)13Scalar _local_scalar_dense(const Tensor& self) { 14 TORCH_CHECK( 15 self.dtype() == ScalarType::Float, "Only float dtype is supported"); 16 return Scalar(self.cpu().item<float>()); 17 } 18 19 #ifdef USE_VULKAN_API 20 TORCH_LIBRARY_IMPL(aten,Vulkan,m)21TORCH_LIBRARY_IMPL(aten, Vulkan, m) { 22 m.impl( 23 TORCH_SELECTIVE_NAME("aten::_local_scalar_dense"), 24 TORCH_FN(_local_scalar_dense)); 25 } 26 27 #endif /* USE_VULKAN_API */ 28 29 } // namespace 30 } // namespace ops 31 } // namespace vulkan 32 } // namespace native 33 } // namespace at 34