• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/Dispatch.h>
4#include <ATen/native/mps/Copy.h>
5#include <ATen/native/mps/OperationUtils.h>
6#include <ATen/ops/_local_scalar_dense_native.h>
7
8#ifdef __OBJC__
9#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
10#endif
11
12using namespace at::mps;
13
14namespace at::native {
15
16Scalar _local_scalar_dense_mps(const Tensor& self) {
17  Scalar r;
18
19  auto output = at::empty_like(self, TensorOptions(kCPU));
20  mps::mps_copy_(output, self, false);
21  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
22                                         at::ScalarType::Bool,
23                                         at::ScalarType::BFloat16,
24                                         self.scalar_type(),
25                                         "_local_scalar_dense_mps",
26                                         [&] {
27                                           scalar_t value = *output.data_ptr<scalar_t>();
28                                           r = Scalar(value);
29                                         });
30
31  return r;
32}
33
34} // namespace at::native
35