1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/mobile/nnc/context.h>
3 #include <torch/csrc/jit/mobile/nnc/registry.h>
4 #include <ATen/Functions.h>
5
6 namespace torch {
7 namespace jit {
8 namespace mobile {
9 namespace nnc {
10
11 extern "C" {
12
13 // out = a * n (doing calculation in the `tmp` buffer)
slow_mul_kernel(void ** args)14 int slow_mul_kernel(void** args) {
15 const int size = 128;
16 at::Tensor a = at::from_blob(args[0], {size}, at::kFloat);
17 at::Tensor out = at::from_blob(args[1], {size}, at::kFloat);
18 at::Tensor n = at::from_blob(args[2], {1}, at::kInt);
19 at::Tensor tmp = at::from_blob(args[3], {size}, at::kFloat);
20
21 tmp.zero_();
22 for (int i = n.item().toInt(); i > 0; i--) {
23 tmp.add_(a);
24 }
25 out.copy_(tmp);
26 return 0;
27 }
28
dummy_kernel(void **)29 int dummy_kernel(void** /* args */) {
30 return 0;
31 }
32
33 } // extern "C"
34
35 REGISTER_NNC_KERNEL("slow_mul", slow_mul_kernel)
36 REGISTER_NNC_KERNEL("dummy", dummy_kernel)
37
create_test_input_spec(const std::vector<int64_t> & sizes)38 InputSpec create_test_input_spec(const std::vector<int64_t>& sizes) {
39 InputSpec input_spec;
40 input_spec.sizes_ = sizes;
41 input_spec.dtype_ = at::kFloat;
42 return input_spec;
43 }
44
create_test_output_spec(const std::vector<int64_t> & sizes)45 OutputSpec create_test_output_spec(const std::vector<int64_t>& sizes) {
46 OutputSpec output_spec;
47 output_spec.sizes_ = sizes;
48 output_spec.dtype_ = at::kFloat;
49 return output_spec;
50 }
51
create_test_memory_plan(const std::vector<int64_t> & buffer_sizes)52 MemoryPlan create_test_memory_plan(const std::vector<int64_t>& buffer_sizes) {
53 MemoryPlan memory_plan;
54 memory_plan.buffer_sizes_ = buffer_sizes;
55 return memory_plan;
56 }
57
TEST(Function,ExecuteSlowMul)58 TEST(Function, ExecuteSlowMul) {
59 const int a = 999;
60 const int n = 100;
61 const int size = 128;
62 Function f;
63
64 f.set_nnc_kernel_id("slow_mul");
65 f.set_input_specs({create_test_input_spec({size})});
66 f.set_output_specs({create_test_output_spec({size})});
67 f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
68 at::ones({1}, at::kInt).mul(n)
69 })));
70 f.set_memory_plan(create_test_memory_plan({sizeof(float) * size}));
71
72 c10::List<at::Tensor> input({
73 at::ones({size}, at::kFloat).mul(a)
74 });
75 auto outputs = f.run(c10::impl::toList(input));
76 auto output = ((const c10::IValue&) outputs[0]).toTensor();
77 auto expected_output = at::ones({size}, at::kFloat).mul(a * n);
78 EXPECT_TRUE(output.equal(expected_output));
79 }
80
TEST(Function,Serialization)81 TEST(Function, Serialization) {
82 Function f;
83 f.set_name("test_function");
84 f.set_nnc_kernel_id("test_kernel");
85 f.set_input_specs({create_test_input_spec({1, 3, 224, 224})});
86 f.set_output_specs({create_test_output_spec({1000})});
87
88 f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
89 at::ones({1, 16, 3, 3}, at::kFloat),
90 at::ones({16, 32, 1, 1}, at::kFloat),
91 at::ones({32, 1, 3, 3}, at::kFloat)
92 })));
93 f.set_memory_plan(create_test_memory_plan({
94 sizeof(float) * 1024,
95 sizeof(float) * 2048,
96 }));
97
98 auto serialized = f.serialize();
99 Function f2(serialized);
100 EXPECT_EQ(f2.name(), "test_function");
101 EXPECT_EQ(f2.nnc_kernel_id(), "test_kernel");
102 EXPECT_EQ(f2.input_specs().size(), 1);
103 EXPECT_EQ(f2.input_specs()[0].sizes_, std::vector<int64_t>({1, 3, 224, 224}));
104 EXPECT_EQ(f2.input_specs()[0].dtype_, at::kFloat);
105
106 EXPECT_EQ(f2.output_specs().size(), 1);
107 EXPECT_EQ(f2.output_specs()[0].sizes_, std::vector<int64_t>({1000}));
108 EXPECT_EQ(f2.output_specs()[0].dtype_, at::kFloat);
109
110 EXPECT_EQ(f2.parameters().size(), 3);
111 EXPECT_EQ(f2.parameters()[0].toTensor().sizes(), at::IntArrayRef({1, 16, 3, 3}));
112 EXPECT_EQ(f2.parameters()[1].toTensor().sizes(), at::IntArrayRef({16, 32, 1, 1}));
113 EXPECT_EQ(f2.parameters()[2].toTensor().sizes(), at::IntArrayRef({32, 1, 3, 3}));
114
115 EXPECT_EQ(f2.memory_plan().buffer_sizes_.size(), 2);
116 EXPECT_EQ(f2.memory_plan().buffer_sizes_[0], sizeof(float) * 1024);
117 EXPECT_EQ(f2.memory_plan().buffer_sizes_[1], sizeof(float) * 2048);
118 }
119
TEST(Function,ValidInput)120 TEST(Function, ValidInput) {
121 const int size = 128;
122 Function f;
123 f.set_nnc_kernel_id("dummy");
124 f.set_input_specs({create_test_input_spec({size})});
125
126 c10::List<at::Tensor> input({
127 at::ones({size}, at::kFloat)
128 });
129 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
130 EXPECT_NO_THROW(
131 f.run(c10::impl::toList(input)));
132 }
133
TEST(Function,InvalidInput)134 TEST(Function, InvalidInput) {
135 const int size = 128;
136 Function f;
137 f.set_nnc_kernel_id("dummy");
138 f.set_input_specs({create_test_input_spec({size})});
139
140 c10::List<at::Tensor> input({
141 at::ones({size * 2}, at::kFloat)
142 });
143 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
144 EXPECT_THROW(
145 f.run(c10::impl::toList(input)),
146 c10::Error);
147 }
148
149 } // namespace nnc
150 } // namespace mobile
151 } // namespace jit
152 } // namespace torch
153