1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2 #include <ATen/core/Tensor.h> 3 #include <ATen/NamedTensorUtils.h> 4 5 #ifndef AT_PER_OPERATOR_HEADERS 6 #include <ATen/NativeFunctions.h> 7 #include <ATen/MPSFunctions.h> 8 #else 9 #include <ATen/ops/eq_mps_dispatch.h> 10 #include <ATen/ops/equal_native.h> 11 #endif 12 13 namespace at { 14 namespace mps { 15 TORCH_API at::Tensor eq(const at::Tensor & self, const at::Tensor & other); 16 } // namespace 17 namespace native { 18 mps_equal(const Tensor & self,const Tensor & src)19bool mps_equal(const Tensor& self, const Tensor &src) { 20 if (!at::namedinference::are_names_equal( 21 self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) { 22 return false; 23 } 24 at::NoNamesGuard guard; 25 TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on " 26 "different devices. Got: ", self.device(), " and ", src.device()); 27 if (self.sizes() != src.sizes()) { 28 return false; 29 } 30 if (self.numel() == 0) { 31 return true; 32 } 33 return at::mps::eq(self, src).all().item().to<bool>(); 34 } 35 36 } // namespace native 37 } // namespace at 38