1 #include <gtest/gtest.h>
2
3 #include <test/cpp/tensorexpr/test_base.h>
4
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
7 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8 #include <torch/csrc/jit/runtime/custom_operator.h>
9 #include <torch/csrc/jit/tensorexpr/kernel.h>
10
11 #include <test/cpp/tensorexpr/test_utils.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
16 #include <torch/csrc/jit/tensorexpr/ir.h>
17 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
18 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
19 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
20 #include <torch/csrc/jit/tensorexpr/loopnest.h>
21 #include <torch/csrc/jit/tensorexpr/tensor.h>
22
23 #include <torch/csrc/jit/testing/file_check.h>
24 #include <torch/jit.h>
25
26 #include <ATen/NativeFunctions.h>
27 #include <ATen/core/dispatch/Dispatcher.h>
28 #include <ATen/native/xnnpack/OpContext.h>
29
30 namespace torch {
31 namespace jit {
32 using namespace torch::jit::tensorexpr;
33
TEST(ExternalCall,Conv1d_float)34 TEST(ExternalCall, Conv1d_float) {
35 BufHandle Input("Input", {1, 100, 115}, kFloat);
36 BufHandle Weight("Weight", {100, 1, 7}, kFloat);
37 BufHandle Bias("Bias", {100}, kFloat);
38 BufHandle ResultBuf("Result", {1, 100, 115}, kFloat);
39 int64_t stride = 1;
40 int64_t pad = 3;
41 int64_t dilation = 1;
42 int64_t groups = 100;
43
44 Tensor Result = Tensor(
45 ResultBuf.node(),
46 ExternalCall::make(
47 ResultBuf,
48 "nnc_aten_conv1d",
49 {Input, Weight, Bias},
50 {stride, pad, dilation, groups}));
51 LoopNest l({Result});
52 l.prepareForCodegen();
53 l.simplify();
54
55 auto options = at::TensorOptions()
56 .dtype(at::kFloat)
57 .layout(at::kStrided)
58 .device(at::kCPU)
59 .requires_grad(false);
60 at::Tensor input = at::ones({1, 100, 115}, options) * 5.f;
61 at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f;
62 at::Tensor bias = at::ones({100}, options) * 11.f;
63 at::Tensor ref =
64 at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
65
66 at::Tensor nnc_result;
67 std::vector<float> input_buf(1 * 100 * 115, 5.f);
68 std::vector<float> weight_buf(100 * 1 * 7, 6.f);
69 std::vector<float> bias_buf(100, 11.f);
70 std::vector<float> result_buf(1 * 100 * 115, -1.f);
71
72 #ifdef TORCH_ENABLE_LLVM
73 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
74
75 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
76 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
77 ASSERT_TRUE(at::allclose(nnc_result, ref));
78 #endif
79
80 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
81
82 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
83 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
84 ASSERT_TRUE(at::allclose(nnc_result, ref));
85 }
86
TEST(ExternalCall,Conv1d_int)87 TEST(ExternalCall, Conv1d_int) {
88 // A similar test, but now using kInt tensors
89 BufHandle Input("Input", {1, 100, 115}, kInt);
90 BufHandle Weight("Weight", {100, 1, 7}, kInt);
91 BufHandle Bias("Bias", {100}, kInt);
92 BufHandle ResultBuf("Result", {1, 100, 115}, kInt);
93 int64_t stride = 1;
94 int64_t pad = 3;
95 int64_t dilation = 1;
96 int64_t groups = 100;
97
98 Tensor Result = Tensor(
99 ResultBuf.node(),
100 ExternalCall::make(
101 ResultBuf,
102 "nnc_aten_conv1d",
103 {Input, Weight, Bias},
104 {stride, pad, dilation, groups}));
105 LoopNest l({Result});
106 l.prepareForCodegen();
107 l.simplify();
108
109 auto options = at::TensorOptions()
110 .dtype(at::kInt)
111 .layout(at::kStrided)
112 .device(at::kCPU)
113 .requires_grad(false);
114 at::Tensor input = at::ones({1, 100, 115}, options) * 5;
115 at::Tensor weight = at::ones({100, 1, 7}, options) * 6;
116 at::Tensor bias = at::ones({100}, options) * 11;
117 at::Tensor ref =
118 at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
119
120 at::Tensor nnc_result;
121 std::vector<int32_t> input_buf(1 * 100 * 115, 5);
122 std::vector<int32_t> weight_buf(100 * 1 * 7, 6);
123 std::vector<int32_t> bias_buf(100, 11);
124 std::vector<int32_t> result_buf(1 * 100 * 115, -1);
125
126 #ifdef TORCH_ENABLE_LLVM
127 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
128
129 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
130 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
131 ASSERT_TRUE(at::allclose(nnc_result, ref));
132 #endif
133
134 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
135
136 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
137 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
138 ASSERT_TRUE(at::allclose(nnc_result, ref));
139 }
140
TEST(ExternalCall,Conv1d_nobias_noargs)141 TEST(ExternalCall, Conv1d_nobias_noargs) {
142 BufHandle Input("Input", {1, 1, 115}, kFloat);
143 BufHandle Weight("Weight", {10, 1, 7}, kFloat);
144 BufHandle ResultBuf("Result", {1, 10, 109}, kFloat);
145
146 Tensor Result = Tensor(
147 ResultBuf.node(),
148 ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {}));
149 LoopNest l({Result});
150 l.prepareForCodegen();
151 l.simplify();
152
153 auto options = at::TensorOptions()
154 .dtype(at::kFloat)
155 .layout(at::kStrided)
156 .device(at::kCPU)
157 .requires_grad(false);
158 at::Tensor input = at::ones({1, 1, 115}, options) * 5.f;
159 at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f;
160 at::Tensor ref = at::conv1d(input, weight);
161
162 at::Tensor nnc_result;
163 std::vector<float> input_buf(1 * 1 * 115, 5.f);
164 std::vector<float> weight_buf(10 * 1 * 7, 6.f);
165 std::vector<float> result_buf(1 * 10 * 109, -1.f);
166
167 #ifdef TORCH_ENABLE_LLVM
168 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
169
170 llvm_codegen.call({input_buf, weight_buf, result_buf});
171 nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
172 ASSERT_TRUE(at::allclose(nnc_result, ref));
173 #endif
174
175 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
176
177 ir_eval.call({input_buf, weight_buf, result_buf});
178 nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
179 ASSERT_TRUE(at::allclose(nnc_result, ref));
180 }
181
TEST(ExternalCall,Conv2d_float)182 TEST(ExternalCall, Conv2d_float) {
183 BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
184 BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat);
185 BufHandle Bias("Bias", {16}, kFloat);
186 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
187 int64_t stride = 2;
188 int64_t pad = 1;
189 int64_t dilation = 1;
190 int64_t groups = 1;
191
192 Tensor Result = Tensor(
193 ResultBuf.node(),
194 ExternalCall::make(
195 ResultBuf,
196 "nnc_aten_conv2d",
197 {Input, Weight, Bias},
198 {stride, stride, pad, pad, dilation, dilation, groups}));
199 LoopNest l({Result});
200 l.prepareForCodegen();
201 l.simplify();
202
203 auto options = at::TensorOptions()
204 .dtype(at::kFloat)
205 .layout(at::kStrided)
206 .device(at::kCPU)
207 .requires_grad(false);
208 at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f;
209 at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f;
210 at::Tensor bias = at::ones({16}, options) * 11.f;
211 at::Tensor ref = at::conv2d(
212 input,
213 weight,
214 bias,
215 {stride, stride},
216 {pad, pad},
217 {dilation, dilation},
218 groups);
219
220 at::Tensor nnc_result;
221 std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f);
222 std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f);
223 std::vector<float> bias_buf(16, 11.f);
224 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
225
226 #ifdef TORCH_ENABLE_LLVM
227 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
228
229 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
230 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
231 ASSERT_TRUE(at::allclose(nnc_result, ref));
232 #endif
233
234 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
235
236 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
237 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
238 ASSERT_TRUE(at::allclose(nnc_result, ref));
239 }
240
TEST(ExternalCall,Conv2d_int)241 TEST(ExternalCall, Conv2d_int) {
242 // A similar test, but now using kInt tensors
243
244 BufHandle Input("Input", {1, 3, 224, 224}, kInt);
245 BufHandle Weight("Weight", {16, 3, 3, 3}, kInt);
246 BufHandle Bias("Bias", {16}, kInt);
247 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt);
248 int64_t stride = 2;
249 int64_t pad = 1;
250 int64_t dilation = 1;
251 int64_t groups = 1;
252
253 Tensor Result = Tensor(
254 ResultBuf.node(),
255 ExternalCall::make(
256 ResultBuf,
257 "nnc_aten_conv2d",
258 {Input, Weight, Bias},
259 {stride, stride, pad, pad, dilation, dilation, groups}));
260 LoopNest l({Result});
261 l.prepareForCodegen();
262 l.simplify();
263
264 auto options = at::TensorOptions()
265 .dtype(at::kInt)
266 .layout(at::kStrided)
267 .device(at::kCPU)
268 .requires_grad(false);
269 at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5;
270 at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6;
271 at::Tensor bias = at::ones({16}, options) * 11;
272 at::Tensor ref = at::conv2d(
273 input,
274 weight,
275 bias,
276 {stride, stride},
277 {pad, pad},
278 {dilation, dilation},
279 groups);
280
281 at::Tensor nnc_result;
282 std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5);
283 std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6);
284 std::vector<int32_t> bias_buf(16, 11);
285 std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1);
286
287 #ifdef TORCH_ENABLE_LLVM
288 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
289
290 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
291 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
292 ASSERT_TRUE(at::allclose(nnc_result, ref));
293 #endif
294
295 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
296
297 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
298 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
299 ASSERT_TRUE(at::allclose(nnc_result, ref));
300 }
301
TEST(ExternalCall,Conv2d_nobias_noargs)302 TEST(ExternalCall, Conv2d_nobias_noargs) {
303 BufHandle Input("Input", {1, 16, 112, 112}, kFloat);
304 BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat);
305 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
306
307 Tensor Result = Tensor(
308 ResultBuf.node(),
309 ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {}));
310 LoopNest l({Result});
311 l.prepareForCodegen();
312 l.simplify();
313
314 auto options = at::TensorOptions()
315 .dtype(at::kFloat)
316 .layout(at::kStrided)
317 .device(at::kCPU)
318 .requires_grad(false);
319 at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f;
320 at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
321 at::Tensor ref = at::conv2d(input, weight);
322
323 at::Tensor nnc_result;
324 std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f);
325 std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
326 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
327
328 #ifdef TORCH_ENABLE_LLVM
329 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
330
331 llvm_codegen.call({input_buf, weight_buf, result_buf});
332 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
333 ASSERT_TRUE(at::allclose(nnc_result, ref));
334 #endif
335
336 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
337
338 ir_eval.call({input_buf, weight_buf, result_buf});
339 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
340 ASSERT_TRUE(at::allclose(nnc_result, ref));
341 }
342
TEST(ExternalCall,Addmm_float)343 TEST(ExternalCall, Addmm_float) {
344 BufHandle Input("Input", {100, 300}, kFloat);
345 BufHandle Mat1("Mat1", {100, 200}, kFloat);
346 BufHandle Mat2("Mat2", {200, 300}, kFloat);
347 BufHandle ResultBuf("Result", {100, 300}, kFloat);
348 int64_t beta = 2;
349 int64_t alpha = 2;
350
351 Tensor Result = Tensor(
352 ResultBuf.node(),
353 ExternalCall::make(
354 ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha}));
355 LoopNest l({Result});
356 l.prepareForCodegen();
357 l.simplify();
358
359 auto options = at::TensorOptions()
360 .dtype(at::kFloat)
361 .layout(at::kStrided)
362 .device(at::kCPU)
363 .requires_grad(false);
364 at::Tensor input = at::ones({100, 300}, options) * 5.f;
365 at::Tensor mat1 = at::ones({100, 200}, options) * 6.f;
366 at::Tensor mat2 = at::ones({200, 300}, options) * 11.f;
367 at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha);
368
369 at::Tensor nnc_result;
370 std::vector<float> input_buf(100 * 300, 5.f);
371 std::vector<float> mat1_buf(100 * 200, 6.f);
372 std::vector<float> mat2_buf(200 * 300, 11.f);
373 std::vector<float> result_buf(100 * 300, -1.f);
374
375 #ifdef TORCH_ENABLE_LLVM
376 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result});
377
378 llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf});
379 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
380 ASSERT_TRUE(at::allclose(nnc_result, ref));
381 #endif
382
383 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result});
384
385 ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf});
386 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
387 ASSERT_TRUE(at::allclose(nnc_result, ref));
388 }
389
TEST(ExternalCall,Embedding)390 TEST(ExternalCall, Embedding) {
391 BufHandle Weight("Weight", {256, 100}, kFloat);
392 BufHandle Indices("Indices", {1, 115}, kLong);
393 BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
394 int64_t padding_idx = -1;
395 bool scale_grad_by_freq = false;
396 bool sparse = false;
397
398 Tensor Result = Tensor(
399 ResultBuf.node(),
400 ExternalCall::make(
401 ResultBuf,
402 "nnc_aten_embedding",
403 {Weight, Indices},
404 {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
405 LoopNest l({Result});
406 l.prepareForCodegen();
407 l.simplify();
408
409 auto options = at::TensorOptions()
410 .layout(at::kStrided)
411 .device(at::kCPU)
412 .requires_grad(false);
413
414 at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
415 at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
416 at::Tensor ref =
417 at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
418
419 at::Tensor nnc_result;
420 std::vector<float> weight_buf(256 * 100, 5.f);
421 std::vector<int64_t> indices_buf(1 * 115, 6);
422 std::vector<float> result_buf(1 * 115 * 100, -1.f);
423
424 #ifdef TORCH_ENABLE_LLVM
425 LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
426
427 llvm_codegen.call({weight_buf, indices_buf, result_buf});
428 nnc_result = at::from_blob(
429 result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
430 ASSERT_TRUE(at::allclose(nnc_result, ref));
431 #endif
432
433 SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
434
435 ir_eval.call({weight_buf, indices_buf, result_buf});
436 nnc_result = at::from_blob(
437 result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
438 ASSERT_TRUE(at::allclose(nnc_result, ref));
439 }
440
TEST(ExternalCall,MaxReduction)441 TEST(ExternalCall, MaxReduction) {
442 BufHandle Input("Input", {1, 115, 152}, kFloat);
443 BufHandle ResultBuf("Result", {1, 152}, kFloat);
444 int64_t dim = 1;
445 bool keep_dim = false;
446
447 Tensor Result = Tensor(
448 ResultBuf.node(),
449 ExternalCall::make(
450 ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim}));
451 LoopNest l({Result});
452 l.prepareForCodegen();
453 l.simplify();
454
455 auto options = at::TensorOptions()
456 .dtype(at::kFloat)
457 .layout(at::kStrided)
458 .device(at::kCPU)
459 .requires_grad(false);
460
461 at::Tensor input = at::ones({1, 115, 152}, options) * 5.f;
462 at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim));
463
464 at::Tensor nnc_result;
465 std::vector<float> input_buf(1 * 115 * 152, 5.f);
466 std::vector<float> result_buf(1 * 152, -1.f);
467
468 #ifdef TORCH_ENABLE_LLVM
469 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result});
470
471 llvm_codegen.call({input_buf, result_buf});
472 nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
473 ASSERT_TRUE(at::allclose(nnc_result, ref));
474 #endif
475
476 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result});
477
478 ir_eval.call({input_buf, result_buf});
479 nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
480 ASSERT_TRUE(at::allclose(nnc_result, ref));
481 }
482
483 #ifdef USE_XNNPACK
484
TEST(ExternalCall,Prepacked_Linear_float)485 TEST(ExternalCall, Prepacked_Linear_float) {
486 using namespace at::native::xnnpack;
487
488 BufHandle Input("Input", {100, 200}, kFloat);
489 BufHandle ResultBuf("Result", {100, 300}, kFloat);
490
491 // Calculate reference result using at::linear.
492 auto options = at::TensorOptions()
493 .dtype(at::kFloat)
494 .layout(at::kStrided)
495 .device(at::kCPU)
496 .requires_grad(false);
497 at::Tensor input =
498 at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200});
499 at::Tensor weight =
500 at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200});
501 at::Tensor bias = at::linspace(-10.0, 10.0, 300, options);
502 at::Tensor ref = at::linear(input, weight, bias);
503
504 // Create prepacked xnnpack context object.
505 auto linear_clamp_prepack_op =
506 c10::Dispatcher::singleton()
507 .findSchemaOrThrow("prepacked::linear_clamp_prepack", "")
508 .typed<c10::intrusive_ptr<LinearOpContext>(
509 at::Tensor,
510 std::optional<at::Tensor>,
511 const std::optional<at::Scalar>&,
512 const std::optional<at::Scalar>&)>();
513 auto prepacked = linear_clamp_prepack_op.call(
514 weight, bias, std::optional<at::Scalar>(), std::optional<at::Scalar>());
515
516 BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
517 Tensor Result = Tensor(
518 ResultBuf.node(),
519 ExternalCall::make(
520 ResultBuf,
521 "nnc_prepacked_linear_clamp_run",
522 {Input, DummyPrepacked},
523 {}));
524 LoopNest l({Result});
525 l.prepareForCodegen();
526 l.simplify();
527
528 at::Tensor nnc_result;
529 std::vector<float> input_buf(
530 input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200);
531 std::vector<float> result_buf(100 * 300, -1.f);
532
533 #ifdef TORCH_ENABLE_LLVM
534 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
535
536 llvm_codegen.call({input_buf, prepacked.get(), result_buf});
537 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
538 ASSERT_TRUE(at::allclose(nnc_result, ref));
539 #endif
540
541 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
542
543 ir_eval.call({input_buf, prepacked.get(), result_buf});
544 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
545 ASSERT_TRUE(at::allclose(nnc_result, ref));
546 }
547
TEST(ExternalCall,Prepacked_Conv2d_float)548 TEST(ExternalCall, Prepacked_Conv2d_float) {
549 using namespace at::native::xnnpack;
550
551 BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
552 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
553 int64_t stride = 2;
554 int64_t pad = 1;
555 int64_t dilation = 1;
556 int64_t groups = 1;
557
558 // Calculate reference result using at::conv2d.
559 auto options = at::TensorOptions()
560 .dtype(at::kFloat)
561 .layout(at::kStrided)
562 .device(at::kCPU)
563 .requires_grad(false);
564 at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options)
565 .resize_({1, 3, 224, 224});
566 at::Tensor weight =
567 at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3});
568 at::Tensor bias = at::linspace(-10.0, 10.0, 16, options);
569 at::Tensor ref = at::conv2d(
570 input,
571 weight,
572 bias,
573 {stride, stride},
574 {pad, pad},
575 {dilation, dilation},
576 groups);
577
578 // Create prepacked xnnpack context object.
579 auto conv2d_clamp_prepack_op =
580 c10::Dispatcher::singleton()
581 .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "")
582 .typed<c10::intrusive_ptr<Conv2dOpContext>(
583 at::Tensor,
584 std::optional<at::Tensor>,
585 std::vector<int64_t>,
586 std::vector<int64_t>,
587 std::vector<int64_t>,
588 int64_t,
589 const std::optional<at::Scalar>&,
590 const std::optional<at::Scalar>&)>();
591 auto prepacked = conv2d_clamp_prepack_op.call(
592 weight,
593 bias,
594 {stride, stride},
595 {pad, pad},
596 {dilation, dilation},
597 groups,
598 std::optional<at::Scalar>(),
599 std::optional<at::Scalar>());
600
601 BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
602 Tensor Result = Tensor(
603 ResultBuf.node(),
604 ExternalCall::make(
605 ResultBuf,
606 "nnc_prepacked_conv2d_clamp_run",
607 {Input, DummyPrepacked},
608 {}));
609 LoopNest l({Result});
610 l.prepareForCodegen();
611 l.simplify();
612
613 at::Tensor nnc_result;
614 std::vector<float> input_buf(
615 input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224);
616 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
617
618 #ifdef TORCH_ENABLE_LLVM
619 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
620
621 llvm_codegen.call({input_buf, prepacked.get(), result_buf});
622 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
623 ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
624 #endif
625
626 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
627
628 ir_eval.call({input_buf, prepacked.get(), result_buf});
629 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
630 ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
631 }
632
633 #endif // USE_XNNPACK
634
TEST(ExternalCall,BinaryFloat)635 TEST(ExternalCall, BinaryFloat) {
636 using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;
637 using Test = std::tuple<
638 std::vector<int64_t>,
639 std::vector<int64_t>,
640 std::vector<int64_t>,
641 TensorFunc,
642 std::string>;
643 std::vector<Test> tests = {};
644 tests.push_back(
645 Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});
646 tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});
647 tests.push_back(
648 Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});
649 for (auto curTest : tests) {
650 auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest;
651 auto toExprHandleVec = [](std::vector<int64_t> v) {
652 auto intV = std::vector<int>(v.begin(), v.end());
653 return std::vector<ExprHandle>(intV.begin(), intV.end());
654 };
655 BufHandle A("A", toExprHandleVec(aShape), kFloat);
656 BufHandle B("B", toExprHandleVec(bShape), kFloat);
657 BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
658
659 Tensor Result = Tensor(
660 ResultBuf.node(),
661 ExternalCall::make(ResultBuf, externCallName, {A, B}, {}));
662 LoopNest l({Result});
663 l.prepareForCodegen();
664 l.simplify();
665
666 auto options = at::TensorOptions()
667 .dtype(at::kFloat)
668 .layout(at::kStrided)
669 .device(at::kCPU)
670 .requires_grad(false);
671 at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
672 at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;
673 at::Tensor ref = torchFunc(a, b);
674
675 auto prod = [](std::vector<int64_t> v) {
676 // NOLINTNEXTLINE(modernize-use-transparent-functors)
677 return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
678 };
679
680 at::Tensor nnc_result;
681 std::vector<float> a_buf(prod(aShape), 5.f);
682 std::vector<float> b_buf(prod(bShape), 6.f);
683 std::vector<float> result_buf(prod(resShape), -1.f);
684
685 #ifdef TORCH_ENABLE_LLVM
686 LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
687
688 llvm_codegen.call({a_buf, b_buf, result_buf});
689 nnc_result =
690 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
691 ASSERT_TRUE(at::allclose(nnc_result, ref));
692 #endif
693
694 SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
695 ir_eval.call({a_buf, b_buf, result_buf});
696 nnc_result =
697 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
698 ASSERT_TRUE(at::allclose(nnc_result, ref));
699 }
700 }
701
TEST(ExternalCall,UnaryFloat)702 TEST(ExternalCall, UnaryFloat) {
703 using TensorFunc = std::function<at::Tensor(at::Tensor)>;
704 auto toExprHandleVec = [](std::vector<int64_t> v) {
705 auto intV = std::vector<int>(v.begin(), v.end());
706 return std::vector<ExprHandle>(intV.begin(), intV.end());
707 };
708 using Test = std::tuple<
709 std::vector<int64_t>,
710 std::vector<int64_t>,
711 TensorFunc,
712 std::string,
713 std::vector<ExprHandle>>;
714 std::vector<Test> tests = {};
715 tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
716 {1, 64, 8, 9},
717 {1, 64, 5, 7},
718 [](at::Tensor x) {
719 return at::adaptive_avg_pool2d(x, {5, 7});
720 },
721 "nnc_aten_adaptive_avg_pool2d",
722 toExprHandleVec({5, 7})});
723 tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
724 {100, 200},
725 {100},
726 [](at::Tensor x) { return at::mean(x, {1}); },
727 "nnc_aten_mean",
728 toExprHandleVec({1, /*keepdim=*/0})});
729 for (auto curTest : tests) {
730 auto [aShape, resShape, torchFunc, externCallName, externCallArgs] =
731 curTest;
732 BufHandle A("A", toExprHandleVec(aShape), kFloat);
733 BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
734
735 Tensor Result = Tensor(
736 ResultBuf.node(),
737 ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs));
738 LoopNest l({Result});
739 l.prepareForCodegen();
740 l.simplify();
741
742 auto options = at::TensorOptions()
743 .dtype(at::kFloat)
744 .layout(at::kStrided)
745 .device(at::kCPU)
746 .requires_grad(false);
747 at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
748 at::Tensor ref = torchFunc(a);
749
750 auto prod = [](std::vector<int64_t> v) {
751 // NOLINTNEXTLINE(modernize-use-transparent-functors)
752 return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
753 };
754
755 at::Tensor nnc_result;
756 std::vector<float> a_buf(prod(aShape), 5.f);
757 std::vector<float> result_buf(prod(resShape), -1.f);
758
759 #ifdef TORCH_ENABLE_LLVM
760 LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});
761
762 llvm_codegen.call({a_buf, result_buf});
763 nnc_result =
764 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
765 ASSERT_TRUE(at::allclose(nnc_result, ref));
766 #endif
767
768 SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});
769 ir_eval.call({a_buf, result_buf});
770 nnc_result =
771 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
772 ASSERT_TRUE(at::allclose(nnc_result, ref));
773 }
774 }
775
TEST(ExternalCall,ComputeInterop)776 TEST(ExternalCall, ComputeInterop) {
777 // This test verifies that Tensors using external calls can be used by and can
778 // use Tensors built with Compute API.
779
780 BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat);
781 BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat);
782
783 Tensor Input = Compute(
784 "Input",
785 {1, 16, 32, 32},
786 [&](const VarHandle& n,
787 const VarHandle& c,
788 const VarHandle& h,
789 const VarHandle& w) { return FloatImm::make(5.0f); });
790 Tensor Weight = Compute(
791 "Weight",
792 {16, 16, 1, 1},
793 [&](const VarHandle& n,
794 const VarHandle& c,
795 const VarHandle& h,
796 const VarHandle& w) { return FloatImm::make(6.0f); });
797
798 Tensor ConvResult = Tensor(
799 ConvResultBuf.node(),
800 ExternalCall::make(
801 ConvResultBuf,
802 "nnc_aten_conv2d",
803 {BufHandle(Input.buf()), BufHandle(Weight.buf())},
804 {}));
805 Tensor MatmulResult = Tensor(
806 MatmulResultBuf.node(),
807 ExternalCall::make(
808 MatmulResultBuf,
809 "nnc_aten_matmul",
810 {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())},
811 {}));
812 Tensor Result = Compute(
813 "Result",
814 {1, 16, 32, 32},
815 [&](const VarHandle& n,
816 const VarHandle& c,
817 const VarHandle& h,
818 const VarHandle& w) {
819 return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w);
820 });
821
822 LoopNest l({Input, Weight, ConvResult, MatmulResult, Result});
823
824 // Inlining should not inline anything here since all Bufs are either defined
825 // or used in ExternalCalls - we run it just for testing
826 l.inlineIntermediateBufs(true);
827
828 l.prepareForCodegen();
829 l.simplify();
830
831 auto options = at::TensorOptions()
832 .dtype(at::kFloat)
833 .layout(at::kStrided)
834 .device(at::kCPU)
835 .requires_grad(false);
836 at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f;
837 at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
838 at::Tensor t = at::conv2d(input, weight);
839 at::Tensor t2 = at::matmul(t, t);
840 at::Tensor ref = t + t2;
841
842 at::Tensor nnc_result;
843 std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f);
844 std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
845 std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f);
846 std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f);
847 std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f);
848
849 #ifdef TORCH_ENABLE_LLVM
850 LLVMCodeGen llvm_codegen(
851 l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
852
853 llvm_codegen.call(
854 {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
855 nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
856 ASSERT_TRUE(at::allclose(nnc_result, ref));
857 #endif
858
859 SimpleIREvaluator ir_eval(
860 l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
861
862 ir_eval.call(
863 {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
864 nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
865 ASSERT_TRUE(at::allclose(nnc_result, ref));
866 }
867
TEST(ExternalCall,Inlining)868 TEST(ExternalCall, Inlining) {
869 // This test verifies that Tensors using external calls can be used by and
870 // can use Tensors built with Compute API.
871
872 BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
873
874 Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
875 return FloatImm::make(5.0f);
876 });
877 Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
878 return FloatImm::make(4.0f);
879 });
880 Tensor MatmulResult = Tensor(
881 MatmulResultBuf.node(),
882 ExternalCall::make(
883 MatmulResultBuf,
884 "nnc_aten_matmul",
885 {BufHandle(A.buf()), BufHandle(B.buf())},
886 {}));
887 Tensor Result =
888 Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
889 return MatmulResult.load(i, j) + FloatImm::make(3.0f);
890 });
891
892 StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>(
893 {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()}));
894 LoopNest l(root_stmt, {Result.buf()});
895
896 // Inlining should not inline anything here since all Bufs are either
897 // defined or used in ExternalCalls
898 l.inlineIntermediateBufs(false);
899
900 l.prepareForCodegen();
901 l.simplify();
902
903 auto options = at::TensorOptions()
904 .dtype(at::kFloat)
905 .layout(at::kStrided)
906 .device(at::kCPU)
907 .requires_grad(false);
908 at::Tensor a = at::ones({8, 8}, options) * 5.f;
909 at::Tensor b = at::ones({8, 8}, options) * 4.f;
910 at::Tensor t = at::matmul(a, b);
911 at::Tensor ref = t + 3.f;
912
913 at::Tensor nnc_result;
914 std::vector<float> result_buf(8 * 8);
915
916 #ifdef TORCH_ENABLE_LLVM
917 LLVMCodeGen llvm_codegen(l.root_stmt(), {Result});
918
919 llvm_codegen.call({result_buf});
920 nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
921 ASSERT_TRUE(at::allclose(nnc_result, ref));
922 #endif
923
924 SimpleIREvaluator ir_eval(l.root_stmt(), {Result});
925
926 ir_eval.call({result_buf});
927 nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
928 ASSERT_TRUE(at::allclose(nnc_result, ref));
929 }
930
TEST(ExternalCall,JitCustomFusionOp)931 TEST(ExternalCall, JitCustomFusionOp) {
932 const char* custom_op_schema_literal =
933 "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor";
934 const char* external_func_name = "nnc_add_mul";
935
936 auto add_mul_lowering_func =
937 [external_func_name](
938 const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
939 const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
940 const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
941 const std::optional<torch::jit::tensorexpr::ScalarType>& output_type,
942 at::Device device) {
943 auto output_dtype = Dtype(*output_type);
944 torch::jit::tensorexpr::BufHandle result_buf(
945 "nnc_add_mul_res_buf", output_shape, output_dtype);
946 const torch::jit::tensorexpr::BufHandle& a =
947 std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
948 const torch::jit::tensorexpr::BufHandle& b =
949 std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
950 const torch::jit::tensorexpr::BufHandle& c =
951 std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
952 torch::jit::tensorexpr::StmtPtr s =
953 torch::jit::tensorexpr::ExternalCall::make(
954 result_buf, external_func_name, {a, b, c}, {});
955 return Tensor(result_buf.node(), s);
956 };
957
958 auto add_mul_external_func = [](int64_t bufs_num,
959 void** buf_data,
960 int64_t* buf_ranks,
961 int64_t* buf_dims,
962 int64_t* buf_strides,
963 int8_t* buf_dtypes,
964 int64_t args_num,
965 int64_t* extra_args) {};
966
967 torch::jit::RegisterOperators reg({Operator(
968 custom_op_schema_literal,
969 [](const Node* node) -> Operation {
970 return [](Stack& _stack) {
971 auto a = std::move(peek(_stack, 0, 3)).toTensor();
972 auto b = std::move(peek(_stack, 1, 3)).toTensor();
973 auto c = std::move(peek(_stack, 2, 3)).toTensor();
974 drop(_stack, 3);
975 auto result = (a + b) * c;
976 pack(_stack, std::move(result));
977 return 0;
978 };
979 },
980 c10::AliasAnalysisKind::FROM_SCHEMA)});
981
982 auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet();
983 custom_operator_set.insert({custom_op_schema_literal});
984
985 auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();
986 te_lowering_registry.insert(
987 parseSchema(custom_op_schema_literal), add_mul_lowering_func);
988
989 auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();
990 te_nnc_func_registry[external_func_name] = add_mul_external_func;
991
992 std::string graph_string = R"IR(
993 graph(%a : Float(10, 20, strides=[20, 1], device=cpu),
994 %b : Float(10, 20, strides=[20, 1], device=cpu),
995 %c : Float(10, 20, strides=[20, 1], device=cpu)):
996 %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c)
997 return (%res))IR";
998
999 auto graph = std::make_shared<Graph>();
1000 torch::jit::parseIR(graph_string, graph.get());
1001
1002 std::string shape_compute_python_string = R"PY(
1003 def computOutput(a: List[int], b: List[int], c: List[int]):
1004 expandedSizes: List[int] = []
1005 dimsA = len(a)
1006 dimsB = len(b)
1007 dimsC = len(c)
1008 ndim = max(dimsA, dimsB, dimsC)
1009 for i in range(ndim):
1010 offset = ndim - 1 - i
1011 dimA = dimsA - 1 - offset
1012 dimB = dimsB - 1 - offset
1013 dimC = dimsC - 1 - offset
1014 sizeA = a[dimA] if (dimA >= 0) else 1
1015 sizeB = b[dimB] if (dimB >= 0) else 1
1016 sizeC = a[dimC] if (dimC >= 0) else 1
1017
1018 if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1:
1019 # TODO: only assertion error is bound in C++ compilation right now
1020 raise AssertionError(
1021 "The size of tensor a {} must match the size of tensor b ("
1022 "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i)
1023 )
1024
1025 expandedSizes.append(max(sizeA, sizeB, sizeC))
1026
1027 return expandedSizes
1028 )PY";
1029 auto cu_ptr = torch::jit::compile(shape_compute_python_string);
1030 torch::jit::GraphFunction* gf =
1031 (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput");
1032 ASSERT_TRUE(gf);
1033
1034 #ifdef TORCH_ENABLE_LLVM
1035 auto static_graph_case = graph->copy();
1036 FuseTensorExprs(static_graph_case, 1);
1037 torch::jit::testing::FileCheck()
1038 .check("prim::TensorExprGroup_")
1039 ->check("nnc_custom::add_mul")
1040 ->run(*static_graph_case);
1041
1042 auto dynamic_graph_case = graph->copy();
1043 auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal);
1044 ASSERT_TRUE(custom_op);
1045 torch::jit::RegisterShapeComputeGraphForSchema(
1046 custom_op->schema(), gf->graph());
1047 FuseTensorExprs(dynamic_graph_case, 1, false, true);
1048 torch::jit::testing::FileCheck()
1049 .check("prim::TensorExprGroup_")
1050 ->check("nnc_custom::add_mul")
1051 ->run(*dynamic_graph_case);
1052 #else
1053 torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph);
1054 #endif
1055 }
1056
1057 } // namespace jit
1058 } // namespace torch
1059