• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/MPSFunctions.h>
8 #else
9 #include <ATen/ops/eq_mps_dispatch.h>
10 #include <ATen/ops/equal_native.h>
11 #endif
12 
13 namespace at {
14 namespace mps {
15 TORCH_API at::Tensor eq(const at::Tensor & self, const at::Tensor & other);
16 } // namespace
17 namespace native {
18 
mps_equal(const Tensor & self,const Tensor & src)19 bool mps_equal(const Tensor& self, const Tensor &src) {
20   if (!at::namedinference::are_names_equal(
21           self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) {
22     return false;
23   }
24   at::NoNamesGuard guard;
25   TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on "
26               "different devices. Got: ", self.device(), " and ", src.device());
27   if (self.sizes() != src.sizes()) {
28     return false;
29   }
30   if (self.numel() == 0) {
31     return true;
32   }
33   return at::mps::eq(self, src).all().item().to<bool>();
34 }
35 
36 } // namespace native
37 } // namespace at
38