• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 #include <torch/torch.h>
3 #include <limits>
4 #include <sstream>
5 
ends_with(const std::string & str,const std::string & suffix)6 bool ends_with(const std::string& str, const std::string& suffix) {
7   const auto str_len = str.length();
8   const auto suffix_len = suffix.length();
9   return str_len < suffix_len ? false : suffix == str.substr(str_len - suffix_len, suffix_len);
10 }
11 
TEST(MPSPrintTest,PrintFloatMatrix)12 TEST(MPSPrintTest, PrintFloatMatrix) {
13   std::stringstream ss;
14   ss << torch::randn({3, 3}, at::device(at::kMPS));
15   ASSERT_TRUE (ends_with(ss.str(), "[ MPSFloatType{3,3} ]")) << " got " << ss.str();
16 }
17 
TEST(MPSPrintTest,PrintHalf4DTensor)18 TEST(MPSPrintTest, PrintHalf4DTensor) {
19   std::stringstream ss;
20   ss << torch::randn({2, 2, 2, 2}, at::device(at::kMPS).dtype(at::kHalf));
21   ASSERT_TRUE (ends_with(ss.str(), "[ MPSHalfType{2,2,2,2} ]")) << " got " << ss.str();
22 }
23 
TEST(MPSPrintTest,PrintLongMatrix)24 TEST(MPSPrintTest, PrintLongMatrix) {
25   std::stringstream ss;
26   ss << torch::full({2, 2}, std::numeric_limits<int>::max(), at::device(at::kMPS));
27   ASSERT_TRUE (ends_with(ss.str(), "[ MPSLongType{2,2} ]")) << " got " << ss.str();
28 }
29 
TEST(MPSPrintTest,PrintFloatScalar)30 TEST(MPSPrintTest, PrintFloatScalar) {
31   std::stringstream ss;
32   ss << torch::ones({}, at::device(at::kMPS));
33   ASSERT_TRUE(ss.str() == "1\n[ MPSFloatType{} ]") << " got " << ss.str();
34 }
35