• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/tensorexpr/test_base.h>
4 
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
7 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8 #include <torch/csrc/jit/runtime/custom_operator.h>
9 #include <torch/csrc/jit/tensorexpr/kernel.h>
10 
11 #include <test/cpp/tensorexpr/test_utils.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
16 #include <torch/csrc/jit/tensorexpr/ir.h>
17 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
18 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
19 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
20 #include <torch/csrc/jit/tensorexpr/loopnest.h>
21 #include <torch/csrc/jit/tensorexpr/tensor.h>
22 
23 #include <torch/csrc/jit/testing/file_check.h>
24 #include <torch/jit.h>
25 
26 #include <ATen/NativeFunctions.h>
27 #include <ATen/core/dispatch/Dispatcher.h>
28 #include <ATen/native/xnnpack/OpContext.h>
29 
30 namespace torch {
31 namespace jit {
32 using namespace torch::jit::tensorexpr;
33 
TEST(ExternalCall,Conv1d_float)34 TEST(ExternalCall, Conv1d_float) {
35   BufHandle Input("Input", {1, 100, 115}, kFloat);
36   BufHandle Weight("Weight", {100, 1, 7}, kFloat);
37   BufHandle Bias("Bias", {100}, kFloat);
38   BufHandle ResultBuf("Result", {1, 100, 115}, kFloat);
39   int64_t stride = 1;
40   int64_t pad = 3;
41   int64_t dilation = 1;
42   int64_t groups = 100;
43 
44   Tensor Result = Tensor(
45       ResultBuf.node(),
46       ExternalCall::make(
47           ResultBuf,
48           "nnc_aten_conv1d",
49           {Input, Weight, Bias},
50           {stride, pad, dilation, groups}));
51   LoopNest l({Result});
52   l.prepareForCodegen();
53   l.simplify();
54 
55   auto options = at::TensorOptions()
56                      .dtype(at::kFloat)
57                      .layout(at::kStrided)
58                      .device(at::kCPU)
59                      .requires_grad(false);
60   at::Tensor input = at::ones({1, 100, 115}, options) * 5.f;
61   at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f;
62   at::Tensor bias = at::ones({100}, options) * 11.f;
63   at::Tensor ref =
64       at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
65 
66   at::Tensor nnc_result;
67   std::vector<float> input_buf(1 * 100 * 115, 5.f);
68   std::vector<float> weight_buf(100 * 1 * 7, 6.f);
69   std::vector<float> bias_buf(100, 11.f);
70   std::vector<float> result_buf(1 * 100 * 115, -1.f);
71 
72 #ifdef TORCH_ENABLE_LLVM
73   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
74 
75   llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
76   nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
77   ASSERT_TRUE(at::allclose(nnc_result, ref));
78 #endif
79 
80   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
81 
82   ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
83   nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
84   ASSERT_TRUE(at::allclose(nnc_result, ref));
85 }
86 
TEST(ExternalCall,Conv1d_int)87 TEST(ExternalCall, Conv1d_int) {
88   // A similar test, but now using kInt tensors
89   BufHandle Input("Input", {1, 100, 115}, kInt);
90   BufHandle Weight("Weight", {100, 1, 7}, kInt);
91   BufHandle Bias("Bias", {100}, kInt);
92   BufHandle ResultBuf("Result", {1, 100, 115}, kInt);
93   int64_t stride = 1;
94   int64_t pad = 3;
95   int64_t dilation = 1;
96   int64_t groups = 100;
97 
98   Tensor Result = Tensor(
99       ResultBuf.node(),
100       ExternalCall::make(
101           ResultBuf,
102           "nnc_aten_conv1d",
103           {Input, Weight, Bias},
104           {stride, pad, dilation, groups}));
105   LoopNest l({Result});
106   l.prepareForCodegen();
107   l.simplify();
108 
109   auto options = at::TensorOptions()
110                      .dtype(at::kInt)
111                      .layout(at::kStrided)
112                      .device(at::kCPU)
113                      .requires_grad(false);
114   at::Tensor input = at::ones({1, 100, 115}, options) * 5;
115   at::Tensor weight = at::ones({100, 1, 7}, options) * 6;
116   at::Tensor bias = at::ones({100}, options) * 11;
117   at::Tensor ref =
118       at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
119 
120   at::Tensor nnc_result;
121   std::vector<int32_t> input_buf(1 * 100 * 115, 5);
122   std::vector<int32_t> weight_buf(100 * 1 * 7, 6);
123   std::vector<int32_t> bias_buf(100, 11);
124   std::vector<int32_t> result_buf(1 * 100 * 115, -1);
125 
126 #ifdef TORCH_ENABLE_LLVM
127   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
128 
129   llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
130   nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
131   ASSERT_TRUE(at::allclose(nnc_result, ref));
132 #endif
133 
134   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
135 
136   ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
137   nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
138   ASSERT_TRUE(at::allclose(nnc_result, ref));
139 }
140 
TEST(ExternalCall,Conv1d_nobias_noargs)141 TEST(ExternalCall, Conv1d_nobias_noargs) {
142   BufHandle Input("Input", {1, 1, 115}, kFloat);
143   BufHandle Weight("Weight", {10, 1, 7}, kFloat);
144   BufHandle ResultBuf("Result", {1, 10, 109}, kFloat);
145 
146   Tensor Result = Tensor(
147       ResultBuf.node(),
148       ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {}));
149   LoopNest l({Result});
150   l.prepareForCodegen();
151   l.simplify();
152 
153   auto options = at::TensorOptions()
154                      .dtype(at::kFloat)
155                      .layout(at::kStrided)
156                      .device(at::kCPU)
157                      .requires_grad(false);
158   at::Tensor input = at::ones({1, 1, 115}, options) * 5.f;
159   at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f;
160   at::Tensor ref = at::conv1d(input, weight);
161 
162   at::Tensor nnc_result;
163   std::vector<float> input_buf(1 * 1 * 115, 5.f);
164   std::vector<float> weight_buf(10 * 1 * 7, 6.f);
165   std::vector<float> result_buf(1 * 10 * 109, -1.f);
166 
167 #ifdef TORCH_ENABLE_LLVM
168   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
169 
170   llvm_codegen.call({input_buf, weight_buf, result_buf});
171   nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
172   ASSERT_TRUE(at::allclose(nnc_result, ref));
173 #endif
174 
175   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
176 
177   ir_eval.call({input_buf, weight_buf, result_buf});
178   nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
179   ASSERT_TRUE(at::allclose(nnc_result, ref));
180 }
181 
TEST(ExternalCall,Conv2d_float)182 TEST(ExternalCall, Conv2d_float) {
183   BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
184   BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat);
185   BufHandle Bias("Bias", {16}, kFloat);
186   BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
187   int64_t stride = 2;
188   int64_t pad = 1;
189   int64_t dilation = 1;
190   int64_t groups = 1;
191 
192   Tensor Result = Tensor(
193       ResultBuf.node(),
194       ExternalCall::make(
195           ResultBuf,
196           "nnc_aten_conv2d",
197           {Input, Weight, Bias},
198           {stride, stride, pad, pad, dilation, dilation, groups}));
199   LoopNest l({Result});
200   l.prepareForCodegen();
201   l.simplify();
202 
203   auto options = at::TensorOptions()
204                      .dtype(at::kFloat)
205                      .layout(at::kStrided)
206                      .device(at::kCPU)
207                      .requires_grad(false);
208   at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f;
209   at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f;
210   at::Tensor bias = at::ones({16}, options) * 11.f;
211   at::Tensor ref = at::conv2d(
212       input,
213       weight,
214       bias,
215       {stride, stride},
216       {pad, pad},
217       {dilation, dilation},
218       groups);
219 
220   at::Tensor nnc_result;
221   std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f);
222   std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f);
223   std::vector<float> bias_buf(16, 11.f);
224   std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
225 
226 #ifdef TORCH_ENABLE_LLVM
227   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
228 
229   llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
230   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
231   ASSERT_TRUE(at::allclose(nnc_result, ref));
232 #endif
233 
234   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
235 
236   ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
237   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
238   ASSERT_TRUE(at::allclose(nnc_result, ref));
239 }
240 
TEST(ExternalCall,Conv2d_int)241 TEST(ExternalCall, Conv2d_int) {
242   // A similar test, but now using kInt tensors
243 
244   BufHandle Input("Input", {1, 3, 224, 224}, kInt);
245   BufHandle Weight("Weight", {16, 3, 3, 3}, kInt);
246   BufHandle Bias("Bias", {16}, kInt);
247   BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt);
248   int64_t stride = 2;
249   int64_t pad = 1;
250   int64_t dilation = 1;
251   int64_t groups = 1;
252 
253   Tensor Result = Tensor(
254       ResultBuf.node(),
255       ExternalCall::make(
256           ResultBuf,
257           "nnc_aten_conv2d",
258           {Input, Weight, Bias},
259           {stride, stride, pad, pad, dilation, dilation, groups}));
260   LoopNest l({Result});
261   l.prepareForCodegen();
262   l.simplify();
263 
264   auto options = at::TensorOptions()
265                      .dtype(at::kInt)
266                      .layout(at::kStrided)
267                      .device(at::kCPU)
268                      .requires_grad(false);
269   at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5;
270   at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6;
271   at::Tensor bias = at::ones({16}, options) * 11;
272   at::Tensor ref = at::conv2d(
273       input,
274       weight,
275       bias,
276       {stride, stride},
277       {pad, pad},
278       {dilation, dilation},
279       groups);
280 
281   at::Tensor nnc_result;
282   std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5);
283   std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6);
284   std::vector<int32_t> bias_buf(16, 11);
285   std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1);
286 
287 #ifdef TORCH_ENABLE_LLVM
288   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
289 
290   llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
291   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
292   ASSERT_TRUE(at::allclose(nnc_result, ref));
293 #endif
294 
295   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
296 
297   ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
298   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
299   ASSERT_TRUE(at::allclose(nnc_result, ref));
300 }
301 
TEST(ExternalCall,Conv2d_nobias_noargs)302 TEST(ExternalCall, Conv2d_nobias_noargs) {
303   BufHandle Input("Input", {1, 16, 112, 112}, kFloat);
304   BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat);
305   BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
306 
307   Tensor Result = Tensor(
308       ResultBuf.node(),
309       ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {}));
310   LoopNest l({Result});
311   l.prepareForCodegen();
312   l.simplify();
313 
314   auto options = at::TensorOptions()
315                      .dtype(at::kFloat)
316                      .layout(at::kStrided)
317                      .device(at::kCPU)
318                      .requires_grad(false);
319   at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f;
320   at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
321   at::Tensor ref = at::conv2d(input, weight);
322 
323   at::Tensor nnc_result;
324   std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f);
325   std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
326   std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
327 
328 #ifdef TORCH_ENABLE_LLVM
329   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
330 
331   llvm_codegen.call({input_buf, weight_buf, result_buf});
332   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
333   ASSERT_TRUE(at::allclose(nnc_result, ref));
334 #endif
335 
336   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
337 
338   ir_eval.call({input_buf, weight_buf, result_buf});
339   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
340   ASSERT_TRUE(at::allclose(nnc_result, ref));
341 }
342 
TEST(ExternalCall,Addmm_float)343 TEST(ExternalCall, Addmm_float) {
344   BufHandle Input("Input", {100, 300}, kFloat);
345   BufHandle Mat1("Mat1", {100, 200}, kFloat);
346   BufHandle Mat2("Mat2", {200, 300}, kFloat);
347   BufHandle ResultBuf("Result", {100, 300}, kFloat);
348   int64_t beta = 2;
349   int64_t alpha = 2;
350 
351   Tensor Result = Tensor(
352       ResultBuf.node(),
353       ExternalCall::make(
354           ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha}));
355   LoopNest l({Result});
356   l.prepareForCodegen();
357   l.simplify();
358 
359   auto options = at::TensorOptions()
360                      .dtype(at::kFloat)
361                      .layout(at::kStrided)
362                      .device(at::kCPU)
363                      .requires_grad(false);
364   at::Tensor input = at::ones({100, 300}, options) * 5.f;
365   at::Tensor mat1 = at::ones({100, 200}, options) * 6.f;
366   at::Tensor mat2 = at::ones({200, 300}, options) * 11.f;
367   at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha);
368 
369   at::Tensor nnc_result;
370   std::vector<float> input_buf(100 * 300, 5.f);
371   std::vector<float> mat1_buf(100 * 200, 6.f);
372   std::vector<float> mat2_buf(200 * 300, 11.f);
373   std::vector<float> result_buf(100 * 300, -1.f);
374 
375 #ifdef TORCH_ENABLE_LLVM
376   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result});
377 
378   llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf});
379   nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
380   ASSERT_TRUE(at::allclose(nnc_result, ref));
381 #endif
382 
383   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result});
384 
385   ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf});
386   nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
387   ASSERT_TRUE(at::allclose(nnc_result, ref));
388 }
389 
TEST(ExternalCall,Embedding)390 TEST(ExternalCall, Embedding) {
391   BufHandle Weight("Weight", {256, 100}, kFloat);
392   BufHandle Indices("Indices", {1, 115}, kLong);
393   BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
394   int64_t padding_idx = -1;
395   bool scale_grad_by_freq = false;
396   bool sparse = false;
397 
398   Tensor Result = Tensor(
399       ResultBuf.node(),
400       ExternalCall::make(
401           ResultBuf,
402           "nnc_aten_embedding",
403           {Weight, Indices},
404           {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
405   LoopNest l({Result});
406   l.prepareForCodegen();
407   l.simplify();
408 
409   auto options = at::TensorOptions()
410                      .layout(at::kStrided)
411                      .device(at::kCPU)
412                      .requires_grad(false);
413 
414   at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
415   at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
416   at::Tensor ref =
417       at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
418 
419   at::Tensor nnc_result;
420   std::vector<float> weight_buf(256 * 100, 5.f);
421   std::vector<int64_t> indices_buf(1 * 115, 6);
422   std::vector<float> result_buf(1 * 115 * 100, -1.f);
423 
424 #ifdef TORCH_ENABLE_LLVM
425   LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
426 
427   llvm_codegen.call({weight_buf, indices_buf, result_buf});
428   nnc_result = at::from_blob(
429       result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
430   ASSERT_TRUE(at::allclose(nnc_result, ref));
431 #endif
432 
433   SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
434 
435   ir_eval.call({weight_buf, indices_buf, result_buf});
436   nnc_result = at::from_blob(
437       result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
438   ASSERT_TRUE(at::allclose(nnc_result, ref));
439 }
440 
TEST(ExternalCall,MaxReduction)441 TEST(ExternalCall, MaxReduction) {
442   BufHandle Input("Input", {1, 115, 152}, kFloat);
443   BufHandle ResultBuf("Result", {1, 152}, kFloat);
444   int64_t dim = 1;
445   bool keep_dim = false;
446 
447   Tensor Result = Tensor(
448       ResultBuf.node(),
449       ExternalCall::make(
450           ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim}));
451   LoopNest l({Result});
452   l.prepareForCodegen();
453   l.simplify();
454 
455   auto options = at::TensorOptions()
456                      .dtype(at::kFloat)
457                      .layout(at::kStrided)
458                      .device(at::kCPU)
459                      .requires_grad(false);
460 
461   at::Tensor input = at::ones({1, 115, 152}, options) * 5.f;
462   at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim));
463 
464   at::Tensor nnc_result;
465   std::vector<float> input_buf(1 * 115 * 152, 5.f);
466   std::vector<float> result_buf(1 * 152, -1.f);
467 
468 #ifdef TORCH_ENABLE_LLVM
469   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result});
470 
471   llvm_codegen.call({input_buf, result_buf});
472   nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
473   ASSERT_TRUE(at::allclose(nnc_result, ref));
474 #endif
475 
476   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result});
477 
478   ir_eval.call({input_buf, result_buf});
479   nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
480   ASSERT_TRUE(at::allclose(nnc_result, ref));
481 }
482 
483 #ifdef USE_XNNPACK
484 
TEST(ExternalCall,Prepacked_Linear_float)485 TEST(ExternalCall, Prepacked_Linear_float) {
486   using namespace at::native::xnnpack;
487 
488   BufHandle Input("Input", {100, 200}, kFloat);
489   BufHandle ResultBuf("Result", {100, 300}, kFloat);
490 
491   // Calculate reference result using at::linear.
492   auto options = at::TensorOptions()
493                      .dtype(at::kFloat)
494                      .layout(at::kStrided)
495                      .device(at::kCPU)
496                      .requires_grad(false);
497   at::Tensor input =
498       at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200});
499   at::Tensor weight =
500       at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200});
501   at::Tensor bias = at::linspace(-10.0, 10.0, 300, options);
502   at::Tensor ref = at::linear(input, weight, bias);
503 
504   // Create prepacked xnnpack context object.
505   auto linear_clamp_prepack_op =
506       c10::Dispatcher::singleton()
507           .findSchemaOrThrow("prepacked::linear_clamp_prepack", "")
508           .typed<c10::intrusive_ptr<LinearOpContext>(
509               at::Tensor,
510               std::optional<at::Tensor>,
511               const std::optional<at::Scalar>&,
512               const std::optional<at::Scalar>&)>();
513   auto prepacked = linear_clamp_prepack_op.call(
514       weight, bias, std::optional<at::Scalar>(), std::optional<at::Scalar>());
515 
516   BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
517   Tensor Result = Tensor(
518       ResultBuf.node(),
519       ExternalCall::make(
520           ResultBuf,
521           "nnc_prepacked_linear_clamp_run",
522           {Input, DummyPrepacked},
523           {}));
524   LoopNest l({Result});
525   l.prepareForCodegen();
526   l.simplify();
527 
528   at::Tensor nnc_result;
529   std::vector<float> input_buf(
530       input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200);
531   std::vector<float> result_buf(100 * 300, -1.f);
532 
533 #ifdef TORCH_ENABLE_LLVM
534   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
535 
536   llvm_codegen.call({input_buf, prepacked.get(), result_buf});
537   nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
538   ASSERT_TRUE(at::allclose(nnc_result, ref));
539 #endif
540 
541   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
542 
543   ir_eval.call({input_buf, prepacked.get(), result_buf});
544   nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
545   ASSERT_TRUE(at::allclose(nnc_result, ref));
546 }
547 
TEST(ExternalCall,Prepacked_Conv2d_float)548 TEST(ExternalCall, Prepacked_Conv2d_float) {
549   using namespace at::native::xnnpack;
550 
551   BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
552   BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
553   int64_t stride = 2;
554   int64_t pad = 1;
555   int64_t dilation = 1;
556   int64_t groups = 1;
557 
558   // Calculate reference result using at::conv2d.
559   auto options = at::TensorOptions()
560                      .dtype(at::kFloat)
561                      .layout(at::kStrided)
562                      .device(at::kCPU)
563                      .requires_grad(false);
564   at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options)
565                          .resize_({1, 3, 224, 224});
566   at::Tensor weight =
567       at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3});
568   at::Tensor bias = at::linspace(-10.0, 10.0, 16, options);
569   at::Tensor ref = at::conv2d(
570       input,
571       weight,
572       bias,
573       {stride, stride},
574       {pad, pad},
575       {dilation, dilation},
576       groups);
577 
578   // Create prepacked xnnpack context object.
579   auto conv2d_clamp_prepack_op =
580       c10::Dispatcher::singleton()
581           .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "")
582           .typed<c10::intrusive_ptr<Conv2dOpContext>(
583               at::Tensor,
584               std::optional<at::Tensor>,
585               std::vector<int64_t>,
586               std::vector<int64_t>,
587               std::vector<int64_t>,
588               int64_t,
589               const std::optional<at::Scalar>&,
590               const std::optional<at::Scalar>&)>();
591   auto prepacked = conv2d_clamp_prepack_op.call(
592       weight,
593       bias,
594       {stride, stride},
595       {pad, pad},
596       {dilation, dilation},
597       groups,
598       std::optional<at::Scalar>(),
599       std::optional<at::Scalar>());
600 
601   BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
602   Tensor Result = Tensor(
603       ResultBuf.node(),
604       ExternalCall::make(
605           ResultBuf,
606           "nnc_prepacked_conv2d_clamp_run",
607           {Input, DummyPrepacked},
608           {}));
609   LoopNest l({Result});
610   l.prepareForCodegen();
611   l.simplify();
612 
613   at::Tensor nnc_result;
614   std::vector<float> input_buf(
615       input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224);
616   std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
617 
618 #ifdef TORCH_ENABLE_LLVM
619   LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
620 
621   llvm_codegen.call({input_buf, prepacked.get(), result_buf});
622   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
623   ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
624 #endif
625 
626   SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
627 
628   ir_eval.call({input_buf, prepacked.get(), result_buf});
629   nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
630   ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
631 }
632 
633 #endif // USE_XNNPACK
634 
TEST(ExternalCall,BinaryFloat)635 TEST(ExternalCall, BinaryFloat) {
636   using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;
637   using Test = std::tuple<
638       std::vector<int64_t>,
639       std::vector<int64_t>,
640       std::vector<int64_t>,
641       TensorFunc,
642       std::string>;
643   std::vector<Test> tests = {};
644   tests.push_back(
645       Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});
646   tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});
647   tests.push_back(
648       Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});
649   for (auto curTest : tests) {
650     auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest;
651     auto toExprHandleVec = [](std::vector<int64_t> v) {
652       auto intV = std::vector<int>(v.begin(), v.end());
653       return std::vector<ExprHandle>(intV.begin(), intV.end());
654     };
655     BufHandle A("A", toExprHandleVec(aShape), kFloat);
656     BufHandle B("B", toExprHandleVec(bShape), kFloat);
657     BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
658 
659     Tensor Result = Tensor(
660         ResultBuf.node(),
661         ExternalCall::make(ResultBuf, externCallName, {A, B}, {}));
662     LoopNest l({Result});
663     l.prepareForCodegen();
664     l.simplify();
665 
666     auto options = at::TensorOptions()
667                        .dtype(at::kFloat)
668                        .layout(at::kStrided)
669                        .device(at::kCPU)
670                        .requires_grad(false);
671     at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
672     at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;
673     at::Tensor ref = torchFunc(a, b);
674 
675     auto prod = [](std::vector<int64_t> v) {
676       // NOLINTNEXTLINE(modernize-use-transparent-functors)
677       return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
678     };
679 
680     at::Tensor nnc_result;
681     std::vector<float> a_buf(prod(aShape), 5.f);
682     std::vector<float> b_buf(prod(bShape), 6.f);
683     std::vector<float> result_buf(prod(resShape), -1.f);
684 
685 #ifdef TORCH_ENABLE_LLVM
686     LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
687 
688     llvm_codegen.call({a_buf, b_buf, result_buf});
689     nnc_result =
690         at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
691     ASSERT_TRUE(at::allclose(nnc_result, ref));
692 #endif
693 
694     SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
695     ir_eval.call({a_buf, b_buf, result_buf});
696     nnc_result =
697         at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
698     ASSERT_TRUE(at::allclose(nnc_result, ref));
699   }
700 }
701 
TEST(ExternalCall,UnaryFloat)702 TEST(ExternalCall, UnaryFloat) {
703   using TensorFunc = std::function<at::Tensor(at::Tensor)>;
704   auto toExprHandleVec = [](std::vector<int64_t> v) {
705     auto intV = std::vector<int>(v.begin(), v.end());
706     return std::vector<ExprHandle>(intV.begin(), intV.end());
707   };
708   using Test = std::tuple<
709       std::vector<int64_t>,
710       std::vector<int64_t>,
711       TensorFunc,
712       std::string,
713       std::vector<ExprHandle>>;
714   std::vector<Test> tests = {};
715   tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
716                        {1, 64, 8, 9},
717                        {1, 64, 5, 7},
718                        [](at::Tensor x) {
719                          return at::adaptive_avg_pool2d(x, {5, 7});
720                        },
721                        "nnc_aten_adaptive_avg_pool2d",
722                        toExprHandleVec({5, 7})});
723   tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
724                        {100, 200},
725                        {100},
726                        [](at::Tensor x) { return at::mean(x, {1}); },
727                        "nnc_aten_mean",
728                        toExprHandleVec({1, /*keepdim=*/0})});
729   for (auto curTest : tests) {
730     auto [aShape, resShape, torchFunc, externCallName, externCallArgs] =
731         curTest;
732     BufHandle A("A", toExprHandleVec(aShape), kFloat);
733     BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
734 
735     Tensor Result = Tensor(
736         ResultBuf.node(),
737         ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs));
738     LoopNest l({Result});
739     l.prepareForCodegen();
740     l.simplify();
741 
742     auto options = at::TensorOptions()
743                        .dtype(at::kFloat)
744                        .layout(at::kStrided)
745                        .device(at::kCPU)
746                        .requires_grad(false);
747     at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
748     at::Tensor ref = torchFunc(a);
749 
750     auto prod = [](std::vector<int64_t> v) {
751       // NOLINTNEXTLINE(modernize-use-transparent-functors)
752       return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
753     };
754 
755     at::Tensor nnc_result;
756     std::vector<float> a_buf(prod(aShape), 5.f);
757     std::vector<float> result_buf(prod(resShape), -1.f);
758 
759 #ifdef TORCH_ENABLE_LLVM
760     LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});
761 
762     llvm_codegen.call({a_buf, result_buf});
763     nnc_result =
764         at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
765     ASSERT_TRUE(at::allclose(nnc_result, ref));
766 #endif
767 
768     SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});
769     ir_eval.call({a_buf, result_buf});
770     nnc_result =
771         at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
772     ASSERT_TRUE(at::allclose(nnc_result, ref));
773   }
774 }
775 
TEST(ExternalCall,ComputeInterop)776 TEST(ExternalCall, ComputeInterop) {
777   // This test verifies that Tensors using external calls can be used by and can
778   // use Tensors built with Compute API.
779 
780   BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat);
781   BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat);
782 
783   Tensor Input = Compute(
784       "Input",
785       {1, 16, 32, 32},
786       [&](const VarHandle& n,
787           const VarHandle& c,
788           const VarHandle& h,
789           const VarHandle& w) { return FloatImm::make(5.0f); });
790   Tensor Weight = Compute(
791       "Weight",
792       {16, 16, 1, 1},
793       [&](const VarHandle& n,
794           const VarHandle& c,
795           const VarHandle& h,
796           const VarHandle& w) { return FloatImm::make(6.0f); });
797 
798   Tensor ConvResult = Tensor(
799       ConvResultBuf.node(),
800       ExternalCall::make(
801           ConvResultBuf,
802           "nnc_aten_conv2d",
803           {BufHandle(Input.buf()), BufHandle(Weight.buf())},
804           {}));
805   Tensor MatmulResult = Tensor(
806       MatmulResultBuf.node(),
807       ExternalCall::make(
808           MatmulResultBuf,
809           "nnc_aten_matmul",
810           {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())},
811           {}));
812   Tensor Result = Compute(
813       "Result",
814       {1, 16, 32, 32},
815       [&](const VarHandle& n,
816           const VarHandle& c,
817           const VarHandle& h,
818           const VarHandle& w) {
819         return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w);
820       });
821 
822   LoopNest l({Input, Weight, ConvResult, MatmulResult, Result});
823 
824   // Inlining should not inline anything here since all Bufs are either defined
825   // or used in ExternalCalls - we run it just for testing
826   l.inlineIntermediateBufs(true);
827 
828   l.prepareForCodegen();
829   l.simplify();
830 
831   auto options = at::TensorOptions()
832                      .dtype(at::kFloat)
833                      .layout(at::kStrided)
834                      .device(at::kCPU)
835                      .requires_grad(false);
836   at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f;
837   at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
838   at::Tensor t = at::conv2d(input, weight);
839   at::Tensor t2 = at::matmul(t, t);
840   at::Tensor ref = t + t2;
841 
842   at::Tensor nnc_result;
843   std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f);
844   std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
845   std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f);
846   std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f);
847   std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f);
848 
849 #ifdef TORCH_ENABLE_LLVM
850   LLVMCodeGen llvm_codegen(
851       l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
852 
853   llvm_codegen.call(
854       {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
855   nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
856   ASSERT_TRUE(at::allclose(nnc_result, ref));
857 #endif
858 
859   SimpleIREvaluator ir_eval(
860       l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
861 
862   ir_eval.call(
863       {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
864   nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
865   ASSERT_TRUE(at::allclose(nnc_result, ref));
866 }
867 
TEST(ExternalCall,Inlining)868 TEST(ExternalCall, Inlining) {
869   // This test verifies that Tensors using external calls can be used by and
870   // can use Tensors built with Compute API.
871 
872   BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
873 
874   Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
875     return FloatImm::make(5.0f);
876   });
877   Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
878     return FloatImm::make(4.0f);
879   });
880   Tensor MatmulResult = Tensor(
881       MatmulResultBuf.node(),
882       ExternalCall::make(
883           MatmulResultBuf,
884           "nnc_aten_matmul",
885           {BufHandle(A.buf()), BufHandle(B.buf())},
886           {}));
887   Tensor Result =
888       Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
889         return MatmulResult.load(i, j) + FloatImm::make(3.0f);
890       });
891 
892   StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>(
893       {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()}));
894   LoopNest l(root_stmt, {Result.buf()});
895 
896   // Inlining should not inline anything here since all Bufs are either
897   // defined or used in ExternalCalls
898   l.inlineIntermediateBufs(false);
899 
900   l.prepareForCodegen();
901   l.simplify();
902 
903   auto options = at::TensorOptions()
904                      .dtype(at::kFloat)
905                      .layout(at::kStrided)
906                      .device(at::kCPU)
907                      .requires_grad(false);
908   at::Tensor a = at::ones({8, 8}, options) * 5.f;
909   at::Tensor b = at::ones({8, 8}, options) * 4.f;
910   at::Tensor t = at::matmul(a, b);
911   at::Tensor ref = t + 3.f;
912 
913   at::Tensor nnc_result;
914   std::vector<float> result_buf(8 * 8);
915 
916 #ifdef TORCH_ENABLE_LLVM
917   LLVMCodeGen llvm_codegen(l.root_stmt(), {Result});
918 
919   llvm_codegen.call({result_buf});
920   nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
921   ASSERT_TRUE(at::allclose(nnc_result, ref));
922 #endif
923 
924   SimpleIREvaluator ir_eval(l.root_stmt(), {Result});
925 
926   ir_eval.call({result_buf});
927   nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
928   ASSERT_TRUE(at::allclose(nnc_result, ref));
929 }
930 
TEST(ExternalCall,JitCustomFusionOp)931 TEST(ExternalCall, JitCustomFusionOp) {
932   const char* custom_op_schema_literal =
933       "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor";
934   const char* external_func_name = "nnc_add_mul";
935 
936   auto add_mul_lowering_func =
937       [external_func_name](
938           const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
939           const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
940           const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
941           const std::optional<torch::jit::tensorexpr::ScalarType>& output_type,
942           at::Device device) {
943         auto output_dtype = Dtype(*output_type);
944         torch::jit::tensorexpr::BufHandle result_buf(
945             "nnc_add_mul_res_buf", output_shape, output_dtype);
946         const torch::jit::tensorexpr::BufHandle& a =
947             std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
948         const torch::jit::tensorexpr::BufHandle& b =
949             std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
950         const torch::jit::tensorexpr::BufHandle& c =
951             std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
952         torch::jit::tensorexpr::StmtPtr s =
953             torch::jit::tensorexpr::ExternalCall::make(
954                 result_buf, external_func_name, {a, b, c}, {});
955         return Tensor(result_buf.node(), s);
956       };
957 
958   auto add_mul_external_func = [](int64_t bufs_num,
959                                   void** buf_data,
960                                   int64_t* buf_ranks,
961                                   int64_t* buf_dims,
962                                   int64_t* buf_strides,
963                                   int8_t* buf_dtypes,
964                                   int64_t args_num,
965                                   int64_t* extra_args) {};
966 
967   torch::jit::RegisterOperators reg({Operator(
968       custom_op_schema_literal,
969       [](const Node* node) -> Operation {
970         return [](Stack& _stack) {
971           auto a = std::move(peek(_stack, 0, 3)).toTensor();
972           auto b = std::move(peek(_stack, 1, 3)).toTensor();
973           auto c = std::move(peek(_stack, 2, 3)).toTensor();
974           drop(_stack, 3);
975           auto result = (a + b) * c;
976           pack(_stack, std::move(result));
977           return 0;
978         };
979       },
980       c10::AliasAnalysisKind::FROM_SCHEMA)});
981 
982   auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet();
983   custom_operator_set.insert({custom_op_schema_literal});
984 
985   auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();
986   te_lowering_registry.insert(
987       parseSchema(custom_op_schema_literal), add_mul_lowering_func);
988 
989   auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();
990   te_nnc_func_registry[external_func_name] = add_mul_external_func;
991 
992   std::string graph_string = R"IR(
993     graph(%a : Float(10, 20, strides=[20, 1], device=cpu),
994           %b : Float(10, 20, strides=[20, 1], device=cpu),
995           %c : Float(10, 20, strides=[20, 1], device=cpu)):
996       %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c)
997       return (%res))IR";
998 
999   auto graph = std::make_shared<Graph>();
1000   torch::jit::parseIR(graph_string, graph.get());
1001 
1002   std::string shape_compute_python_string = R"PY(
1003   def computOutput(a: List[int], b: List[int], c: List[int]):
1004     expandedSizes: List[int] = []
1005     dimsA = len(a)
1006     dimsB = len(b)
1007     dimsC = len(c)
1008     ndim = max(dimsA, dimsB, dimsC)
1009     for i in range(ndim):
1010         offset = ndim - 1 - i
1011         dimA = dimsA - 1 - offset
1012         dimB = dimsB - 1 - offset
1013         dimC = dimsC - 1 - offset
1014         sizeA = a[dimA] if (dimA >= 0) else 1
1015         sizeB = b[dimB] if (dimB >= 0) else 1
1016         sizeC = a[dimC] if (dimC >= 0) else 1
1017 
1018         if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1:
1019             # TODO: only assertion error is bound in C++ compilation right now
1020             raise AssertionError(
1021                 "The size of tensor a {} must match the size of tensor b ("
1022                 "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i)
1023             )
1024 
1025         expandedSizes.append(max(sizeA, sizeB, sizeC))
1026 
1027     return expandedSizes
1028   )PY";
1029   auto cu_ptr = torch::jit::compile(shape_compute_python_string);
1030   torch::jit::GraphFunction* gf =
1031       (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput");
1032   ASSERT_TRUE(gf);
1033 
1034 #ifdef TORCH_ENABLE_LLVM
1035   auto static_graph_case = graph->copy();
1036   FuseTensorExprs(static_graph_case, 1);
1037   torch::jit::testing::FileCheck()
1038       .check("prim::TensorExprGroup_")
1039       ->check("nnc_custom::add_mul")
1040       ->run(*static_graph_case);
1041 
1042   auto dynamic_graph_case = graph->copy();
1043   auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal);
1044   ASSERT_TRUE(custom_op);
1045   torch::jit::RegisterShapeComputeGraphForSchema(
1046       custom_op->schema(), gf->graph());
1047   FuseTensorExprs(dynamic_graph_case, 1, false, true);
1048   torch::jit::testing::FileCheck()
1049       .check("prim::TensorExprGroup_")
1050       ->check("nnc_custom::add_mul")
1051       ->run(*dynamic_graph_case);
1052 #else
1053   torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph);
1054 #endif
1055 }
1056 
1057 } // namespace jit
1058 } // namespace torch
1059