• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <c10/util/irange.h>
2 #include <torch/script.h>
3 #include <torch/cuda.h>
4 
5 #include "op.h"
6 
7 #include <memory>
8 #include <string>
9 #include <vector>
10 
11 #include <iostream>
12 
13 namespace helpers {
14 template <typename Predicate>
check_all_parameters(const torch::jit::Module & module,Predicate predicate)15 void check_all_parameters(
16     const torch::jit::Module& module,
17     Predicate predicate) {
18   for (at::Tensor parameter : module.parameters()) {
19     AT_ASSERT(predicate(parameter));
20   }
21 }
22 
23 template<class Result, class... Args>
get_operator_from_registry_and_execute(const char * op_name,Args &&...args)24 Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) {
25   auto& ops = torch::jit::getAllOperatorsFor(
26       torch::jit::Symbol::fromQualString(op_name));
27   TORCH_INTERNAL_ASSERT(ops.size() == 1);
28 
29   auto& op = ops.front();
30   TORCH_INTERNAL_ASSERT(op->schema().name() == op_name);
31 
32   torch::jit::Stack stack;
33   torch::jit::push(stack, std::forward<Args>(args)...);
34   op->getOperation()(stack);
35 
36   TORCH_INTERNAL_ASSERT(1 == stack.size());
37   return torch::jit::pop(stack).to<Result>();
38 }
39 } // namespace helpers
40 
get_operator_from_registry_and_execute()41 void get_operator_from_registry_and_execute() {
42   std::vector<torch::Tensor> output =
43     helpers::get_operator_from_registry_and_execute<std::vector<torch::Tensor>>("custom::op", torch::ones(5), 2.0, 3);
44 
45   const auto manual = custom_op(torch::ones(5), 2.0, 3);
46 
47   TORCH_INTERNAL_ASSERT(output.size() == 3);
48   for (const auto i : c10::irange(output.size())) {
49     TORCH_INTERNAL_ASSERT(output[i].allclose(torch::ones(5) * 2));
50     TORCH_INTERNAL_ASSERT(output[i].allclose(manual[i]));
51   }
52 }
53 
get_autograd_operator_from_registry_and_execute()54 void get_autograd_operator_from_registry_and_execute() {
55   torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
56   torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
57   torch::Tensor z = torch::randn({5,5}, torch::requires_grad());
58 
59   torch::Tensor output =
60     helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
61 
62   TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
63   auto go = torch::ones({}, torch::requires_grad());
64   output.sum().backward(go, false, true);
65 
66   TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
67   TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
68 
69   // Test with optional argument.
70   at::zero_(x.mutable_grad());
71   at::zero_(y.mutable_grad());
72   output = helpers::get_operator_from_registry_and_execute<torch::Tensor>(
73       "custom::op_with_autograd", x, 2, y, z);
74 
75   TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y + z));
76   go = torch::ones({}, torch::requires_grad());
77   output.sum().backward(go, false, true);
78 
79   TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
80   TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
81   TORCH_INTERNAL_ASSERT(torch::allclose(z.grad(), torch::ones({5,5})));
82 }
83 
get_autograd_operator_from_registry_and_execute_in_nograd_mode()84 void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
85   at::AutoDispatchBelowAutograd guard;
86 
87   torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
88   torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
89 
90   torch::Tensor output =
91     helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
92 
93   TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
94 }
95 
load_serialized_module_with_custom_op_and_execute(const std::string & path_to_exported_script_module)96 void load_serialized_module_with_custom_op_and_execute(
97     const std::string& path_to_exported_script_module) {
98   torch::jit::Module module =
99       torch::jit::load(path_to_exported_script_module);
100   std::vector<torch::jit::IValue> inputs;
101   inputs.push_back(torch::ones(5));
102   auto output = module.forward(inputs).toTensor();
103 
104   AT_ASSERT(output.allclose(torch::ones(5) + 1));
105 }
106 
test_argument_checking_for_serialized_modules(const std::string & path_to_exported_script_module)107 void test_argument_checking_for_serialized_modules(
108     const std::string& path_to_exported_script_module) {
109   torch::jit::Module module =
110       torch::jit::load(path_to_exported_script_module);
111 
112   try {
113     module.forward({torch::jit::IValue(1), torch::jit::IValue(2)});
114     AT_ASSERT(false);
115   } catch (const c10::Error& error) {
116     AT_ASSERT(
117         std::string(error.what_without_backtrace())
118             .find("Expected at most 2 argument(s) for operator 'forward', "
119                   "but received 3 argument(s)") == 0);
120   }
121 
122   try {
123     module.forward({torch::jit::IValue(5)});
124     AT_ASSERT(false);
125   } catch (const c10::Error& error) {
126     AT_ASSERT(
127         std::string(error.what_without_backtrace())
128             .find("forward() Expected a value of type 'Tensor' "
129                   "for argument 'input' but instead found type 'int'") == 0);
130   }
131 
132   try {
133     module.forward({});
134     AT_ASSERT(false);
135   } catch (const c10::Error& error) {
136     AT_ASSERT(
137         std::string(error.what_without_backtrace())
138             .find("forward() is missing value for argument 'input'") == 0);
139   }
140 }
141 
test_move_to_device(const std::string & path_to_exported_script_module)142 void test_move_to_device(const std::string& path_to_exported_script_module) {
143   torch::jit::Module module =
144       torch::jit::load(path_to_exported_script_module);
145 
146   helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
147     return tensor.device().is_cpu();
148   });
149 
150   module.to(torch::kCUDA);
151 
152   helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
153     return tensor.device().is_cuda();
154   });
155 
156   module.to(torch::kCPU);
157 
158   helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
159     return tensor.device().is_cpu();
160   });
161 }
162 
test_move_to_dtype(const std::string & path_to_exported_script_module)163 void test_move_to_dtype(const std::string& path_to_exported_script_module) {
164   torch::jit::Module module =
165       torch::jit::load(path_to_exported_script_module);
166 
167   module.to(torch::kFloat16);
168 
169   helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
170     return tensor.dtype() == torch::kFloat16;
171   });
172 
173   module.to(torch::kDouble);
174 
175   helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
176     return tensor.dtype() == torch::kDouble;
177   });
178 }
179 
main(int argc,const char * argv[])180 int main(int argc, const char* argv[]) {
181   if (argc != 2) {
182     std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
183     return -1;
184   }
185   const std::string path_to_exported_script_module = argv[1];
186 
187   get_operator_from_registry_and_execute();
188   get_autograd_operator_from_registry_and_execute();
189   get_autograd_operator_from_registry_and_execute_in_nograd_mode();
190   load_serialized_module_with_custom_op_and_execute(
191       path_to_exported_script_module);
192   test_argument_checking_for_serialized_modules(path_to_exported_script_module);
193   test_move_to_dtype(path_to_exported_script_module);
194 
195   if (torch::cuda::device_count() > 0) {
196     test_move_to_device(path_to_exported_script_module);
197   }
198 
199   std::cout << "ok\n";
200 }
201