• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/mobile/nnc/context.h>
3 #include <torch/csrc/jit/mobile/nnc/registry.h>
4 #include <ATen/Functions.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace mobile {
9 namespace nnc {
10 
11 extern "C" {
12 
13 // out = a * n (doing calculation in the `tmp` buffer)
slow_mul_kernel(void ** args)14 int slow_mul_kernel(void** args) {
15   const int size = 128;
16   at::Tensor a = at::from_blob(args[0], {size}, at::kFloat);
17   at::Tensor out = at::from_blob(args[1], {size}, at::kFloat);
18   at::Tensor n = at::from_blob(args[2], {1}, at::kInt);
19   at::Tensor tmp = at::from_blob(args[3], {size}, at::kFloat);
20 
21   tmp.zero_();
22   for (int i = n.item().toInt(); i > 0; i--) {
23     tmp.add_(a);
24   }
25   out.copy_(tmp);
26   return 0;
27 }
28 
dummy_kernel(void **)29 int dummy_kernel(void** /* args */) {
30   return 0;
31 }
32 
33 } // extern "C"
34 
35 REGISTER_NNC_KERNEL("slow_mul", slow_mul_kernel)
36 REGISTER_NNC_KERNEL("dummy", dummy_kernel)
37 
create_test_input_spec(const std::vector<int64_t> & sizes)38 InputSpec create_test_input_spec(const std::vector<int64_t>& sizes) {
39   InputSpec input_spec;
40   input_spec.sizes_ = sizes;
41   input_spec.dtype_ = at::kFloat;
42   return input_spec;
43 }
44 
create_test_output_spec(const std::vector<int64_t> & sizes)45 OutputSpec create_test_output_spec(const std::vector<int64_t>& sizes) {
46   OutputSpec output_spec;
47   output_spec.sizes_ = sizes;
48   output_spec.dtype_ = at::kFloat;
49   return output_spec;
50 }
51 
create_test_memory_plan(const std::vector<int64_t> & buffer_sizes)52 MemoryPlan create_test_memory_plan(const std::vector<int64_t>& buffer_sizes) {
53   MemoryPlan memory_plan;
54   memory_plan.buffer_sizes_ = buffer_sizes;
55   return memory_plan;
56 }
57 
TEST(Function,ExecuteSlowMul)58 TEST(Function, ExecuteSlowMul) {
59   const int a = 999;
60   const int n = 100;
61   const int size = 128;
62   Function f;
63 
64   f.set_nnc_kernel_id("slow_mul");
65   f.set_input_specs({create_test_input_spec({size})});
66   f.set_output_specs({create_test_output_spec({size})});
67   f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
68       at::ones({1}, at::kInt).mul(n)
69   })));
70   f.set_memory_plan(create_test_memory_plan({sizeof(float) * size}));
71 
72   c10::List<at::Tensor> input({
73       at::ones({size}, at::kFloat).mul(a)
74   });
75   auto outputs = f.run(c10::impl::toList(input));
76   auto output = ((const c10::IValue&) outputs[0]).toTensor();
77   auto expected_output = at::ones({size}, at::kFloat).mul(a * n);
78   EXPECT_TRUE(output.equal(expected_output));
79 }
80 
TEST(Function,Serialization)81 TEST(Function, Serialization) {
82   Function f;
83   f.set_name("test_function");
84   f.set_nnc_kernel_id("test_kernel");
85   f.set_input_specs({create_test_input_spec({1, 3, 224, 224})});
86   f.set_output_specs({create_test_output_spec({1000})});
87 
88   f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
89       at::ones({1, 16, 3, 3}, at::kFloat),
90       at::ones({16, 32, 1, 1}, at::kFloat),
91       at::ones({32, 1, 3, 3}, at::kFloat)
92   })));
93   f.set_memory_plan(create_test_memory_plan({
94       sizeof(float) * 1024,
95       sizeof(float) * 2048,
96   }));
97 
98   auto serialized = f.serialize();
99   Function f2(serialized);
100   EXPECT_EQ(f2.name(), "test_function");
101   EXPECT_EQ(f2.nnc_kernel_id(), "test_kernel");
102   EXPECT_EQ(f2.input_specs().size(), 1);
103   EXPECT_EQ(f2.input_specs()[0].sizes_, std::vector<int64_t>({1, 3, 224, 224}));
104   EXPECT_EQ(f2.input_specs()[0].dtype_, at::kFloat);
105 
106   EXPECT_EQ(f2.output_specs().size(), 1);
107   EXPECT_EQ(f2.output_specs()[0].sizes_, std::vector<int64_t>({1000}));
108   EXPECT_EQ(f2.output_specs()[0].dtype_, at::kFloat);
109 
110   EXPECT_EQ(f2.parameters().size(), 3);
111   EXPECT_EQ(f2.parameters()[0].toTensor().sizes(), at::IntArrayRef({1, 16, 3, 3}));
112   EXPECT_EQ(f2.parameters()[1].toTensor().sizes(), at::IntArrayRef({16, 32, 1, 1}));
113   EXPECT_EQ(f2.parameters()[2].toTensor().sizes(), at::IntArrayRef({32, 1, 3, 3}));
114 
115   EXPECT_EQ(f2.memory_plan().buffer_sizes_.size(), 2);
116   EXPECT_EQ(f2.memory_plan().buffer_sizes_[0], sizeof(float) * 1024);
117   EXPECT_EQ(f2.memory_plan().buffer_sizes_[1], sizeof(float) * 2048);
118 }
119 
TEST(Function,ValidInput)120 TEST(Function, ValidInput) {
121   const int size = 128;
122   Function f;
123   f.set_nnc_kernel_id("dummy");
124   f.set_input_specs({create_test_input_spec({size})});
125 
126   c10::List<at::Tensor> input({
127       at::ones({size}, at::kFloat)
128   });
129   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
130   EXPECT_NO_THROW(
131       f.run(c10::impl::toList(input)));
132 }
133 
TEST(Function,InvalidInput)134 TEST(Function, InvalidInput) {
135   const int size = 128;
136   Function f;
137   f.set_nnc_kernel_id("dummy");
138   f.set_input_specs({create_test_input_spec({size})});
139 
140   c10::List<at::Tensor> input({
141       at::ones({size * 2}, at::kFloat)
142   });
143   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
144   EXPECT_THROW(
145       f.run(c10::impl::toList(input)),
146       c10::Error);
147 }
148 
149 } // namespace nnc
150 } // namespace mobile
151 } // namespace jit
152 } // namespace torch
153