1 #include <gtest/gtest.h>
2 
3 #include <ATen/code_template.h>
4 #include <c10/core/DeviceType.h>
5 #include <test/cpp/tensorexpr/test_base.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/irparser.h>
8 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
9 #include <torch/csrc/jit/tensorexpr/kernel.h>
10 #include <torch/csrc/jit/testing/file_check.h>
11 #include <torch/torch.h>
12 #include <cmath>
13 #include <sstream>
14 #include <stdexcept>
15 #include <thread>
16 
17 namespace torch {
18 namespace jit {
19 
20 using namespace torch::indexing;
21 using namespace torch::jit::tensorexpr;
22 
TEST(DynamicShapes,SimpleGraph)23 TEST(DynamicShapes, SimpleGraph) {
24 #ifdef TORCH_ENABLE_LLVM
25   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
26   const auto graph_string = R"IR(
27       graph(%x : Tensor,
28             %SS_2 : int,
29             %SS_3 : int):
30         %3 : Tensor = aten::tanh(%x)
31         %4 : Tensor = aten::erf(%3)
32         return (%4))IR";
33   torch::jit::parseIR(graph_string, graph.get());
34 
35   auto x_inp = graph->inputs()[0];
36   auto x_type = TensorType::create(at::rand({10, 5}));
37   std::vector<ShapeSymbol> x_sym_dims(
38       {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
39   auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
40   graph->inputs().at(0)->setType(x_sym_type);
41   for (const auto n : graph->nodes()) {
42     n->output()->setType(x_sym_type);
43   }
44 
45   // Graph with symbolic shapes:
46   //
47   // graph(%x : Float(SS(-2), SS(-3)),
48   //       %SS_2 : int,
49   //       %SS_3 : int):
50   //   %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
51   //   %4 : Float(SS(-2), SS(-3)) = aten::erf(%3)
52   //   return (%4)
53 
54   std::vector<torch::jit::StrideInput> input_desc = {
55       torch::jit::StrideInput::TENSOR_CONT};
56   std::unordered_map<
57       const torch::jit::Value*,
58       std::vector<torch::jit::StrideInput>>
59       symbolic_strides;
60   symbolic_strides[x_inp] = input_desc;
61   symbolic_strides[graph->outputs().at(0)] = input_desc;
62   std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
63       x_sym_dims,
64       [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
65 
66   TensorExprKernel kernel(
67       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
68   // Run with the same static dims as the one we initialized the graph with.
69   {
70     auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
71     auto ref = at::erf(at::tanh(a));
72 
73     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
74     stack.push_back(10);
75     stack.push_back(5);
76     kernel.run(stack);
77 
78     auto o = stack[0].toTensor();
79     ASSERT_TRUE(at::allclose(o, ref));
80   }
81 
82   // Run with inputs having different dims.
83   {
84     auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
85     auto ref = at::erf(at::tanh(a));
86 
87     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
88     stack.push_back(50);
89     stack.push_back(100);
90     kernel.run(stack);
91 
92     auto o = stack[0].toTensor();
93     ASSERT_TRUE(at::allclose(o, ref));
94   }
95 #endif
96 }
97 
TEST(DynamicShapes,GraphWith2InputsSameDims)98 TEST(DynamicShapes, GraphWith2InputsSameDims) {
99 #ifdef TORCH_ENABLE_LLVM
100   // The two inputs in this graph must have the same dims.
101   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
102   const auto graph_string = R"IR(
103       graph(%x : Tensor,
104             %y : Tensor,
105             %SS_2 : int,
106             %SS_3 : int):
107         %3 : Tensor = aten::tanh(%x)
108         %4 : Tensor = aten::erf(%3)
109         %5 : Tensor = aten::mul(%4, %y)
110         return (%5))IR";
111   torch::jit::parseIR(graph_string, graph.get());
112 
113   auto x_inp = graph->inputs()[0];
114   auto y_inp = graph->inputs()[1];
115   auto x_type = TensorType::create(at::rand({10, 5}));
116   std::vector<ShapeSymbol> x_sym_dims(
117       {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
118   auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
119   graph->inputs().at(0)->setType(x_sym_type);
120   graph->inputs().at(1)->setType(x_sym_type);
121   for (const auto n : graph->nodes()) {
122     n->output()->setType(x_sym_type);
123   }
124 
125   // Graph with symbolic shapes:
126   //
127   // graph(%x : Float(SS(-4), SS(-5)),
128   //       %y : Float(SS(-4), SS(-5)),
129   //       %SS_2 : int,
130   //       %SS_3 : int):
131   //   %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x)
132   //   %5 : Float(SS(-4), SS(-5)) = aten::erf(%4)
133   //   %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y)
134   //   return (%6)
135 
136   std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
137       x_sym_dims,
138       [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
139 
140   std::vector<torch::jit::StrideInput> input_desc = {
141       torch::jit::StrideInput::TENSOR_CONT};
142   std::unordered_map<
143       const torch::jit::Value*,
144       std::vector<torch::jit::StrideInput>>
145       symbolic_strides;
146   symbolic_strides[x_inp] = input_desc;
147   symbolic_strides[y_inp] = input_desc;
148   symbolic_strides[graph->outputs().at(0)] = input_desc;
149 
150   TensorExprKernel kernel(
151       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
152 
153   // Run with the same static dims as the one we initialized the graph with.
154   {
155     auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
156     auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
157     auto ref = at::mul(at::erf(at::tanh(a)), b);
158 
159     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
160     stack.push_back(10);
161     stack.push_back(5);
162     kernel.run(stack);
163 
164     auto o = stack[0].toTensor();
165     ASSERT_TRUE(at::allclose(o, ref));
166   }
167 
168   // Run with inputs having different dims.
169   {
170     auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
171     auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
172     auto ref = at::mul(at::erf(at::tanh(a)), b);
173 
174     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
175     stack.push_back(50);
176     stack.push_back(100);
177     kernel.run(stack);
178 
179     auto o = stack[0].toTensor();
180     ASSERT_TRUE(at::allclose(o, ref));
181   }
182 #endif
183 }
184 
TEST(DynamicShapes,GraphWith2InputsAndBroadcast)185 TEST(DynamicShapes, GraphWith2InputsAndBroadcast) {
186 #ifdef TORCH_ENABLE_LLVM
187   // The second input to the graph has a dim of size 1 which should be
188   // broadcasted in the at::mul op.
189   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
190   const auto graph_string = R"IR(
191       graph(%x : Float(10, 5, requires_grad=0, device=cpu),
192             %y : Float(1, 5, requires_grad=0, device=cpu),
193             %SS_2 : int,
194             %SS_3 : int):
195         %3 : Tensor = aten::tanh(%x)
196         %4 : Tensor = aten::erf(%3)
197         %5 : Tensor = aten::mul(%4, %y)
198         return (%5))IR";
199   torch::jit::parseIR(graph_string, graph.get());
200 
201   auto x_inp = graph->inputs()[0];
202   auto y_inp = graph->inputs()[1];
203   auto x_type = TensorType::create(at::rand({10, 5}));
204   auto y_type = TensorType::create(at::rand({1, 5}));
205   auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
206   auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
207   auto x_sym_type = x_type->withSymbolicShapes(
208       std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
209   auto y_sym_type = y_type->withSymbolicShapes(std::vector<ShapeSymbol>(
210       {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
211   graph->inputs().at(0)->setType(x_sym_type);
212   graph->inputs().at(1)->setType(y_sym_type);
213   for (const auto n : graph->nodes()) {
214     n->output()->setType(x_sym_type);
215   }
216 
217   // Graph with symbolic shapes:
218   //
219   // graph(%x : Float(SS(-6), SS(-7)),
220   //       %y : Float(1, SS(-7)),
221   //       %SS_2 : int,
222   //       %SS_3 : int):
223   //   %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x)
224   //   %5 : Float(SS(-6), SS(-7)) = aten::erf(%4)
225   //   %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y)
226   //   return (%6)
227 
228   std::vector<int64_t> symbolic_shape_inputs(
229       {x_dim0_sym.value(), x_dim1_sym.value()});
230 
231   std::vector<torch::jit::StrideInput> input_desc = {
232       torch::jit::StrideInput::TENSOR_CONT};
233   std::unordered_map<
234       const torch::jit::Value*,
235       std::vector<torch::jit::StrideInput>>
236       symbolic_strides;
237   symbolic_strides[x_inp] = input_desc;
238   symbolic_strides[y_inp] = input_desc;
239   symbolic_strides[graph->outputs().at(0)] = input_desc;
240 
241   TensorExprKernel kernel(
242       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
243 
244   // Run with the same static dims as the one we initialized the graph with.
245   {
246     auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
247     auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
248     auto ref = at::mul(at::erf(at::tanh(a)), b);
249 
250     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
251     stack.push_back(10);
252     stack.push_back(5);
253     kernel.run(stack);
254 
255     auto o = stack[0].toTensor();
256     ASSERT_TRUE(at::allclose(o, ref));
257   }
258 
259   // Run with inputs having different dims.
260   {
261     auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
262     auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
263     auto ref = at::mul(at::erf(at::tanh(a)), b);
264 
265     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
266     stack.push_back(50);
267     stack.push_back(100);
268     kernel.run(stack);
269 
270     auto o = stack[0].toTensor();
271     ASSERT_TRUE(at::allclose(o, ref));
272   }
273 #endif
274 }
275 
TEST(DynamicShapes,GraphWithPartiallySymbolicOutput)276 TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
277 #ifdef TORCH_ENABLE_LLVM
278   // The second input to the graph has a dim of size 1 which should be
279   // broadcasted in the at::mul op.
280   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
281   const auto graph_string = R"IR(
282       graph(%x : Float(1, 5, requires_grad=0, device=cpu),
283             %y : Float(1, 5, requires_grad=0, device=cpu),
284             %SS_2 : int):
285         %4 : Tensor = aten::tanh(%x)
286         %5 : Tensor = aten::mul(%4, %y)
287         return (%5))IR";
288   torch::jit::parseIR(graph_string, graph.get());
289 
290   auto x_inp = graph->inputs()[0];
291   auto y_inp = graph->inputs()[1];
292   auto x_type = TensorType::create(at::rand({1, 5}));
293   auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
294   auto x_sym_type = x_type->withSymbolicShapes(std::vector<ShapeSymbol>(
295       {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
296   graph->inputs().at(0)->setType(x_sym_type);
297   graph->inputs().at(1)->setType(x_sym_type);
298   for (const auto n : graph->nodes()) {
299     n->output()->setType(x_sym_type);
300   }
301 
302   // Graph with symbolic shapes:
303   //
304   // graph(%x : Float(1, SS(-2)),
305   //       %y : Float(1, SS(-2)),
306   //       %SS_2 : int):
307   //   %3 : Float(1, SS(-2)) = aten::tanh(%x)
308   //   %4 : Float(1, SS(-2)) = aten::mul(%3, %y)
309   //   return (%4)
310 
311   std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()});
312 
313   std::vector<torch::jit::StrideInput> input_desc = {
314       torch::jit::StrideInput::TENSOR_CONT};
315   std::unordered_map<
316       const torch::jit::Value*,
317       std::vector<torch::jit::StrideInput>>
318       symbolic_strides;
319   symbolic_strides[x_inp] = input_desc;
320   symbolic_strides[y_inp] = input_desc;
321   symbolic_strides[graph->outputs().at(0)] = input_desc;
322 
323   TensorExprKernel kernel(
324       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
325 
326   // Run with the same static dims as the one we initialized the graph with.
327   {
328     auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
329     auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
330     auto ref = at::mul(at::tanh(a), b);
331 
332     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
333     stack.push_back(5);
334     kernel.run(stack);
335 
336     auto o = stack[0].toTensor();
337     ASSERT_TRUE(at::allclose(o, ref));
338   }
339 
340   // Run with inputs having different dims.
341   {
342     auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
343     auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
344     auto ref = at::mul(at::tanh(a), b);
345 
346     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
347     stack.push_back(100);
348     kernel.run(stack);
349 
350     auto o = stack[0].toTensor();
351     ASSERT_TRUE(at::allclose(o, ref));
352   }
353 #endif
354 }
355 
TEST(DynamicShapes,GraphWithSymbolicStrides)356 TEST(DynamicShapes, GraphWithSymbolicStrides) {
357 #ifdef TORCH_ENABLE_LLVM
358   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
359   const auto graph_string = R"IR(
360     graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
361           %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
362           %SS_3 : int,
363           %SS_2 : int):
364       %15 : int = prim::Constant[value=1]()
365       %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15)
366       %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0)
367       return (%22))IR";
368   parseIR(graph_string, &*graph);
369 
370   std::vector<torch::jit::StrideInput> input_desc = {
371       torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE};
372   std::vector<torch::jit::StrideInput> output_desc = {
373       torch::jit::StrideInput::TENSOR_CONT};
374   std::unordered_map<
375       const torch::jit::Value*,
376       std::vector<torch::jit::StrideInput>>
377       symbolic_strides;
378   symbolic_strides[graph->inputs().at(0)] = input_desc;
379   symbolic_strides[graph->inputs().at(1)] = input_desc;
380   symbolic_strides[graph->outputs().at(0)] = output_desc;
381   std::vector<int64_t> symbolic_shape_inputs = {-3, -2};
382   TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
383 
384   {
385     auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
386     auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
387     auto ref = at::mul(at::add(x0, x1, 1), x0);
388 
389     std::vector<at::Tensor> inputs = {x0, x1};
390     std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
391     stack.push_back(32);
392     stack.push_back(10);
393     k.run(stack);
394 
395     auto o = stack[0].toTensor();
396     ASSERT_TRUE(at::allclose(o, ref));
397   }
398 
399   {
400     auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
401     auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
402     auto out =
403         at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
404     auto ref = at::mul(at::add(x0, x1, 1), x0);
405 
406     std::vector<at::Tensor> inputs = {out, x0, x1};
407     std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
408     stack.push_back(32);
409     stack.push_back(10);
410     k.runWithAllocatedOutputs(stack);
411 
412     ASSERT_TRUE(at::allclose(out, ref));
413   }
414 #endif
415 }
416 
TEST(DynamicShapes,GraphWithCatAndBroadcast)417 TEST(DynamicShapes, GraphWithCatAndBroadcast) {
418 #ifdef TORCH_ENABLE_LLVM
419   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
420   const auto graph_string = R"IR(
421       graph(%x : Float(10, 5, requires_grad=0, device=cpu),
422             %y : Float(4, 5, requires_grad=0, device=cpu),
423             %z : Float(1, 1, requires_grad=0, device=cpu),
424             %SS_2 : int,
425             %SS_3 : int,
426             %SS_4 : int,
427             %SS_5 : int):
428         %11 : int = prim::Constant[value=0]()
429         %3 : Tensor = aten::tanh(%x)
430         %out1 : Tensor = aten::erf(%3)
431         %out2 : Tensor = aten::relu(%y)
432         %10 : Tensor[] = prim::ListConstruct(%out1, %out2)
433         %25 : Tensor = aten::cat(%10, %11)
434         %28 : Tensor = aten::hardswish(%25)
435         %29 : Tensor = aten::mul(%28, %z)
436         return (%29))IR";
437   torch::jit::parseIR(graph_string, graph.get());
438 
439   auto x_inp = graph->inputs()[0];
440   auto y_inp = graph->inputs()[1];
441   auto z_inp = graph->inputs()[2];
442   auto x_type = TensorType::create(at::rand({10, 5}));
443   auto y_type = TensorType::create(at::rand({4, 5}));
444   auto z_type = TensorType::create(at::rand({1, 1}));
445   auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
446   auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
447   auto x_sym_type = x_type->withSymbolicShapes(
448       std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
449   auto y_dim0_sym = c10::ShapeSymbol::newSymbol();
450   auto y_sym_type = y_type->withSymbolicShapes(
451       std::vector<ShapeSymbol>({y_dim0_sym, x_dim1_sym}));
452   graph->inputs().at(0)->setType(x_sym_type);
453   graph->inputs().at(1)->setType(y_sym_type);
454   auto cat_dim0_sym = c10::ShapeSymbol::newSymbol();
455   auto cat_out_type = x_type->withSymbolicShapes(
456       std::vector<ShapeSymbol>({cat_dim0_sym, x_dim1_sym}));
457   auto nodeIt = graph->nodes().begin();
458   ++nodeIt;
459   nodeIt->output()->setType(x_sym_type); // aten::tanh
460   ++nodeIt;
461   nodeIt->output()->setType(x_sym_type); // aten::erf
462   ++nodeIt;
463   nodeIt->output()->setType(y_sym_type); // aten::relu
464   ++nodeIt;
465   ++nodeIt;
466   nodeIt->output()->setType(cat_out_type); // aten::cat
467   ++nodeIt;
468   nodeIt->output()->setType(cat_out_type); // aten::hardswish
469   ++nodeIt;
470   nodeIt->output()->setType(cat_out_type); // aten::mul
471 
472   // Graph with symbolic shapes:
473   //
474   // graph(%x : Float(SS(-2), SS(-3)),
475   //       %y : Float(SS(-4), SS(-3)),
476   //       %z : Float(1, 1),
477   //       %SS_2 : int,
478   //       %SS_3 : int,
479   //       %SS_4 : int,
480   //       %SS_5 : int):
481   //   %7 : int = prim::Constant[value=0]()
482   //   %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
483   //   %9 : Float(SS(-2), SS(-3)) = aten::erf(%8)
484   //   %10 : Float(SS(-4), SS(-3)) = aten::relu(%y)
485   //   %11 : Tensor[] = prim::ListConstruct(%9, %10)
486   //   %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7)
487   //   %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12)
488   //   %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z)
489   //   return (%14)
490 
491   std::vector<int64_t> symbolic_shape_inputs(
492       {x_dim0_sym.value(),
493        x_dim1_sym.value(),
494        y_dim0_sym.value(),
495        cat_dim0_sym.value()});
496 
497   std::vector<torch::jit::StrideInput> input_desc = {
498       torch::jit::StrideInput::TENSOR_CONT};
499   std::unordered_map<
500       const torch::jit::Value*,
501       std::vector<torch::jit::StrideInput>>
502       symbolic_strides;
503   symbolic_strides[x_inp] = input_desc;
504   symbolic_strides[y_inp] = input_desc;
505   symbolic_strides[z_inp] = input_desc;
506   symbolic_strides[graph->outputs().at(0)] = input_desc;
507 
508   TensorExprKernel kernel(
509       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
510 
511   auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
512   auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
513   auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
514   auto ref = at::mul(
515       at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c);
516 
517   std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
518   stack.push_back(10);
519   stack.push_back(5);
520   stack.push_back(4);
521   stack.push_back(14);
522   kernel.run(stack);
523 
524   auto o = stack[0].toTensor();
525   ASSERT_TRUE(at::allclose(o, ref));
526 #endif
527 }
528 
TEST(DynamicShapes,GraphFromModel)529 TEST(DynamicShapes, GraphFromModel) {
530 #ifdef TORCH_ENABLE_LLVM
531   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
532   const auto graph_string = R"IR(
533     graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
534           %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu),
535           %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu),
536           %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu),
537           %4 : Float(SS(-7), requires_grad=0, device=cpu),
538           %5 : Float(SS(-7), requires_grad=0, device=cpu),
539           %SS_10 : int,
540           %SS_9 : int,
541           %SS_8 : int,
542           %SS_7 : int,
543           %SS_6 : int,
544           %SS_5 : int,
545           %SS_4 : int,
546           %SS_3 : int,
547           %SS_2 : int):
548       %15 : int = prim::Constant[value=1]()
549       %16 : bool = prim::Constant[value=0]()
550       %17 : int = prim::Constant[value=6]()
551       %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16)
552       %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2)
553       %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15)
554       %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15)
555       %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4)
556       return (%22))IR";
557   parseIR(graph_string, &*graph);
558 
559   std::vector<torch::jit::StrideInput> input_desc = {
560       torch::jit::StrideInput::TENSOR_CONT};
561   std::unordered_map<
562       const torch::jit::Value*,
563       std::vector<torch::jit::StrideInput>>
564       symbolic_strides;
565   symbolic_strides[graph->inputs().at(0)] = input_desc;
566   symbolic_strides[graph->inputs().at(1)] = input_desc;
567   symbolic_strides[graph->inputs().at(2)] = input_desc;
568   symbolic_strides[graph->inputs().at(3)] = input_desc;
569   symbolic_strides[graph->inputs().at(4)] = input_desc;
570   symbolic_strides[graph->inputs().at(5)] = input_desc;
571   symbolic_strides[graph->outputs().at(0)] = input_desc;
572   std::vector<int64_t> symbolic_shape_inputs = {
573       -10, -9, -8, -7, -6, -5, -4, -3, -2};
574   TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
575 
576   int64_t i2 = 10;
577   int64_t i3 = 32;
578   int64_t i4 = 19;
579   int64_t i5 = 71;
580   int64_t i6 = 139;
581   int64_t i7 = 261;
582   int64_t i8 = 261;
583   int64_t i9 = 261;
584   int64_t i10 = 261;
585   auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
586   auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
587   auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
588   auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong));
589   auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
590   auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
591   auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4);
592 
593   {
594     std::vector<at::Tensor> inputs = {x0, x1, x2, x3, x4, x5};
595     std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
596     stack.emplace_back(i10);
597     stack.emplace_back(i9);
598     stack.emplace_back(i8);
599     stack.emplace_back(i7);
600     stack.emplace_back(i6);
601     stack.emplace_back(i5);
602     stack.emplace_back(i4);
603     stack.emplace_back(i3);
604     stack.emplace_back(i2);
605     k.run(stack);
606 
607     auto o = stack[0].toTensor();
608     ASSERT_TRUE(at::allclose(o, ref));
609   }
610 
611   {
612     auto out =
613         at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
614     std::vector<at::Tensor> inputs = {out, x0, x1, x2, x3, x4, x5};
615     std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
616     stack.emplace_back(i10);
617     stack.emplace_back(i9);
618     stack.emplace_back(i8);
619     stack.emplace_back(i7);
620     stack.emplace_back(i6);
621     stack.emplace_back(i5);
622     stack.emplace_back(i4);
623     stack.emplace_back(i3);
624     stack.emplace_back(i2);
625     k.runWithAllocatedOutputs(stack);
626 
627     ASSERT_TRUE(at::allclose(out, ref));
628   }
629 #endif
630 }
631 
TEST(DynamicShapes,MultiThreadedExecution)632 TEST(DynamicShapes, MultiThreadedExecution) {
633 #ifdef TORCH_ENABLE_LLVM
634   const auto graph_template = R"IR(
635       graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
636             %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
637             %SS_2 : int,
638             %SS_3 : int):
639         %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x)
640         %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3)
641         %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y)
642         return (%5))IR";
643   for (bool use_cuda : {false, true}) {
644     if (!torch::cuda::is_available() && use_cuda) {
645       continue;
646     }
647     auto device = use_cuda ? at::kCUDA : at::kCPU;
648     at::jit::TemplateEnv env;
649     env.s("device", use_cuda ? "cuda:0" : "cpu");
650     const auto graph_string = format(graph_template, env);
651     std::shared_ptr<Graph> graph = std::make_shared<Graph>();
652     torch::jit::parseIR(graph_string, graph.get());
653 
654     std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
655 
656     std::vector<torch::jit::StrideInput> input_desc = {
657         torch::jit::StrideInput::TENSOR_CONT};
658     std::unordered_map<
659         const torch::jit::Value*,
660         std::vector<torch::jit::StrideInput>>
661         symbolic_strides;
662     symbolic_strides[graph->inputs().at(0)] = input_desc;
663     symbolic_strides[graph->inputs().at(1)] = input_desc;
664     symbolic_strides[graph->outputs().at(0)] = input_desc;
665 
666     TensorExprKernel kernel(
667         graph, {}, symbolic_shape_inputs, false, symbolic_strides);
668 
669     auto run_kernel = [&](int dim1, int dim2) {
670       auto a =
671           at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
672       auto b =
673           at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
674 
675       auto ref = at::mul(at::erf(at::tanh(a)), b);
676 
677       std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
678       stack.emplace_back(dim1);
679       stack.emplace_back(dim2);
680       kernel.run(stack);
681 
682       auto o = stack[0].toTensor();
683       ASSERT_TRUE(at::allclose(o, ref));
684     };
685 
686     // Run the kernel in parallel to ensure that the run() method calls in
687     // TensorExprKernel are not changing any state.
688     constexpr size_t kNumThreads = 4;
689     std::vector<std::thread> threads;
690     for (size_t id = 0; id < kNumThreads; ++id) {
691       threads.emplace_back(run_kernel, id + 5, id + 20);
692     }
693     for (auto& t : threads) {
694       t.join();
695     }
696   }
697 #endif
698 }
699 
700 } // namespace jit
701 } // namespace torch
702