1 #include <c10/util/irange.h>
2 #include <torch/script.h>
3 #include <torch/cuda.h>
4
5 #include "op.h"
6
7 #include <memory>
8 #include <string>
9 #include <vector>
10
11 #include <iostream>
12
13 namespace helpers {
14 template <typename Predicate>
check_all_parameters(const torch::jit::Module & module,Predicate predicate)15 void check_all_parameters(
16 const torch::jit::Module& module,
17 Predicate predicate) {
18 for (at::Tensor parameter : module.parameters()) {
19 AT_ASSERT(predicate(parameter));
20 }
21 }
22
23 template<class Result, class... Args>
get_operator_from_registry_and_execute(const char * op_name,Args &&...args)24 Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) {
25 auto& ops = torch::jit::getAllOperatorsFor(
26 torch::jit::Symbol::fromQualString(op_name));
27 TORCH_INTERNAL_ASSERT(ops.size() == 1);
28
29 auto& op = ops.front();
30 TORCH_INTERNAL_ASSERT(op->schema().name() == op_name);
31
32 torch::jit::Stack stack;
33 torch::jit::push(stack, std::forward<Args>(args)...);
34 op->getOperation()(stack);
35
36 TORCH_INTERNAL_ASSERT(1 == stack.size());
37 return torch::jit::pop(stack).to<Result>();
38 }
39 } // namespace helpers
40
get_operator_from_registry_and_execute()41 void get_operator_from_registry_and_execute() {
42 std::vector<torch::Tensor> output =
43 helpers::get_operator_from_registry_and_execute<std::vector<torch::Tensor>>("custom::op", torch::ones(5), 2.0, 3);
44
45 const auto manual = custom_op(torch::ones(5), 2.0, 3);
46
47 TORCH_INTERNAL_ASSERT(output.size() == 3);
48 for (const auto i : c10::irange(output.size())) {
49 TORCH_INTERNAL_ASSERT(output[i].allclose(torch::ones(5) * 2));
50 TORCH_INTERNAL_ASSERT(output[i].allclose(manual[i]));
51 }
52 }
53
get_autograd_operator_from_registry_and_execute()54 void get_autograd_operator_from_registry_and_execute() {
55 torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
56 torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
57 torch::Tensor z = torch::randn({5,5}, torch::requires_grad());
58
59 torch::Tensor output =
60 helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
61
62 TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
63 auto go = torch::ones({}, torch::requires_grad());
64 output.sum().backward(go, false, true);
65
66 TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
67 TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
68
69 // Test with optional argument.
70 at::zero_(x.mutable_grad());
71 at::zero_(y.mutable_grad());
72 output = helpers::get_operator_from_registry_and_execute<torch::Tensor>(
73 "custom::op_with_autograd", x, 2, y, z);
74
75 TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y + z));
76 go = torch::ones({}, torch::requires_grad());
77 output.sum().backward(go, false, true);
78
79 TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
80 TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
81 TORCH_INTERNAL_ASSERT(torch::allclose(z.grad(), torch::ones({5,5})));
82 }
83
get_autograd_operator_from_registry_and_execute_in_nograd_mode()84 void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
85 at::AutoDispatchBelowAutograd guard;
86
87 torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
88 torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
89
90 torch::Tensor output =
91 helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
92
93 TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
94 }
95
load_serialized_module_with_custom_op_and_execute(const std::string & path_to_exported_script_module)96 void load_serialized_module_with_custom_op_and_execute(
97 const std::string& path_to_exported_script_module) {
98 torch::jit::Module module =
99 torch::jit::load(path_to_exported_script_module);
100 std::vector<torch::jit::IValue> inputs;
101 inputs.push_back(torch::ones(5));
102 auto output = module.forward(inputs).toTensor();
103
104 AT_ASSERT(output.allclose(torch::ones(5) + 1));
105 }
106
test_argument_checking_for_serialized_modules(const std::string & path_to_exported_script_module)107 void test_argument_checking_for_serialized_modules(
108 const std::string& path_to_exported_script_module) {
109 torch::jit::Module module =
110 torch::jit::load(path_to_exported_script_module);
111
112 try {
113 module.forward({torch::jit::IValue(1), torch::jit::IValue(2)});
114 AT_ASSERT(false);
115 } catch (const c10::Error& error) {
116 AT_ASSERT(
117 std::string(error.what_without_backtrace())
118 .find("Expected at most 2 argument(s) for operator 'forward', "
119 "but received 3 argument(s)") == 0);
120 }
121
122 try {
123 module.forward({torch::jit::IValue(5)});
124 AT_ASSERT(false);
125 } catch (const c10::Error& error) {
126 AT_ASSERT(
127 std::string(error.what_without_backtrace())
128 .find("forward() Expected a value of type 'Tensor' "
129 "for argument 'input' but instead found type 'int'") == 0);
130 }
131
132 try {
133 module.forward({});
134 AT_ASSERT(false);
135 } catch (const c10::Error& error) {
136 AT_ASSERT(
137 std::string(error.what_without_backtrace())
138 .find("forward() is missing value for argument 'input'") == 0);
139 }
140 }
141
test_move_to_device(const std::string & path_to_exported_script_module)142 void test_move_to_device(const std::string& path_to_exported_script_module) {
143 torch::jit::Module module =
144 torch::jit::load(path_to_exported_script_module);
145
146 helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
147 return tensor.device().is_cpu();
148 });
149
150 module.to(torch::kCUDA);
151
152 helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
153 return tensor.device().is_cuda();
154 });
155
156 module.to(torch::kCPU);
157
158 helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
159 return tensor.device().is_cpu();
160 });
161 }
162
test_move_to_dtype(const std::string & path_to_exported_script_module)163 void test_move_to_dtype(const std::string& path_to_exported_script_module) {
164 torch::jit::Module module =
165 torch::jit::load(path_to_exported_script_module);
166
167 module.to(torch::kFloat16);
168
169 helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
170 return tensor.dtype() == torch::kFloat16;
171 });
172
173 module.to(torch::kDouble);
174
175 helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
176 return tensor.dtype() == torch::kDouble;
177 });
178 }
179
main(int argc,const char * argv[])180 int main(int argc, const char* argv[]) {
181 if (argc != 2) {
182 std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
183 return -1;
184 }
185 const std::string path_to_exported_script_module = argv[1];
186
187 get_operator_from_registry_and_execute();
188 get_autograd_operator_from_registry_and_execute();
189 get_autograd_operator_from_registry_and_execute_in_nograd_mode();
190 load_serialized_module_with_custom_op_and_execute(
191 path_to_exported_script_module);
192 test_argument_checking_for_serialized_modules(path_to_exported_script_module);
193 test_move_to_dtype(path_to_exported_script_module);
194
195 if (torch::cuda::device_count() > 0) {
196 test_move_to_device(path_to_exported_script_module);
197 }
198
199 std::cout << "ok\n";
200 }
201