1 #include <gtest/gtest.h>
2 #include <torch/torch.h>
3 #include <limits>
4 #include <sstream>
5
ends_with(const std::string & str,const std::string & suffix)6 bool ends_with(const std::string& str, const std::string& suffix) {
7 const auto str_len = str.length();
8 const auto suffix_len = suffix.length();
9 return str_len < suffix_len ? false : suffix == str.substr(str_len - suffix_len, suffix_len);
10 }
11
TEST(MPSPrintTest,PrintFloatMatrix)12 TEST(MPSPrintTest, PrintFloatMatrix) {
13 std::stringstream ss;
14 ss << torch::randn({3, 3}, at::device(at::kMPS));
15 ASSERT_TRUE (ends_with(ss.str(), "[ MPSFloatType{3,3} ]")) << " got " << ss.str();
16 }
17
TEST(MPSPrintTest,PrintHalf4DTensor)18 TEST(MPSPrintTest, PrintHalf4DTensor) {
19 std::stringstream ss;
20 ss << torch::randn({2, 2, 2, 2}, at::device(at::kMPS).dtype(at::kHalf));
21 ASSERT_TRUE (ends_with(ss.str(), "[ MPSHalfType{2,2,2,2} ]")) << " got " << ss.str();
22 }
23
TEST(MPSPrintTest,PrintLongMatrix)24 TEST(MPSPrintTest, PrintLongMatrix) {
25 std::stringstream ss;
26 ss << torch::full({2, 2}, std::numeric_limits<int>::max(), at::device(at::kMPS));
27 ASSERT_TRUE (ends_with(ss.str(), "[ MPSLongType{2,2} ]")) << " got " << ss.str();
28 }
29
TEST(MPSPrintTest,PrintFloatScalar)30 TEST(MPSPrintTest, PrintFloatScalar) {
31 std::stringstream ss;
32 ss << torch::ones({}, at::device(at::kMPS));
33 ASSERT_TRUE(ss.str() == "1\n[ MPSFloatType{} ]") << " got " << ss.str();
34 }
35