• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <test/cpp/common/support.h>
4 
5 #include <gtest/gtest.h>
6 
7 #include <ATen/TensorIndexing.h>
8 #include <c10/util/Exception.h>
9 #include <torch/nn/cloneable.h>
10 #include <torch/types.h>
11 #include <torch/utils.h>
12 
13 #include <string>
14 #include <utility>
15 
16 namespace torch {
17 namespace test {
18 
19 // Lets you use a container without making a new class,
20 // for experimental implementations
21 class SimpleContainer : public nn::Cloneable<SimpleContainer> {
22  public:
reset()23   void reset() override {}
24 
25   template <typename ModuleHolder>
26   ModuleHolder add(
27       ModuleHolder module_holder,
28       std::string name = std::string()) {
29     return Module::register_module(std::move(name), module_holder);
30   }
31 };
32 
33 struct SeedingFixture : public ::testing::Test {
SeedingFixtureSeedingFixture34   SeedingFixture() {
35     torch::manual_seed(0);
36   }
37 };
38 
39 struct WarningCapture : public WarningHandler {
WarningCaptureWarningCapture40   WarningCapture() : prev_(WarningUtils::get_warning_handler()) {
41     WarningUtils::set_warning_handler(this);
42   }
43 
~WarningCaptureWarningCapture44   ~WarningCapture() override {
45     WarningUtils::set_warning_handler(prev_);
46   }
47 
messagesWarningCapture48   const std::vector<std::string>& messages() {
49     return messages_;
50   }
51 
strWarningCapture52   std::string str() {
53     return c10::Join("\n", messages_);
54   }
55 
processWarningCapture56   void process(const c10::Warning& warning) override {
57     messages_.push_back(warning.msg());
58   }
59 
60  private:
61   WarningHandler* prev_;
62   std::vector<std::string> messages_;
63 };
64 
pointer_equal(at::Tensor first,at::Tensor second)65 inline bool pointer_equal(at::Tensor first, at::Tensor second) {
66   return first.data_ptr() == second.data_ptr();
67 }
68 
69 // This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
70 // torch.Tensor)` branch in `TestCase.assertEqual` in
71 // torch/testing/_internal/common_utils.py
72 inline void assert_tensor_equal(
73     at::Tensor a,
74     at::Tensor b,
75     bool allow_inf = false) {
76   ASSERT_TRUE(a.sizes() == b.sizes());
77   if (a.numel() > 0) {
78     if (a.device().type() == torch::kCPU &&
79         (a.scalar_type() == torch::kFloat16 ||
80          a.scalar_type() == torch::kBFloat16)) {
81       // CPU half and bfloat16 tensors don't have the methods we need below
82       a = a.to(torch::kFloat32);
83     }
84     if (a.device().type() == torch::kCUDA &&
85         a.scalar_type() == torch::kBFloat16) {
86       // CUDA bfloat16 tensors don't have the methods we need below
87       a = a.to(torch::kFloat32);
88     }
89     b = b.to(a);
90 
91     if ((a.scalar_type() == torch::kBool) !=
92         (b.scalar_type() == torch::kBool)) {
93       TORCH_CHECK(false, "Was expecting both tensors to be bool type.");
94     } else {
95       if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) {
96         // we want to respect precision but as bool doesn't support subtraction,
97         // boolean tensor has to be converted to int
98         a = a.to(torch::kInt);
99         b = b.to(torch::kInt);
100       }
101 
102       auto diff = a - b;
103       if (a.is_floating_point()) {
104         // check that NaNs are in the same locations
105         auto nan_mask = torch::isnan(a);
106         ASSERT_TRUE(torch::equal(nan_mask, torch::isnan(b)));
107         diff.index_put_({nan_mask}, 0);
108         // inf check if allow_inf=true
109         if (allow_inf) {
110           auto inf_mask = torch::isinf(a);
111           auto inf_sign = inf_mask.sign();
112           ASSERT_TRUE(torch::equal(inf_sign, torch::isinf(b).sign()));
113           diff.index_put_({inf_mask}, 0);
114         }
115       }
116       // TODO: implement abs on CharTensor (int8)
117       if (diff.is_signed() && diff.scalar_type() != torch::kInt8) {
118         diff = diff.abs();
119       }
120       auto max_err = diff.max().item<double>();
121       ASSERT_LE(max_err, 1e-5);
122     }
123   }
124 }
125 
126 // This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
127 // torch.Tensor)` branch in `TestCase.assertNotEqual` in
128 // torch/testing/_internal/common_utils.py
assert_tensor_not_equal(at::Tensor x,at::Tensor y)129 inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) {
130   if (x.sizes() != y.sizes()) {
131     return;
132   }
133   ASSERT_GT(x.numel(), 0);
134   y = y.type_as(x);
135   y = x.is_cuda() ? y.to({torch::kCUDA, x.get_device()}) : y.cpu();
136   auto nan_mask = x != x;
137   if (torch::equal(nan_mask, y != y)) {
138     auto diff = x - y;
139     if (diff.is_signed()) {
140       diff = diff.abs();
141     }
142     diff.index_put_({nan_mask}, 0);
143     // Use `item()` to work around:
144     // https://github.com/pytorch/pytorch/issues/22301
145     auto max_err = diff.max().item<double>();
146     ASSERT_GE(max_err, 1e-5);
147   }
148 }
149 
count_substr_occurrences(const std::string & str,const std::string & substr)150 inline int count_substr_occurrences(
151     const std::string& str,
152     const std::string& substr) {
153   int count = 0;
154   size_t pos = str.find(substr);
155 
156   while (pos != std::string::npos) {
157     count++;
158     pos = str.find(substr, pos + substr.size());
159   }
160 
161   return count;
162 }
163 
164 // A RAII, thread local (!) guard that changes default dtype upon
165 // construction, and sets it back to the original dtype upon destruction.
166 //
167 // Usage of this guard is synchronized across threads, so that at any given
168 // time, only one guard can take effect.
169 struct AutoDefaultDtypeMode {
170   static std::mutex default_dtype_mutex;
171 
AutoDefaultDtypeModeAutoDefaultDtypeMode172   AutoDefaultDtypeMode(c10::ScalarType default_dtype)
173       : prev_default_dtype(
174             torch::typeMetaToScalarType(torch::get_default_dtype())) {
175     default_dtype_mutex.lock();
176     torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
177   }
~AutoDefaultDtypeModeAutoDefaultDtypeMode178   ~AutoDefaultDtypeMode() {
179     default_dtype_mutex.unlock();
180     torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
181   }
182   c10::ScalarType prev_default_dtype;
183 };
184 
assert_tensor_creation_meta(torch::Tensor & x,torch::autograd::CreationMeta creation_meta)185 inline void assert_tensor_creation_meta(
186     torch::Tensor& x,
187     torch::autograd::CreationMeta creation_meta) {
188   auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta();
189   TORCH_CHECK(autograd_meta);
190   auto view_meta =
191       static_cast<torch::autograd::DifferentiableViewMeta*>(autograd_meta);
192   TORCH_CHECK(view_meta->has_bw_view());
193   ASSERT_EQ(view_meta->get_creation_meta(), creation_meta);
194 }
195 } // namespace test
196 } // namespace torch
197