1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/tensorexpr/test_base.h>
4 #include <memory>
5 #include <sstream>
6 #include <stdexcept>
7 #include <unordered_map>
8 
9 #include <test/cpp/tensorexpr/padded_buffer.h>
10 #include <test/cpp/tensorexpr/test_utils.h>
11 #include <torch/csrc/jit/tensorexpr/analysis.h>
12 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
13 #include <torch/csrc/jit/tensorexpr/eval.h>
14 #include <torch/csrc/jit/tensorexpr/ir.h>
15 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
16 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17 #include <torch/csrc/jit/tensorexpr/loopnest.h>
18 #include <torch/csrc/jit/tensorexpr/tensor.h>
19 #include <torch/csrc/jit/testing/file_check.h>
20 
21 namespace torch {
22 namespace jit {
23 
24 using namespace torch::jit::tensorexpr;
25 
checkIR(StmtPtr s,const std::string & pattern)26 void checkIR(StmtPtr s, const std::string& pattern) {
27   std::ostringstream oss;
28   oss << *s;
29   torch::jit::testing::FileCheck().run(pattern, oss.str());
30 }
31 
checkExprIR(ExprPtr e,const std::string & pattern)32 void checkExprIR(ExprPtr e, const std::string& pattern) {
33   std::string prefixed_pattern = "# CHECK: " + pattern + "\n";
34   std::ostringstream oss;
35   oss << *e << "\n";
36   torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str());
37 }
38 
checkExprIR(const ExprHandle & e,const std::string & pattern)39 void checkExprIR(const ExprHandle& e, const std::string& pattern) {
40   checkExprIR(e.node(), pattern);
41 }
42 
TEST(LoopNest,ExprSimple01)43 TEST(LoopNest, ExprSimple01) {
44   Tensor tensor =
45       Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
46         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
47       });
48   LoopNest l({tensor});
49   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
50 
51   LoopNest::splitWithTail(loops[0], 2);
52   LoopNest::splitWithTail(loops[0], 2);
53 }
54 
TEST(LoopNest,ExprLower01)55 TEST(LoopNest, ExprLower01) {
56   Tensor tensor =
57       Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
58         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
59       });
60   LoopNest l({tensor});
61   StmtPtr stmt = l.root_stmt();
62   std::ostringstream oss;
63   oss << *stmt;
64   ASSERT_GT(oss.str().size(), 20);
65   ASSERT_LT(oss.str().size(), 200);
66 }
67 
TEST(LoopNest,ExprSimple02)68 TEST(LoopNest, ExprSimple02) {
69   auto func = [](const ExprHandle& x, const ExprHandle& y) {
70     return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
71   };
72   Tensor tensor = Compute("f", {26, 5}, func);
73   LoopNest l({tensor});
74   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
75 
76   LoopNest::splitWithTail(loops[0], 4);
77 
78   StmtPtr stmt = l.root_stmt();
79   std::ostringstream oss;
80   oss << *stmt;
81   ASSERT_GT(oss.str().size(), 200);
82   ASSERT_LT(oss.str().size(), 600);
83 
84   {
85     // Compare to a reference loop structure structure.
86     VarHandle x_outer("i_outer", kInt);
87     VarHandle x_inner("i_inner", kInt);
88     VarHandle y("i", kInt);
89     VarHandle x_tail("i_tail", kInt);
90     BufHandle f("f", {26, 5}, kFloat);
91     ExprHandle x_1 = x_outer * 4 + x_inner;
92     ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
93     ForPtr stmt1 = For::make(
94         x_outer,
95         0,
96         x_outer_end,
97         For::make(
98             x_inner,
99             0,
100             4,
101             For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))));
102     ExprHandle x_2 = x_tail + x_outer_end * 4;
103     ForPtr stmt2 = For::make(
104         x_tail,
105         0,
106         (ExprHandle(26) - 0) % 4,
107         For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y))));
108     StmtPtr stmt = Block::make({stmt1, stmt2});
109 
110     std::ostringstream oss_ref;
111     oss_ref << *stmt;
112     ASSERT_EQ(oss.str(), oss_ref.str());
113   }
114 
115   {
116     PaddedBuffer<float> f_v(26, 5, "f_v");
117     PaddedBuffer<float> f_ref(26, 5, "f_res");
118 
119     stmt = FlattenIndexes(stmt);
120     SimpleIREvaluator ir_eval(stmt, {tensor});
121     ir_eval(f_v);
122 
123     for (int x = 0; x < 26; x++) {
124       for (int y = 0; y < 5; y++) {
125         f_ref(x, y) = 1 + x * x + y * y;
126       }
127     }
128 
129     ExpectAllNear(f_v, f_ref, 1e-5);
130   }
131 }
132 
getSimplifiedBody(const LoopNest & l)133 BlockPtr getSimplifiedBody(const LoopNest& l) {
134   StmtPtr stmt = l.root_stmt();
135   StmtPtr simplified = IRSimplifier::simplify(stmt);
136   return to<Block>(simplified);
137 }
138 
assertForRange(ForPtr f,int expected_start,int expected_stop)139 void assertForRange(ForPtr f, int expected_start, int expected_stop) {
140   ASSERT_NE(f, nullptr);
141   IntImmPtr start = to<IntImm>(f->start());
142   ASSERT_NE(start, nullptr);
143   ASSERT_EQ(start->value(), expected_start);
144   IntImmPtr stop = to<IntImm>(f->stop());
145   ASSERT_NE(stop, nullptr);
146   ASSERT_EQ(stop->value(), expected_stop);
147 }
148 
assertForRanges(BlockPtr body,const std::vector<std::pair<int,int>> & start_stops)149 void assertForRanges(
150     BlockPtr body,
151     const std::vector<std::pair<int, int>>& start_stops) {
152   ASSERT_EQ(body->nstmts(), start_stops.size());
153 
154   auto it = body->begin();
155   for (size_t i = 0; i < start_stops.size(); i++, it++) {
156     ForPtr loop = to<For>(*it);
157     assertForRange(loop, start_stops[i].first, start_stops[i].second);
158   }
159 }
160 
TEST(LoopNest,ExprSliceHeadWithLoopOptions)161 TEST(LoopNest, ExprSliceHeadWithLoopOptions) {
162   auto func = [](const ExprHandle& x) {
163     return ExprHandle(1.0f) + cast<float>(x);
164   };
165   Tensor tensor = Compute("f", {10}, func);
166   LoopNest l({tensor});
167   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
168   ForPtr head;
169   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
170   ForPtr tail;
171   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
172   loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
173   LoopNest::sliceHead(loops[0], 2, &head, &tail);
174 
175   BlockPtr body = getSimplifiedBody(l);
176   assertForRanges(body, {{0, 2}, {0, 8}});
177 
178   ASSERT_TRUE(tail->loop_options().is_gpu_block_index());
179   ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
180 
181   ASSERT_TRUE(head->loop_options().isDefault());
182 }
183 
TEST(LoopNest,ExprSliceTailWithLoopOptions)184 TEST(LoopNest, ExprSliceTailWithLoopOptions) {
185   auto func = [](const ExprHandle& x) {
186     return ExprHandle(1.0f) + cast<float>(x);
187   };
188   Tensor tensor = Compute("f", {10}, func);
189   LoopNest l({tensor});
190   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
191   ForPtr head;
192   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
193   ForPtr tail;
194   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
195   LoopNest::sliceTail(loops[0], 4, &head, &tail);
196 
197   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198   ForPtr tail_head;
199   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
200   ForPtr tail_tail;
201   tail->set_gpu_block_index(LoopOptions::IDX_Y);
202   LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail);
203 
204   BlockPtr body = getSimplifiedBody(l);
205   assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}});
206 
207   ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index());
208   ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
209 
210   ASSERT_TRUE(head->loop_options().isDefault());
211   ASSERT_TRUE(tail_tail->loop_options().isDefault());
212 }
213 
TEST(LoopNest,ExprSliceHeadWhenFactorEqualsSize)214 TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) {
215   // When factor equals the For loop's original size, keep using the original
216   // For loop.
217   auto func = [](const ExprHandle& x) {
218     return ExprHandle(1.0f) + cast<float>(x);
219   };
220   Tensor tensor = Compute("f", {10}, func);
221   LoopNest l({tensor});
222   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
223   ForPtr head;
224   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
225   ForPtr tail;
226   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
227   LoopNest::sliceHead(loops[0], 10, &head, &tail);
228 
229   ASSERT_EQ(head, loops[0]);
230   ASSERT_EQ(tail, nullptr);
231 
232   BlockPtr body = getSimplifiedBody(l);
233   assertForRanges(body, {{0, 10}});
234 }
235 
TEST(LoopNest,ExprSliceHeadWhenFactorLargerThanSize)236 TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) {
237   auto func = [](const ExprHandle& x) {
238     return ExprHandle(1.0f) + cast<float>(x);
239   };
240   Tensor tensor = Compute("f", {10}, func);
241   LoopNest l({tensor});
242   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
243   ForPtr head;
244   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
245   ForPtr tail;
246   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
247   LoopNest::sliceHead(loops[0], 100, &head, &tail);
248 
249   ASSERT_EQ(head, loops[0]);
250   ASSERT_EQ(tail, nullptr);
251 
252   BlockPtr body = getSimplifiedBody(l);
253   assertForRanges(body, {{0, 10}});
254 }
255 
TEST(LoopNest,ExprSliceHead)256 TEST(LoopNest, ExprSliceHead) {
257   auto func = [](const ExprHandle& x) {
258     return ExprHandle(1.0f) + cast<float>(x);
259   };
260   Tensor tensor = Compute("f", {10}, func);
261   LoopNest l({tensor});
262   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
263   ForPtr head;
264   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
265   ForPtr tail;
266   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
267   LoopNest::sliceHead(loops[0], 4, &head, &tail);
268 
269   ASSERT_NE(head, nullptr);
270   ASSERT_NE(head, loops[0]);
271   ASSERT_NE(tail, nullptr);
272   ASSERT_EQ(tail, loops[0]);
273 
274   BlockPtr body = getSimplifiedBody(l);
275   assertForRanges(body, {{0, 4}, {4, 10}});
276 }
277 
TEST(LoopNest,ExprSliceHeadWithNonZeroStart)278 TEST(LoopNest, ExprSliceHeadWithNonZeroStart) {
279   auto func = [](const ExprHandle& x) {
280     return ExprHandle(1.0f) + cast<float>(x);
281   };
282   Tensor tensor = Compute("f", {10}, func);
283   LoopNest l({tensor});
284   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
285 
286   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
287   ForPtr head;
288   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
289   ForPtr tail;
290   LoopNest::sliceTail(loops[0], 4, &head, &tail);
291   // head: [0, 6)
292   // tail: [6, 10)
293 
294   LoopNest::sliceHead(tail, 2);
295   // tail_head: [6, 8)
296   // tail_tail: [8, 10)
297 
298   BlockPtr body = getSimplifiedBody(l);
299   assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}});
300 }
301 
TEST(LoopNest,ExprSliceTailWhenFactorEqualsSize)302 TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) {
303   // When factor equals the For loop's original size, keep using the original
304   // For loop.
305   auto func = [](const ExprHandle& x) {
306     return ExprHandle(1.0f) + cast<float>(x);
307   };
308   Tensor tensor = Compute("f", {10}, func);
309   LoopNest l({tensor});
310   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
311   ForPtr head;
312   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
313   ForPtr tail;
314   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
315   LoopNest::sliceTail(loops[0], 10, &head, &tail);
316 
317   ASSERT_EQ(head, nullptr);
318   ASSERT_EQ(tail, loops[0]);
319 
320   BlockPtr body = getSimplifiedBody(l);
321   assertForRanges(body, {{0, 10}});
322 }
323 
TEST(LoopNest,ExprSliceTailWhenFactorLargerThanSize)324 TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) {
325   // When factor equals the For loop's original size, keep using the original
326   // For loop.
327   auto func = [](const ExprHandle& x) {
328     return ExprHandle(1.0f) + cast<float>(x);
329   };
330   Tensor tensor = Compute("f", {10}, func);
331   LoopNest l({tensor});
332   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
333   ForPtr head;
334   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
335   ForPtr tail;
336   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
337   LoopNest::sliceTail(loops[0], 100, &head, &tail);
338 
339   ASSERT_EQ(head, nullptr);
340   ASSERT_EQ(tail, loops[0]);
341 
342   BlockPtr body = getSimplifiedBody(l);
343   assertForRanges(body, {{0, 10}});
344 }
345 
TEST(LoopNest,ExprSliceTail)346 TEST(LoopNest, ExprSliceTail) {
347   auto func = [](const ExprHandle& x) {
348     return ExprHandle(1.0f) + cast<float>(x);
349   };
350   Tensor tensor = Compute("f", {10}, func);
351   LoopNest l({tensor});
352   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
353   ForPtr head;
354   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
355   ForPtr tail;
356   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
357   LoopNest::sliceTail(loops[0], 4, &head, &tail);
358 
359   ASSERT_NE(head, nullptr);
360   ASSERT_EQ(head, loops[0]);
361   ASSERT_NE(tail, nullptr);
362   ASSERT_NE(tail, loops[0]);
363 
364   BlockPtr body = getSimplifiedBody(l);
365   assertForRanges(body, {{0, 6}, {6, 10}});
366 }
367 
TEST(LoopNest,ExprSplitAndSlice)368 TEST(LoopNest, ExprSplitAndSlice) {
369   // 0: splitWithTail
370   // 1: sliceTail on inner loop
371   // 2: sliceHead on outer loop
372   auto func = [](const ExprHandle& x) {
373     return ExprHandle(1.0f) + cast<float>(x);
374   };
375   Tensor tensor = Compute("f", {100}, func);
376   LoopNest l({tensor});
377 
378   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
379   ForPtr inner;
380   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
381   ForPtr tail;
382   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
383   // outer: [0, 4)
384   // inner: [0, 21)
385   // tail:  [84, 100)
386   LoopNest::splitWithTail(loops[0], 21, &inner, &tail);
387   LoopNest::sliceTail(inner, 2);
388   LoopNest::sliceHead(loops[0], 2);
389 
390   // for (int x_outer = 0; x_outer < 2; x_outer++) {
391   //   for (int x_inner = 0; x_inner < 19; x_inner++) {
392   //     f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
393   //   }
394   //   for (int x_inner = 19; x_inner < 21; x_inner++) {
395   //     f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
396   //   }
397   // }
398   // for (int x_outer = 2; x_outer < 4; x_outer++) {
399   //   for (int x_inner = 0; x_inner < 19; x_inner++) {
400   //     f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
401   //   }
402   //   for (int x_inner = 19; x_inner < 21; x_inner++) {
403   //     f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
404   //   }
405   // }
406   // for (int x_tail = 0; x_tail < 16; x_tail++) {
407   //   f[x_tail + 84] = 1.f + float(x_tail + 84);
408   // }
409   BlockPtr body = getSimplifiedBody(l);
410   assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}});
411 
412   auto biter = body->begin();
413 
414   ForPtr loop = to<For>(*biter++);
415   assertForRanges(loop->body(), {{0, 19}, {19, 21}});
416 
417   loop = to<For>(*biter);
418   assertForRanges(loop->body(), {{0, 19}, {19, 21}});
419 }
420 
TEST(LoopNest,ExprSliceAndNormalize)421 TEST(LoopNest, ExprSliceAndNormalize) {
422   // 0: sliceHead
423   // 1: normalize tail
424   auto func = [](const ExprHandle& x) {
425     return ExprHandle(1.0f) + cast<float>(x);
426   };
427   Tensor tensor = Compute("f", {10}, func);
428   LoopNest l({tensor});
429   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
430 
431   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
432   ForPtr head;
433   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
434   ForPtr tail;
435   LoopNest::sliceHead(loops[0], 2, &head, &tail);
436   // head: [0, 2)
437   // tail: [2, 10)
438 
439   LoopNest::normalize(tail);
440   // normalized_tail: [0, 8)
441 
442   BlockPtr body = getSimplifiedBody(l);
443   assertForRanges(body, {{0, 2}, {0, 8}});
444 }
445 
446 template <typename T>
evalExpr(const ExprHandle & expr,const VarHandle & var,T value)447 T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) {
448   ExprEval<SimpleIREvaluator> eval(expr, {var});
449   return eval.value<T>(value);
450 }
451 
TEST(LoopNest,ExprSliceWithVariableDimension)452 TEST(LoopNest, ExprSliceWithVariableDimension) {
453   auto testWithDimension =
454       [](int dimension,
455          const std::vector<std::pair<int, int>>& expected_for_ranges) {
456         VarHandle dim("dim", kInt);
457         Tensor tensor =
458             Compute("f", {dim}, [](const ExprHandle& x) { return x; });
459         LoopNest l({tensor});
460         std::vector<ForPtr> loops =
461             l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
462 
463         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
464         ForPtr head;
465         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
466         ForPtr tail;
467         LoopNest::sliceHead(loops[0], 2, &head, &tail);
468 
469         LoopNest::sliceTail(tail, 2);
470 
471         BlockPtr body = getSimplifiedBody(l);
472         ASSERT_EQ(expected_for_ranges.size(), 3);
473         auto it = body->begin();
474         for (auto& start_stop : expected_for_ranges) {
475           ForPtr loop = to<For>(*it++);
476           int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension);
477           int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension);
478           ASSERT_EQ(start, start_stop.first);
479           ASSERT_EQ(stop, start_stop.second);
480         }
481       };
482 
483   testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}});
484   testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}});
485   testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}});
486   testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}});
487   testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}});
488   testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}});
489 }
490 
TEST(LoopNest,ExprSplitWithTail)491 TEST(LoopNest, ExprSplitWithTail) {
492   auto func = [](const ExprHandle& x) {
493     return ExprHandle(1.0f) + cast<float>(x);
494   };
495   Tensor tensor = Compute("f", {199}, func);
496   LoopNest l({tensor});
497   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
498   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
499   LoopNest::splitWithTail(loops[0], 17);
500   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
501   LoopNest::splitWithTail(loops[0], 7);
502 
503   StmtPtr stmt = l.root_stmt();
504   StmtPtr simplified = IRSimplifier::simplify(stmt);
505   BlockPtr body = to<Block>(simplified);
506   ASSERT_EQ(body->nstmts(), 3);
507   auto biter = body->begin();
508 
509   // Verify that the split loops are ordered correctly.
510   ForPtr loop = to<For>(*biter++);
511   assertForRange(loop, 0, 7);
512 
513   loop = to<For>(*biter++);
514   assertForRange(loop, 0, 4);
515 
516   loop = to<For>(*biter);
517   assertForRange(loop, 0, 12);
518 }
519 
TEST(LoopNest,ExprSplitWithTailNone)520 TEST(LoopNest, ExprSplitWithTailNone) {
521   auto func = [](const ExprHandle& x, const ExprHandle& y) {
522     return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
523   };
524   Tensor tensor = Compute("f", {24, 5}, func);
525   LoopNest l({tensor});
526   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
527   LoopNest::splitWithTail(loops[0], 4);
528 
529   StmtPtr stmt = l.root_stmt();
530   std::ostringstream oss;
531   oss << *stmt;
532   ASSERT_GT(oss.str().size(), 200);
533   ASSERT_LT(oss.str().size(), 600);
534 
535   {
536     // Compare to a reference loop structure structure.
537     VarHandle x_outer("i_outer", kInt);
538     VarHandle x_inner("i_inner", kInt);
539     VarHandle y("i", kInt);
540     VarHandle x_tail("i_tail", kInt);
541     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
542     BufHandle f("f", {24, 5}, kFloat);
543     ExprHandle x_1 = x_outer * 4 + x_inner;
544     ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
545     StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make(
546         x_outer,
547         0,
548         x_outer_end,
549         For::make(
550             x_inner,
551             0,
552             4,
553             For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))}));
554 
555     std::ostringstream oss_ref;
556     oss_ref << *stmt;
557     ASSERT_EQ(oss.str(), oss_ref.str());
558   }
559 
560   {
561     PaddedBuffer<float> f_v(24, 5, "f_v");
562     PaddedBuffer<float> f_ref(24, 5, "f_res");
563 
564     SimpleIREvaluator ir_eval(stmt, {tensor});
565     ir_eval(f_v);
566 
567     for (int x = 0; x < 24; x++) {
568       for (int y = 0; y < 5; y++) {
569         f_ref(x, y) = 1 + x * x + y * y;
570       }
571     }
572 
573     ExpectAllNear(f_v, f_ref, 1e-5);
574   }
575 }
576 
TEST(LoopNest,ExprSplitWithMask01)577 TEST(LoopNest, ExprSplitWithMask01) {
578   const int M = 26;
579   const int N = 5;
580   BufHandle a_buf("a", {M, N}, kFloat);
581   BufHandle b_buf("b", {M, N}, kFloat);
582   Tensor tensor =
583       Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
584         return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f;
585       });
586 
587   LoopNest l({tensor});
588   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
589   LoopNest::splitWithMask(loops[1], 4);
590 
591   StmtPtr stmt = l.root_stmt();
592 
593   PaddedBuffer<float> a_v(M, N, "a");
594   PaddedBuffer<float> b_v(M, N, "b");
595   PaddedBuffer<float> c_v(M, N, "c");
596   PaddedBuffer<float> c_ref(M, N, "c_ref");
597   for (int m = 0; m < M; m++) {
598     for (int n = 0; n < N; n++) {
599       a_v(m, n) = 2 * m;
600       b_v(m, n) = 3 * n;
601       c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
602     }
603   }
604 
605   SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
606 
607   ExpectAllNear(c_v, c_ref, 1e-5);
608 }
609 
610 // Tests the case where we split a loop cleanly multiple times, we should not
611 // insert any masks.
TEST(LoopNest,ExprSplitWithMaskRepeatedNoMask)612 TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) {
613   const int M = 64;
614   BufHandle a_buf("a", {M}, kFloat);
615   BufHandle b_buf("b", {M}, kFloat);
616   Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
617     return a_buf.load(m) + b_buf.load(m) + 1.0f;
618   });
619 
620   LoopNest l({tensor});
621   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
622   LoopNest::splitWithMask(loops[0], 4);
623   LoopNest::splitWithMask(loops[0], 4);
624 
625   StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt());
626 
627   // Two splits mean 3 loops, but should need no masks in this case.
628   checkIR(stmt1, R"IR(
629 # CHECK: for (
630 # CHECK-NOT: if (
631 # CHECK:   for (
632 # CHECK-NOT: if (
633 # CHECK:     for (
634 # CHECK-NOT: if (
635 # CHECK:       f[)IR");
636 }
637 
TEST(LoopNest,getLoopAt)638 TEST(LoopNest, getLoopAt) {
639   // Input IR:
640   //  for (int i = 0; i < 100; i++) {
641   //    for (int j = 0; j < 100; j++) {
642   //      A[i, j] = sin(i * j);
643   //      for (int k1 = 0; k1 < 200; k1++) {
644   //        B[i, j, k1] = (A[i, j]) / (k1 + 1);
645   //      }
646   //      for (int k2 = 0; k2 < 300; k2++) {
647   //        C[i, j, k2] = (A[i, j]) * (k2 + 1);
648   //      }
649   //    }
650   //  }
651   BufPtr A = alloc<Buf>(
652       "A",
653       std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}),
654       kInt);
655   BufPtr B = alloc<Buf>(
656       "B",
657       std::vector<ExprPtr>(
658           {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}),
659       kInt);
660   BufPtr C = alloc<Buf>(
661       "C",
662       std::vector<ExprPtr>(
663           {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}),
664       kInt);
665   BufHandle a_buf(A);
666   BufHandle b_buf(B);
667   BufHandle c_buf(C);
668   VarHandle i("i", kInt);
669   VarHandle j("j", kInt);
670   VarHandle k1("k1", kInt);
671   VarHandle k2("k2", kInt);
672   auto store1 = Store::make(a_buf, {i, j}, sin(i * j));
673   auto store2 = Store::make(
674       b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1)));
675   auto store3 = Store::make(
676       c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1)));
677   auto for_k2 = For::make(k2, 0, 300, Block::make({store3}));
678   auto for_k1 = For::make(k1, 0, 200, Block::make({store2}));
679   auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2}));
680   auto for_i = For::make(i, 0, 100, for_j);
681   LoopNest l(Block::make({for_i}), {B, C});
682   auto ret_k2 = l.getLoopAt(for_i, {0, 2});
683   TORCH_CHECK(ret_k2 == for_k2);
684 
685   std::ostringstream oss;
686   oss << *ret_k2;
687   const std::string& verification_pattern =
688       R"IR(
689 # CHECK: for (int k2
690 # CHECK-NEXT: C[i, j, k2] =
691       )IR";
692   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
693 }
694 
TEST(LoopNest,TileSimple)695 TEST(LoopNest, TileSimple) {
696   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
697   const int M = 64, N = 64;
698   BufHandle a_buf("a", {M, N}, kFloat);
699   BufHandle b_buf("b", {M, N}, kFloat);
700   Tensor tensor =
701       Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
702         return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
703       });
704 
705   LoopNest l({tensor});
706   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
707   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
708   l.tile(loops[0], loops[1], 4, 8);
709 
710   // IR check
711   StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
712   checkIR(stmt, R"IR(
713 # CHECK: for (int i_outer
714 # CHECK:   for (int i_outer_1
715 # CHECK:     for (int i_inner
716 # CHECK:       for (int i_inner_1
717 # CHECK:         f[
718 # CHECK-NOT:     for (int i_tail
719 # CHECK-NOT: for (int i_tail)IR");
720 
721   // Correctness check
722   PaddedBuffer<float> a_v(M, N, "a");
723   PaddedBuffer<float> b_v(M, N, "b");
724   PaddedBuffer<float> c_v(M, N, "c");
725   PaddedBuffer<float> c_ref(M, N, "c_ref");
726   for (int m = 0; m < M; m++) {
727     for (int n = 0; n < N; n++) {
728       a_v(m, n) = 2 * m;
729       b_v(m, n) = 3 * n;
730       c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
731     }
732   }
733 
734   SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
735 
736   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
737   ExpectAllNear(c_v, c_ref, 1e-5);
738 }
739 
TEST(LoopNest,TileWithTails)740 TEST(LoopNest, TileWithTails) {
741   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
742   const int M = 64, N = 64;
743   BufHandle a_buf("a", {M, N}, kFloat);
744   BufHandle b_buf("b", {M, N}, kFloat);
745   Tensor tensor =
746       Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
747         return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
748       });
749 
750   LoopNest l({tensor});
751   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
752   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
753   l.tile(loops[0], loops[1], 5, 9);
754 
755   // IR check
756   StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
757   checkIR(stmt, R"IR(
758 # CHECK: for (int i_outer
759 # CHECK:   for (int i_outer_1
760 # CHECK:     for (int i_inner
761 # CHECK:       for (int i_inner_1
762 # CHECK:         f[
763 # CHECK:   for (int i_inner
764 # CHECK:     f[
765 # CHECK: for (int i_tail)IR");
766 
767   // Correctness check
768   PaddedBuffer<float> a_v(M, N, "a");
769   PaddedBuffer<float> b_v(M, N, "b");
770   PaddedBuffer<float> c_v(M, N, "c");
771   PaddedBuffer<float> c_ref(M, N, "c_ref");
772   for (int m = 0; m < M; m++) {
773     for (int n = 0; n < N; n++) {
774       a_v(m, n) = 2 * m;
775       b_v(m, n) = 3 * n;
776       c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
777     }
778   }
779 
780   SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
781 
782   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
783   ExpectAllNear(c_v, c_ref, 1e-5);
784 }
785 
TEST(LoopNest,TileInMiddle)786 TEST(LoopNest, TileInMiddle) {
787   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
788   const int M = 8, N = 8, L = 8, K = 8;
789   BufHandle a_buf("a", {M, N, L, K}, kFloat);
790   BufHandle b_buf("b", {M, N, L, K}, kFloat);
791   Tensor tensor = Compute(
792       "f",
793       {M, N, L, K},
794       [&](const ExprHandle& m,
795           const ExprHandle& n,
796           const ExprHandle& l,
797           const ExprHandle& k) {
798         return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f;
799       });
800 
801   LoopNest nest({tensor});
802   std::vector<ForPtr> loops =
803       nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
804   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
805   nest.tile(loops[1], loops[2], 3, 3);
806 
807   // IR check
808   StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt());
809   checkIR(stmt, R"IR(
810 # CHECK: for (int i
811 # CHECK:   for (int i_outer
812 # CHECK:     for (int i_outer_1
813 # CHECK:       for (int i_inner
814 # CHECK:         for (int i_inner_1
815 # CHECK:           for (int i_1
816 # CHECK:             f[
817 # CHECK:     for (int i_tail_1
818 # CHECK:       for (int i_inner_1
819 # CHECK:         for (int i_1
820 # CHECK:           f[
821 # CHECK:   for (int i_tail)IR");
822 
823   // Correctness check
824   PaddedBuffer<float> a_v(M, N, L, K, "a");
825   PaddedBuffer<float> b_v(M, N, L, K, "b");
826   PaddedBuffer<float> c_v(M, N, L, K, "c");
827   PaddedBuffer<float> c_ref(M, N, L, K, "c_ref");
828   for (int m = 0; m < M; m++) {
829     for (int n = 0; n < N; n++) {
830       for (int l = 0; l < L; l++) {
831         for (int k = 0; k < K; k++) {
832           a_v(m, n, l, k) = 2 * (m + l);
833           b_v(m, n, l, k) = 3 * (n + k);
834           c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f;
835         }
836       }
837     }
838   }
839 
840   SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
841 
842   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
843   ExpectAllNear(c_v, c_ref, 1e-5);
844 }
845 
TEST(LoopNest,SplitWithTailWithLoopOptions)846 TEST(LoopNest, SplitWithTailWithLoopOptions) {
847   const int M = 21;
848   BufHandle a_buf("a", {M}, kFloat);
849   BufHandle b_buf("b", {M}, kFloat);
850   Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
851     return a_buf.load(m) + b_buf.load(m) + 1.0f;
852   });
853   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
854   ForPtr inner, tail;
855 
856   LoopNest l({tensor});
857   auto loops = NodeFinder<For>::find(l.root_stmt());
858   ASSERT_GT(loops.size(), 0);
859   loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
860   LoopNest::splitWithTail(loops[0], 4, &inner, &tail);
861   ASSERT_NE(inner, nullptr);
862   ASSERT_NE(tail, nullptr);
863   ForPtr outer = loops[0];
864 
865   // Outer loop carries loop axis bindings.
866   ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
867   ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
868 
869   // Inner loop has none.
870   ASSERT_TRUE(inner->loop_options().isDefault());
871 
872   // Tail loop has none.
873   ASSERT_TRUE(tail->loop_options().isDefault());
874 }
875 
TEST(LoopNest,SplitWithMaskWithLoopOptions)876 TEST(LoopNest, SplitWithMaskWithLoopOptions) {
877   const int M = 21;
878   BufHandle a_buf("a", {M}, kFloat);
879   BufHandle b_buf("b", {M}, kFloat);
880   Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
881     return a_buf.load(m) + b_buf.load(m) + 1.0f;
882   });
883   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
884   ForPtr inner;
885 
886   LoopNest l({tensor});
887   auto loops = NodeFinder<For>::find(l.root_stmt());
888   loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
889   LoopNest::splitWithMask(loops[0], 4, &inner);
890   ForPtr outer = loops[0];
891 
892   // Outer loop carries loop axis bindings.
893   ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
894   ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
895 
896   // Inner loop has none.
897   ASSERT_TRUE(inner->loop_options().isDefault());
898 }
899 
TEST(LoopNest,ScheduleBroadcastAddBuffer)900 TEST(LoopNest, ScheduleBroadcastAddBuffer) {
901   const int M = 4;
902   const int N = 5;
903   const int K = 6;
904   BufHandle a_buf("a", {M, N}, kFloat);
905   BufHandle b_buf("b", {N, K}, kFloat);
906   Tensor c = Compute(
907       "broadcast_add",
908       {M, N, K},
909       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
910         return a_buf.load(m, n) + b_buf.load(n, k);
911       });
912   LoopNest l({c});
913   StmtPtr stmt = l.root_stmt();
914 
915   PaddedBuffer<float> a_v(M, N, "a_v");
916   for (int m = 0; m < M; m++) {
917     for (int n = 0; n < N; n++) {
918       a_v(m, n) = 7 * m * n;
919     }
920   }
921   a_v.Backup();
922 
923   PaddedBuffer<float> b_v(N, K, "b_v");
924   for (int n = 0; n < N; n++) {
925     for (int k = 0; k < K; k++) {
926       b_v(n, k) = 11 * n * k;
927     }
928   }
929   b_v.Backup();
930 
931   PaddedBuffer<float> c_v(M, N, K, "c_buf");
932   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c});
933   ir_eval(a_v, b_v, c_v);
934 
935   a_v.CheckBackup();
936   b_v.CheckBackup();
937   PaddedBuffer<float> c_ref(M, N, K, "c_ref");
938   for (int m = 0; m < M; m++) {
939     for (int n = 0; n < N; n++) {
940       for (int k = 0; k < K; k++) {
941         c_ref(m, n, k) = 7 * m * n + 11 * n * k;
942       }
943     }
944   }
945   ExpectAllNear(c_v, c_ref, 1e-5);
946 }
947 
TEST(LoopNest,ScheduleFunctionCall01)948 TEST(LoopNest, ScheduleFunctionCall01) {
949   const int M = 4;
950   const int N = 5;
951   const int K = 6;
952   BufHandle a_buf("a", {M, N}, kFloat);
953   BufHandle b_buf("b", {N, K}, kFloat);
954   Tensor c = Compute(
955       "broadcast_add",
956       {M, N, K},
957       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
958         return a_buf.load(m, n) + b_buf.load(n, k);
959       });
960   Tensor d = Compute(
961       "d",
962       {M, N, K},
963       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
964         return c.load(m, n, k) + 1;
965       });
966 
967   LoopNest l({d}, {c, d});
968   l.prepareForCodegen();
969   StmtPtr stmt = l.root_stmt();
970   std::ostringstream oss;
971   oss << *stmt;
972   ASSERT_GT(oss.str().size(), 100);
973 
974   PaddedBuffer<float> a_v(M, N);
975   PaddedBuffer<float> b_v(N, K);
976   PaddedBuffer<float> c_v(M, N, K);
977   PaddedBuffer<float> d_v(M, N, K);
978   PaddedBuffer<float> d_ref(M, N, K);
979 
980   for (int i = 0; i < M; i++) {
981     for (int j = 0; j < N; j++) {
982       a_v(i, j) = i * i;
983     }
984   }
985   for (int i = 0; i < N; i++) {
986     for (int j = 0; j < K; j++) {
987       b_v(i, j) = j * j;
988     }
989   }
990   for (int i = 0; i < M; i++) {
991     for (int j = 0; j < N; j++) {
992       for (int k = 0; k < K; k++) {
993         d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1;
994       }
995     }
996   }
997 
998   SimpleIREvaluator eval(stmt, {a_buf, b_buf, d});
999   eval(a_v, b_v, d_v);
1000 
1001   ExpectAllNear(d_v, d_ref, 1e-5);
1002 }
1003 
TEST(LoopNest,ScheduleInlineSimple)1004 TEST(LoopNest, ScheduleInlineSimple) {
1005   const int M = 4;
1006   const int N = 5;
1007   const int K = 6;
1008   BufHandle a_buf("a", {M, N}, kFloat);
1009   BufHandle b_buf("b", {N, K}, kFloat);
1010   BufHandle c_buf("c", {M, N}, kFloat);
1011   BufHandle d_buf("d", {M, K}, kFloat);
1012 
1013   Tensor x = Compute(
1014       "x",
1015       {M, N, K},
1016       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1017         return a_buf.load(m, n) * b_buf.load(n, k);
1018       });
1019   Tensor y = Compute(
1020       "y",
1021       {M, N, K},
1022       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1023         return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
1024       });
1025 
1026   LoopNest l1({y}, {x, y});
1027   LoopNest l2(l1);
1028   l2.computeInline(x.buf());
1029 
1030   l1.prepareForCodegen();
1031   l2.prepareForCodegen();
1032 
1033   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1034   StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1035 
1036   SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y});
1037   SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y});
1038 
1039   PaddedBuffer<float> a_v(M, N);
1040   PaddedBuffer<float> b_v(N, K);
1041   PaddedBuffer<float> c_v(M, N);
1042   PaddedBuffer<float> d_v(M, K);
1043 
1044   for (int i = 0; i < M; i++) {
1045     for (int j = 0; j < N; j++) {
1046       a_v(i, j) = i * i;
1047     }
1048   }
1049   for (int i = 0; i < N; i++) {
1050     for (int j = 0; j < K; j++) {
1051       b_v(i, j) = j * j;
1052     }
1053   }
1054   for (int i = 0; i < M; i++) {
1055     for (int j = 0; j < N; j++) {
1056       c_v(i, j) = i + j;
1057     }
1058   }
1059   for (int i = 0; i < M; i++) {
1060     for (int j = 0; j < K; j++) {
1061       d_v(i, j) = i * j;
1062     }
1063   }
1064 
1065   PaddedBuffer<float> y_1(M, N, K);
1066   PaddedBuffer<float> y_2(M, N, K);
1067 
1068   eval1(a_v, b_v, c_v, d_v, y_1);
1069   eval2(a_v, b_v, c_v, d_v, y_2);
1070   ExpectAllNear(y_1, y_2, 1e-5);
1071   std::ostringstream oss1, oss2;
1072   oss1 << *stmt1;
1073   oss2 << *stmt2;
1074   ASSERT_GT(oss1.str().size(), oss2.str().size());
1075 }
1076 
remove_space(const std::string & str)1077 static std::string remove_space(const std::string& str) {
1078   std::string str_new = str;
1079   str_new.erase(
1080       remove_if(str_new.begin(), str_new.end(), isspace), str_new.end());
1081   return str_new;
1082 }
1083 
InlineFunc01Helper(const std::vector<std::string> & inline_order)1084 void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
1085   const int M = 4;
1086   const int N = 5;
1087   const int K = 6;
1088   BufHandle a_buf("a", {M, N}, kFloat);
1089   BufHandle b_buf("b", {N, K}, kFloat);
1090   BufHandle c_buf("c", {M, N}, kFloat);
1091   BufHandle d_buf("d", {M, K}, kFloat);
1092 
1093   Tensor x = Compute(
1094       "x",
1095       {M, N, K},
1096       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1097         return a_buf.load(m, n) * b_buf.load(n, k);
1098       });
1099   Tensor y = Compute(
1100       "y",
1101       {M, N, K},
1102       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1103         return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
1104       });
1105   Tensor z = Compute(
1106       "z",
1107       {M, N, K},
1108       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1109         return x.load(m, n, k) + y.load(m, n, k);
1110       });
1111 
1112   LoopNest l({z}, {x, y, z});
1113   for (const std::string& order : inline_order) {
1114     if (order == "x") {
1115       l.computeInline(x.buf());
1116     } else if (order == "y") {
1117       l.computeInline(y.buf());
1118     } else {
1119       throw std::runtime_error("Invalid order: " + order);
1120     }
1121   }
1122   l.prepareForCodegen();
1123   StmtPtr stmt = l.root_stmt();
1124 
1125   std::ostringstream oss;
1126   oss << *stmt;
1127   std::string str1 = remove_space(oss.str());
1128 
1129   {
1130     PaddedBuffer<float> a_v(M, N);
1131     PaddedBuffer<float> b_v(N, K);
1132     PaddedBuffer<float> c_v(M, N);
1133     PaddedBuffer<float> d_v(M, K);
1134 
1135     for (int i = 0; i < M; i++) {
1136       for (int j = 0; j < N; j++) {
1137         a_v(i, j) = i * i;
1138       }
1139     }
1140     for (int i = 0; i < N; i++) {
1141       for (int j = 0; j < K; j++) {
1142         b_v(i, j) = j * j;
1143       }
1144     }
1145     for (int i = 0; i < M; i++) {
1146       for (int j = 0; j < N; j++) {
1147         c_v(i, j) = i + j;
1148       }
1149     }
1150     for (int i = 0; i < M; i++) {
1151       for (int j = 0; j < K; j++) {
1152         d_v(i, j) = i * j;
1153       }
1154     }
1155 
1156     PaddedBuffer<float> z_v(M, N, K);
1157     PaddedBuffer<float> z_ref(M, N, K);
1158     for (int m = 0; m < M; m++) {
1159       for (int n = 0; n < N; n++) {
1160         for (int k = 0; k < K; k++) {
1161           z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
1162         }
1163       }
1164     }
1165 
1166     SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
1167     eval(a_v, b_v, c_v, d_v, z_v);
1168     ExpectAllNear(z_v, z_ref, 1e-5);
1169   }
1170 
1171   if (inline_order.size() == 2) {
1172     Tensor z2 = Compute(
1173         "z",
1174         {M, N, K},
1175         [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1176           return a_buf.load(m, n) * b_buf.load(n, k) +
1177               (c_buf.load(m, n) * d_buf.load(m, k) +
1178                a_buf.load(m, n) * b_buf.load(n, k));
1179         });
1180     LoopNest l2({z2});
1181     l2.prepareForCodegen();
1182     StmtPtr stmt2 = l2.root_stmt();
1183 
1184     std::ostringstream oss2;
1185     oss2 << *stmt2;
1186     std::string str2 = remove_space(oss2.str());
1187 
1188     ASSERT_EQ(str1, str2);
1189     ASSERT_GT(str1.size(), 100);
1190   }
1191 }
1192 
TEST(LoopNest,ScheduleInlineFunc01)1193 TEST(LoopNest, ScheduleInlineFunc01) {
1194   InlineFunc01Helper({"x", "y"});
1195   InlineFunc01Helper({"y", "x"});
1196   InlineFunc01Helper({"x"});
1197   InlineFunc01Helper({"y"});
1198   InlineFunc01Helper({});
1199 }
1200 
1201 // Make sure we cache random vars if we should.
TEST(LoopNest,ScheduleInlineRandom)1202 TEST(LoopNest, ScheduleInlineRandom) {
1203   const int M = 4;
1204   const int N = 5;
1205   const int K = 6;
1206 
1207   Tensor x = Compute(
1208       "x",
1209       {M, N, K},
1210       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1211         return Mod::make(Intrinsics::make(kRand, kInt), 5);
1212       });
1213   Tensor y = Compute(
1214       "y",
1215       {M, N, K},
1216       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1217         return x.load(m, n, k) + x.load(m, n, k);
1218       });
1219 
1220   LoopNest l1({y}, {x, y});
1221   l1.computeInline(x.buf());
1222 
1223   // would normally compare results but Rand isn't implemented in the
1224   // SimpleIREvaluator, even if we could seed it.
1225   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1226 
1227   // Check the IR we produced
1228   checkIR(stmt1, R"IR(
1229 # CHECK: for (int i = 0; i < 4; i++)
1230 # CHECK:   for (int i_1 = 0; i_1 < 5; i_1++)
1231 # CHECK:     for (int i_2 = 0; i_2 < 6; i_2++)
1232 # CHECK:       int x = rand();
1233 # CHECK:       y[i, i_1, i_2] = 2 * (x % 5);)IR");
1234 }
1235 
1236 // Make sure we don't cache random vars that are not being inlined.
TEST(LoopNest,ScheduleInlineRandomUnrelated)1237 TEST(LoopNest, ScheduleInlineRandomUnrelated) {
1238   const int M = 4;
1239   const int N = 5;
1240   const int K = 6;
1241 
1242   Tensor x = Compute(
1243       "x",
1244       {M, N, K},
1245       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1246         return m * n * k;
1247       });
1248   Tensor y = Compute(
1249       "y",
1250       {M, N, K},
1251       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1252         return x.load(m, n, k) + Intrinsics::make(kRand, kInt) +
1253             Intrinsics::make(kRand, kInt);
1254       });
1255 
1256   LoopNest l1({y}, {x, y});
1257   l1.computeInline(x.buf());
1258 
1259   // would normally compare results but Rand isn't implemented in the
1260   // SimpleIREvaluator, even if we could seed it.
1261   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1262 
1263   // Check the IR we produced
1264   checkIR(stmt1, R"IR(
1265 # CHECK: for (int i = 0; i < 4; i++)
1266 # CHECK:   for (int i_1 = 0; i_1 < 5; i_1++)
1267 # CHECK:     for (int i_2 = 0; i_2 < 6; i_2++)
1268 # CHECK:       y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR");
1269 }
1270 
1271 // Make sure we generate the right number of random values == the dimensionality
1272 // of the production tensor.
TEST(LoopNest,ScheduleInlineRandomLowerDimensions)1273 TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
1274   const int M = 4;
1275   const int N = 5;
1276   const int K = 6;
1277 
1278   Tensor x = Compute("x", {M}, [&](const VarHandle& m) {
1279     return Mod::make(Intrinsics::make(kRand, kInt), 5);
1280   });
1281   Tensor y = Compute(
1282       "y",
1283       {M, N, K},
1284       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1285         return x.load(m) + x.load(m);
1286       });
1287 
1288   LoopNest l1({y}, {x, y});
1289   l1.computeInline(x.buf());
1290 
1291   // would normally compare results but Rand isn't implemented in the
1292   // SimpleIREvaluator, even if we could seed it.
1293   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1294 
1295   // Check the IR we produced
1296   checkIR(stmt1, R"IR(
1297 # CHECK: for (int i = 0; i < 4; i++)
1298 # CHECK:   int x = rand();
1299 # CHECK:   for (int i_1 = 0; i_1 < 5; i_1++)
1300 # CHECK:     for (int i_2 = 0; i_2 < 6; i_2++)
1301 # CHECK:       y[i, i_1, i_2] = 2 * (x % 5);)IR");
1302 }
1303 
1304 // Make sure we don't screw up intrinsics thinking they're rand.
TEST(LoopNest,ScheduleInlineIntrinsics)1305 TEST(LoopNest, ScheduleInlineIntrinsics) {
1306   const int M = 4;
1307   const int N = 5;
1308   const int K = 6;
1309   BufHandle a_buf("a", {M, N}, kFloat);
1310   BufHandle b_buf("b", {N, K}, kFloat);
1311 
1312   Tensor x = Compute(
1313       "x",
1314       {M, N, K},
1315       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1316         return a_buf.load(m, n) * b_buf.load(n, k);
1317       });
1318   Tensor y = Compute(
1319       "y",
1320       {M, N, K},
1321       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1322         return Intrinsics::make(kSqrt, x.load(m, n, k));
1323       });
1324 
1325   PaddedBuffer<float> a_v(M, N);
1326   PaddedBuffer<float> b_v(N, K);
1327 
1328   for (int i = 0; i < M; i++) {
1329     for (int j = 0; j < N; j++) {
1330       a_v(i, j) = i * i;
1331     }
1332   }
1333   for (int i = 0; i < N; i++) {
1334     for (int j = 0; j < K; j++) {
1335       b_v(i, j) = j * j;
1336     }
1337   }
1338 
1339   LoopNest l1({y}, {x, y});
1340   LoopNest l2(l1);
1341   l2.computeInline(x.buf());
1342 
1343   l1.prepareForCodegen();
1344   l2.prepareForCodegen();
1345 
1346   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1347   StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1348 
1349   SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1350   SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1351 
1352   PaddedBuffer<float> y_1(M, N, K);
1353   PaddedBuffer<float> y_2(M, N, K);
1354 
1355   eval1(a_v, b_v, y_1);
1356   eval2(a_v, b_v, y_2);
1357   ExpectAllNear(y_1, y_2, 1e-5);
1358   std::ostringstream oss1, oss2;
1359   oss1 << *stmt1;
1360   oss2 << *stmt2;
1361   ASSERT_GT(oss1.str().size(), oss2.str().size());
1362 }
1363 
1364 // Make sure we can handle rand and non-rand intrinsics.
TEST(LoopNest,ScheduleInlineRandWithIntrinsics)1365 TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
1366   const int M = 4;
1367   const int N = 5;
1368   const int K = 6;
1369 
1370   Tensor x = Compute(
1371       "x",
1372       {M, N, K},
1373       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1374         return Intrinsics::make(kRand, kFloat);
1375       });
1376   Tensor y = Compute(
1377       "y",
1378       {M, N, K},
1379       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1380         return Intrinsics::make(kSqrt, x.load(m, n, k));
1381       });
1382 
1383   LoopNest l1({y}, {x, y});
1384   l1.computeInline(x.buf());
1385 
1386   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1387 
1388   // Check the IR we produced
1389   checkIR(stmt1, R"IR(
1390 # CHECK: for (int i = 0; i < 4; i++)
1391 # CHECK:   for (int i_1 = 0; i_1 < 5; i_1++)
1392 # CHECK:     for (int i_2 = 0; i_2 < 6; i_2++)
1393 # CHECK:       float x = rand();
1394 # CHECK:       y[i, i_1, i_2] = sqrt(x);)IR");
1395 }
1396 
1397 // Split a Compute then inline it into another compute.
TEST(LoopNest,ScheduleSplitAThenInline)1398 TEST(LoopNest, ScheduleSplitAThenInline) {
1399   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1400   Tensor b = Compute(
1401       "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1402 
1403   LoopNest l({b}, {a, b});
1404   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1405   LoopNest::splitWithMask(loops[0], 4);
1406   ASSERT_FALSE(l.computeInline(a.buf()));
1407 }
1408 
1409 // Split a Compute then inline another Compute into it.
TEST(LoopNest,ScheduleSplitBThenInline)1410 TEST(LoopNest, ScheduleSplitBThenInline) {
1411   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1412   Tensor b = Compute(
1413       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1414 
1415   LoopNest l({b}, {a, b});
1416   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
1417   LoopNest::splitWithMask(loops[0], 3);
1418   l.computeInline(a.buf());
1419   l.prepareForCodegen();
1420   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1421 
1422   std::vector<int> output(6, 0);
1423   SimpleIREvaluator eval(s, {b});
1424   eval(output);
1425 
1426   for (int i = 0; i < 6; ++i) {
1427     ASSERT_EQ(output[i], (i + 8) * (i + 8));
1428   }
1429 }
1430 
1431 // Split a Compute twice then inline it.
TEST(LoopNest,ScheduleSplitTwiceThenInline)1432 TEST(LoopNest, ScheduleSplitTwiceThenInline) {
1433   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1434   Tensor b = Compute(
1435       "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1436   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1437   ForPtr i_inner;
1438 
1439   LoopNest l({b}, {a, b});
1440   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1441   LoopNest::splitWithMask(loops[0], 4, &i_inner);
1442   LoopNest::splitWithMask(i_inner, 2);
1443   ASSERT_FALSE(l.computeInline(a.buf()));
1444 }
1445 
1446 // Inline a Compute, then split.
TEST(LoopNest,ScheduleInlineThenSplit)1447 TEST(LoopNest, ScheduleInlineThenSplit) {
1448   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1449   Tensor b = Compute(
1450       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1451 
1452   LoopNest l({b}, {a, b});
1453   l.computeInline(a.buf());
1454 
1455   std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
1456   LoopNest::splitWithMask(loops.back(), 3);
1457   l.prepareForCodegen();
1458   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1459   std::vector<int> output(6, 0);
1460   SimpleIREvaluator eval(s, {b});
1461   eval(output);
1462 
1463   for (int i = 0; i < 6; ++i) {
1464     ASSERT_EQ(output[i], (i + 8) * (i + 8));
1465   }
1466 }
1467 
1468 // Split a Compute, inline it, then split the result.
TEST(LoopNest,ScheduleSplitInlineThenSplit)1469 TEST(LoopNest, ScheduleSplitInlineThenSplit) {
1470   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1471   Tensor b = Compute(
1472       "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1473 
1474   LoopNest l({b}, {a, b});
1475   auto loops = NodeFinder<For>::find(l.root_stmt());
1476   LoopNest::splitWithMask(loops.back(), 2);
1477   l.computeInline(a.buf());
1478 
1479   loops = NodeFinder<For>::find(l.root_stmt());
1480   LoopNest::splitWithMask(loops.front(), 2);
1481   l.prepareForCodegen();
1482   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1483   std::vector<int> output(16, 0);
1484   SimpleIREvaluator eval(s, {b});
1485   eval(output);
1486 
1487   for (int i = 0; i < 16; ++i) {
1488     ASSERT_EQ(output[i], (i + 8) * (i + 8));
1489   }
1490 }
1491 
1492 // Oversplit a loop that is simplified out after inlining.
TEST(LoopNest,ScheduleSplitInlineSimplify)1493 TEST(LoopNest, ScheduleSplitInlineSimplify) {
1494   Tensor a = Compute("a", {18}, [&](const VarHandle& i) {
1495     return ExprHandle(4) * i - ExprHandle(2) * i;
1496   });
1497   Tensor b = Compute(
1498       "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); });
1499 
1500   LoopNest l({b}, {a, b});
1501   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1502   LoopNest::splitWithMask(loops[0], 4);
1503   ASSERT_FALSE(l.computeInline(a.buf()));
1504 }
1505 
1506 // Inline a Compute with two consumers.
TEST(LoopNest,ScheduleInlineThreeMixedOnce)1507 TEST(LoopNest, ScheduleInlineThreeMixedOnce) {
1508   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1509   Tensor b = Compute(
1510       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1511   Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1512     return a.load(k) * b.load(l);
1513   });
1514 
1515   LoopNest l({c}, {a, b, c});
1516   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1517   l.computeInline(a.buf());
1518   l.prepareForCodegen();
1519 
1520   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1521   std::vector<int> output(4 * 3, 0);
1522   SimpleIREvaluator eval(s, {c});
1523   eval(output);
1524 
1525   for (int k = 0; k < 4; ++k) {
1526     for (int l = 0; l < 3; ++l) {
1527       ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1528     }
1529   }
1530 }
1531 
1532 // Inline Compute A into B, then inline B into C.
TEST(LoopNest,ScheduleInlineThreeMixedTwice)1533 TEST(LoopNest, ScheduleInlineThreeMixedTwice) {
1534   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1535   Tensor b = Compute(
1536       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1537   Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1538     return a.load(k) * b.load(l);
1539   });
1540 
1541   LoopNest l({c}, {a, b, c});
1542   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1543   l.computeInline(a.buf());
1544   l.computeInline(b.buf());
1545   l.prepareForCodegen();
1546 
1547   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1548   std::vector<int> output(4 * 3, 0);
1549   SimpleIREvaluator eval(s, {c});
1550   eval(output);
1551 
1552   for (int k = 0; k < 4; ++k) {
1553     for (int l = 0; l < 3; ++l) {
1554       ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1555     }
1556   }
1557 }
1558 
1559 // Inline a Compute that is both a producer and consumer.
TEST(LoopNest,ScheduleInlineThreeMixedInner)1560 TEST(LoopNest, ScheduleInlineThreeMixedInner) {
1561   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1562   Tensor b = Compute(
1563       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1564   Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1565     return a.load(k) * b.load(l);
1566   });
1567 
1568   LoopNest l({c}, {a, b, c});
1569   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1570   l.computeInline(b.buf());
1571   l.prepareForCodegen();
1572 
1573   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1574   std::vector<int> output(4 * 3, 0);
1575   SimpleIREvaluator eval(s, {c});
1576   eval(output);
1577 
1578   for (int k = 0; k < 4; ++k) {
1579     for (int l = 0; l < 3; ++l) {
1580       ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1581     }
1582   }
1583 }
1584 
1585 // Split 3 Computes, then inline the first two into the last.
TEST(LoopNest,ScheduleInlineThreeMixedSplit)1586 TEST(LoopNest, ScheduleInlineThreeMixedSplit) {
1587   Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1588   Tensor b = Compute(
1589       "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1590   Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1591     return a.load(k) * b.load(l);
1592   });
1593 
1594   LoopNest l({c}, {a, b, c});
1595   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1596   LoopNest::splitWithMask(loops[0], 4);
1597   loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
1598   LoopNest::splitWithMask(loops[0], 3);
1599   loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1600   LoopNest::splitWithMask(loops[0], 2);
1601 
1602   ASSERT_FALSE(l.computeInline(a.buf()));
1603 }
1604 
1605 // Check that inlining works for output tensors too
TEST(LoopNest,ScheduleInlineOutputTensors)1606 TEST(LoopNest, ScheduleInlineOutputTensors) {
1607   const int M = 4;
1608   const int N = 5;
1609   const int K = 6;
1610 
1611   Tensor x = Compute(
1612       "x",
1613       {M, N, K},
1614       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1615         return m * n * k;
1616       });
1617   Tensor y = Compute(
1618       "y",
1619       {M, N, K},
1620       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1621         return x.load(m, n, k) + m;
1622       });
1623 
1624   LoopNest l1({x, y});
1625   l1.computeInline(x.buf());
1626 
1627   // would normally compare results but Rand isn't implemented in the
1628   // SimpleIREvaluator, even if we could seed it.
1629   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1630 
1631   // Check the IR we produced
1632   checkIR(stmt1, R"IR(
1633 # CHECK: for (int i = 0; i < 4; i++)
1634 # CHECK:   for (int i_1 = 0; i_1 < 5; i_1++)
1635 # CHECK:     for (int i_2 = 0; i_2 < 6; i_2++)
1636 # CHECK:       x[i, i_1, i_2] = (i * i_1) * i_2;
1637 # CHECK: for (int i_3 = 0; i_3 < 4; i_3++)
1638 # CHECK:   for (int i_4 = 0; i_4 < 5; i_4++)
1639 # CHECK:     for (int i_5 = 0; i_5 < 6; i_5++)
1640 # CHECK:       y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR");
1641 }
1642 
TEST(LoopNest,ScheduleInlineWithCompoundIndices)1643 TEST(LoopNest, ScheduleInlineWithCompoundIndices) {
1644   // Input IR:
1645   //     for (int64_t i = 0; i < 100; i++) {
1646   //       A[i*2,i] = i * 500ll;
1647   //     }
1648   //     for (int64_t j = 0; j < 100; j++) {
1649   //       B[0ll,j] = A[0, j] + j * 100ll;
1650   //     }
1651   BufHandle a_buf("A", {20, 100}, kLong);
1652   BufHandle b_buf("B", {20, 100}, kLong);
1653   VarHandle i("i", kLong);
1654   VarHandle j("j", kLong);
1655   auto forI = For::make(
1656       i,
1657       0,
1658       100,
1659       Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast<int64_t>(500))));
1660   auto forJ = For::make(
1661       j,
1662       0,
1663       100,
1664       Store::make(
1665           b_buf,
1666           {static_cast<int64_t>(0), j},
1667           Add::make(
1668               Load::make(a_buf, {static_cast<int64_t>(0), j}),
1669               Mul::make(j, static_cast<int64_t>(100)))));
1670   auto par = Block::make({forI, forJ});
1671 
1672   LoopNest l(par, {b_buf.node()});
1673   // Inlining should fail since the producer has compound expr as index.
1674   ASSERT_FALSE(l.computeInline(a_buf.node()));
1675 
1676   // The input statement must remain as is.
1677   checkIR(l.root_stmt(), R"IR(
1678     # CHECK: for (int64_t i = 0;
1679     # CHECK-NEXT:   A[
1680     # CHECK: for (int64_t j = 0;
1681     # CHECK-NEXT:   B[)IR");
1682 }
1683 
TEST(LoopNest,ScheduleInlineConsumerIndicesWithCast)1684 TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) {
1685   // Input IR:
1686   //     for (int64_t i = 0; i < 100; i++) {
1687   //       A[0ll,i] = i * 500ll;
1688   //     }
1689   //     for (int64_t j = 0; j < 100; j++) {
1690   //       B[0ll,j] = A[(int64_t)0, j] + j * 100ll;
1691   //     }
1692   BufHandle a_buf("A", {20, 100}, kLong);
1693   BufHandle b_buf("B", {20, 100}, kLong);
1694   VarHandle i("i", kLong);
1695   VarHandle j("j", kLong);
1696   auto forI = For::make(
1697       i,
1698       0,
1699       100,
1700       Store::make(
1701           a_buf,
1702           {static_cast<int64_t>(0), i},
1703           Mul::make(i, static_cast<int64_t>(500))));
1704   auto forJ = For::make(
1705       j,
1706       0,
1707       100,
1708       Store::make(
1709           b_buf,
1710           {static_cast<int64_t>(0), j},
1711           Add::make(
1712               Load::make(a_buf, {0, j}),
1713               Mul::make(j, static_cast<int64_t>(100)))));
1714   auto par = Block::make({forI, forJ});
1715 
1716   LoopNest l(par, {b_buf.node()});
1717   ASSERT_TRUE(l.computeInline(a_buf.node()));
1718 
1719   checkIR(l.root_stmt(), R"IR(
1720     # CHECK: for (int64_t j = 0; j < 100; j++) {
1721     # CHECK:   B[0ll, j] = j * 500ll + j * 100ll;
1722     # CHECK: })IR");
1723 }
1724 
TEST(LoopNest,ScheduleInlineProducerIndicesWithCast)1725 TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) {
1726   // Input IR:
1727   //     for (int64_t i = 0; i < 100; i++) {
1728   //       A[(int64_t)0,i] = i * 500ll;
1729   //     }
1730   //     for (int64_t j = 0; j < 100; j++) {
1731   //       B[0ll,j] = A[0ll, j] + j * 100ll;
1732   //     }
1733   BufHandle a_buf("A", {20, 100}, kLong);
1734   BufHandle b_buf("B", {20, 100}, kLong);
1735   VarHandle i("i", kLong);
1736   VarHandle j("j", kLong);
1737   auto forI = For::make(
1738       i,
1739       0,
1740       100,
1741       Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500))));
1742   auto forJ = For::make(
1743       j,
1744       0,
1745       100,
1746       Store::make(
1747           b_buf,
1748           {static_cast<int64_t>(0), j},
1749           Add::make(
1750               Load::make(a_buf, {static_cast<int64_t>(0), j}),
1751               Mul::make(j, static_cast<int64_t>(100)))));
1752   auto par = Block::make({forI, forJ});
1753 
1754   LoopNest l(par, {b_buf.node()});
1755   ASSERT_TRUE(l.computeInline(a_buf.node()));
1756 
1757   checkIR(l.root_stmt(), R"IR(
1758     # CHECK: for (int64_t j = 0; j < 100; j++) {
1759     # CHECK:   B[0ll, j] = j * 500ll + j * 100ll;
1760     # CHECK: })IR");
1761 }
1762 
TEST(LoopNest,ScheduleFuserStyle)1763 TEST(LoopNest, ScheduleFuserStyle) {
1764   const int kVectorSize = 8;
1765   const int kVectorCount = 128;
1766   const int kTotalSize = kVectorSize * kVectorCount;
1767 
1768   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
1769 
1770   Tensor b =
1771       Compute("f", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
1772         return a_buf.load(axes[0]) + 11.0f;
1773       });
1774 
1775   Tensor c =
1776       Compute("g", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
1777         return b.load(axes[0]) + 1.0f;
1778       });
1779 
1780   LoopNest l({b, c});
1781   l.prepareForCodegen();
1782   StmtPtr s = l.root_stmt();
1783 
1784   std::vector<float> a_data(kTotalSize, 7.0f);
1785   std::vector<float> b_data(kTotalSize, 0.0f);
1786   std::vector<float> c_data(kTotalSize, 0.0f);
1787   SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data);
1788 
1789   for (int i = 0; i < kTotalSize; i++) {
1790     ASSERT_EQ(b_data[i], 18.0f);
1791     ASSERT_EQ(c_data[i], 19.0f);
1792   }
1793 }
1794 
TEST(LoopNest,ScheduleFuserThreeArg)1795 TEST(LoopNest, ScheduleFuserThreeArg) {
1796   const int kVectorSize = 8;
1797   const int kVectorCount = 128;
1798   const int kTotalSize = kVectorSize * kVectorCount;
1799 
1800   BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat);
1801   BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat);
1802   BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat);
1803   BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat);
1804 
1805   Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) {
1806     return a.load(i) + b.load(i);
1807   });
1808   Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) {
1809     return e.load(i) + c.load(i);
1810   });
1811   Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) {
1812     return f.load(i) + d.load(i);
1813   });
1814 
1815   LoopNest l({g}, {e, f, g});
1816   l.computeInline(l.getLoopBodyFor(e));
1817   l.computeInline(l.getLoopBodyFor(f));
1818   l.prepareForCodegen();
1819   StmtPtr s = l.root_stmt();
1820 
1821   std::vector<float> a_data(kTotalSize, 1.0f);
1822   std::vector<float> b_data(kTotalSize, 2.0f);
1823   std::vector<float> c_data(kTotalSize, 3.0f);
1824   std::vector<float> d_data(kTotalSize, 4.0f);
1825   std::vector<float> g_data(kTotalSize, 0.0f);
1826   SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data);
1827 
1828   for (int i = 0; i < kTotalSize; i++) {
1829     ASSERT_EQ(g_data[i], 10.0f);
1830   }
1831 }
1832 
TEST(LoopNest,ScheduleDynamicShape2D)1833 TEST(LoopNest, ScheduleDynamicShape2D) {
1834   auto testWithSize = [](int32_t M, int32_t N) {
1835     VarHandle m("m", kInt);
1836     VarHandle n("n", kInt);
1837     BufHandle a("a", {m, n}, kFloat);
1838     BufHandle b("b", {m, n}, kFloat);
1839     Tensor c =
1840         Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
1841           return a.load(i, j) + b.load(i, j);
1842         });
1843     LoopNest l({c});
1844     StmtPtr s = l.root_stmt();
1845     SimpleIREvaluator cg(s, {a, b, c, m, n});
1846     std::vector<float> aData(M * N, 1.0f);
1847     std::vector<float> bData(M * N, 2.0f);
1848     std::vector<float> cData(M * N, 0.0f);
1849     cg.call({aData, bData, cData, M, N});
1850     ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
1851   };
1852   testWithSize(1, 8);
1853   testWithSize(16, 32);
1854   testWithSize(37, 11);
1855 }
1856 
TEST(LoopNest,LoopNestComputeAt_1)1857 TEST(LoopNest, LoopNestComputeAt_1) {
1858   // Verify that compute_at works on the following example:
1859   //
1860   // for (int i_a = 0; i_a < N; i_a++) {
1861   //   A[i_a] = i_a * i_a
1862   // }
1863   // for (int i_b = 0; i_b < N; i_b++) {
1864   //   B[i_b] = A[i_b]
1865   // }
1866   //
1867   // After the transformation the i_b loop should have an allocation for a temp
1868   // buffer and that buffer should be used in computation of B. No use of A
1869   // should be in that loop after the transformation. Also, computation of A
1870   // should not be inlined into B. Instead, it should be computed into the temp,
1871   // and the temp should be used in B.
1872   VarHandle N("N", kInt);
1873   Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; });
1874   Tensor B =
1875       Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); });
1876   LoopNest l({B}, {A, B});
1877   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
1878   LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
1879   l.prepareForCodegen();
1880   SimpleIREvaluator cg(l.root_stmt(), {B, N});
1881   StmtPtr s = cg.stmt();
1882 
1883   checkIR(s, R"IR(
1884 # CHECK: Allocate(temp); // dtype=int, dims=[1]
1885 # CHECK: for (int i = 0; i < N; i++)
1886 # CHECK:   temp[
1887 # CHECK-NOT: A[
1888 # CHECK:   B[i_1] = temp[0]
1889 # CHECK:   Free(temp))IR");
1890 
1891   // Now check that the loop still produces the correct result.
1892   std::vector<int> b_data(100, 0);
1893   cg.call({b_data, 100});
1894 
1895   std::vector<int> b_ref(100, 0);
1896   for (int i = 0; i < 100; i++) {
1897     b_ref[i] = i * i;
1898   }
1899   assertAllEqual(b_data, b_ref);
1900 }
1901 
TEST(LoopNest,LoopNestComputeAt_2)1902 TEST(LoopNest, LoopNestComputeAt_2) {
1903   // Verify that compute_at works on the following example:
1904   //
1905   // for (int py = 0; py < H+1; py++) {
1906   //   for (int px = 0; px < W+1; px++) {
1907   //     p[py, px] = py*px
1908   //   }
1909   // }
1910   // for (int cy = 0; cy < H; cy++) {
1911   //   for (int cx = 0; cx < W; cx++) {
1912   //     c[py, px] = p[cy,cx]   + p[cy+1,cx] +
1913   //                 p[cy,cx+1] + p[cy+1,cx+1]
1914   //   }
1915   // }
1916 
1917   const int kW = 16, kH = 16;
1918   VarHandle W("W", kInt);
1919   VarHandle H("H", kInt);
1920   Tensor p = Compute(
1921       "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) {
1922         return px * py;
1923       });
1924   Tensor c =
1925       Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
1926         return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) +
1927             p.load(y + 1, x + 1);
1928       });
1929 
1930   std::vector<int> c_ref(kW * kH, 0);
1931   for (int y = 0; y < kH; y++) {
1932     for (int x = 0; x < kW; x++) {
1933       c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
1934     }
1935   }
1936   LoopNest orig_loopnest({c}, {p, c});
1937 
1938   {
1939     // First let's try to compute P at axis cy (the outer loop)
1940     LoopNest l(orig_loopnest);
1941     std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1942     LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
1943     l.prepareForCodegen();
1944     SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
1945     StmtPtr s = cg.stmt();
1946 
1947     // Check the IR we produced
1948     checkIR(s, R"IR(
1949 # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
1950 # CHECK: for (int i_2 = 0; i_2 < H; i_2++)
1951 # CHECK:   for
1952 # CHECK:     for
1953 # CHECK:   for (int i_3 = 0; i_3 < W; i_3++)
1954 # CHECK-NOT: prod[
1955 # CHECK:     cons[
1956 # CHECK: Free(temp))IR");
1957 
1958     // Now check that the loop still produces the correct result.
1959     std::vector<int> c_data(kW * kH, 0);
1960     cg.call({c_data, kW, kH});
1961 
1962     assertAllEqual(c_data, c_ref);
1963   }
1964   {
1965     // Now let's try to compute P at axis cx (the inner loop)
1966     LoopNest l(orig_loopnest);
1967     std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1968     LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
1969     l.prepareForCodegen();
1970     SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
1971     StmtPtr s = cg.stmt();
1972 
1973     // Check the IR we produced
1974     checkIR(s, R"IR(
1975 # CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
1976 # CHECK: for (int i_2 = 0; i_2 < H; i_2++)
1977 # CHECK:   for (int i_3 = 0; i_3 < W; i_3++)
1978 # CHECK:     for
1979 # CHECK:       for
1980 # CHECK-NOT: prod[
1981 # CHECK:     cons[
1982 # CHECK: Free(temp))IR");
1983 
1984     // Now check that the loop still produces the correct result.
1985     std::vector<int> c_data(kW * kH, 0);
1986     cg.call({c_data, kW, kH});
1987 
1988     assertAllEqual(c_data, c_ref);
1989   }
1990 }
1991 
TEST(LoopNest,LoopNestComputeAt_3)1992 TEST(LoopNest, LoopNestComputeAt_3) {
1993   // Verify that compute_at works on the following example:
1994   //
1995   // A(x,y) = x*y
1996   // B(x,y) = A(x, y)
1997   // C(x,y) = B(x+1, y)
1998   // D(x,y) = A(x, y+1) + C(x, y)
1999   //
2000   // i.e. when 'A' comes to 'D' directly and indirectly through 'C'.
2001 
2002   const int kW = 16, kH = 16;
2003   VarHandle W("W", kInt);
2004   VarHandle H("H", kInt);
2005   Tensor A = Compute(
2006       "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) {
2007         return ax * ay;
2008       });
2009   Tensor B = Compute(
2010       "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) {
2011         return A.load(by, bx);
2012       });
2013   Tensor C =
2014       Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) {
2015         return B.load(cy, cx + 1);
2016       });
2017   Tensor D =
2018       Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) {
2019         return A.load(dy + 1, dx) + C.load(dy, dx);
2020       });
2021 
2022   std::vector<int> c_ref(kW * kH, 0);
2023   for (int y = 0; y < kH; y++) {
2024     for (int x = 0; x < kW; x++) {
2025       c_ref[y * kW + x] = (y + 1) * x + y * (x + 1);
2026     }
2027   }
2028 
2029   LoopNest orig_loopnest({D}, {A, B, C, D});
2030   {
2031     // First let's try to compute A at axis dy (the outer loop)
2032     LoopNest l(orig_loopnest);
2033     std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
2034     LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
2035     l.prepareForCodegen();
2036     SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
2037     StmtPtr s = cg.stmt();
2038 
2039     // Check the IR we produced
2040     checkIR(s, R"IR(
2041 # CHECK: Allocate(temp); // dtype=int, dims=[1, W]
2042 # CHECK: for (int i = 0; i < H + 1; i++)
2043 # CHECK:   for (int i_1 = 0; i_1 < W + 1; i_1++)
2044 # CHECK:     A[
2045 # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
2046 # CHECK:   for (int i_3 = 0; i_3 < W + 1; i_3++)
2047 # CHECK:     B[
2048 # CHECK: for (int i_4 = 0; i_4 < H; i_4++)
2049 # CHECK:   for (int i_5 = 0; i_5 < W; i_5++)
2050 # CHECK:     C[
2051 # CHECK: for (int i_6 = 0; i_6 < H; i_6++)
2052 # CHECK:   for (int i_7 = 0; i_7 < W; i_7++)
2053 # CHECK-NOT: A[)IR");
2054 
2055     // Now check that the loop still produces the correct result.
2056     std::vector<int> c_data(kW * kH, 0);
2057     cg.call({c_data, kW, kH});
2058 
2059     assertAllEqual(c_data, c_ref);
2060   }
2061   {
2062     // Now let's try to compute A at axis dx (the inner loop)
2063     LoopNest l(orig_loopnest);
2064     std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
2065     LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]);
2066     l.prepareForCodegen();
2067     SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
2068     StmtPtr s = cg.stmt();
2069 
2070     // Check the IR we produced
2071     checkIR(s, R"IR(
2072 # CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
2073 # CHECK: for (int i = 0; i < H + 1; i++)
2074 # CHECK:   for (int i_1 = 0; i_1 < W + 1; i_1++)
2075 # CHECK:     A[
2076 # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
2077 # CHECK:   for (int i_3 = 0; i_3 < W + 1; i_3++)
2078 # CHECK:     B[
2079 # CHECK: for (int i_4 = 0; i_4 < H; i_4++)
2080 # CHECK:   for (int i_5 = 0; i_5 < W; i_5++)
2081 # CHECK:     C[
2082 # CHECK: for (int i_6 = 0; i_6 < H; i_6++)
2083 # CHECK:   for (int i_7 = 0; i_7 < W; i_7++)
2084 # CHECK-NOT: A[)IR");
2085 
2086     // Now check that the loop still produces the correct result.
2087     std::vector<int> c_data(kW * kH, 0);
2088     cg.call({c_data, kW, kH});
2089 
2090     assertAllEqual(c_data, c_ref);
2091   }
2092 }
2093 
2094 using Axis = const VarHandle&;
2095 
TEST(LoopNest,Reduce2dComputeAt)2096 TEST(LoopNest, Reduce2dComputeAt) {
2097   const int kW = 16, kH = 16;
2098   VarHandle W("W", kInt);
2099   VarHandle H("H", kInt);
2100 
2101   Tensor p = Compute(
2102       "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; });
2103   Tensor c = Reduce(
2104       "cons",
2105       {H, W},
2106       Sum(),
2107       [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); },
2108       {2, 2});
2109 
2110   std::vector<int> c_ref(kW * kH, 0);
2111   for (int y = 0; y < kH; y++) {
2112     for (int x = 0; x < kW; x++) {
2113       c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
2114     }
2115   }
2116   LoopNest orig_loopnest({c}, {p, c});
2117   checkIR(orig_loopnest.root_stmt(), R"IR(
2118 # CHECK: for (int i = 0; i < H + 1; i++) {
2119 # CHECK:   for (int i_1 = 0; i_1 < W + 1; i_1++) {
2120 # CHECK:     prod[i, i_1] = i_1 * i;
2121 # CHECK:   }
2122 # CHECK: }
2123 # CHECK: for (int i_2 = 0; i_2 < H; i_2++) {
2124 # CHECK:   for (int i_3 = 0; i_3 < W; i_3++) {
2125 # CHECK:     cons[i_2, i_3] = int(0);
2126 # CHECK:     for (int i_4 = 0; i_4 < 2; i_4++) {
2127 # CHECK:       for (int i_5 = 0; i_5 < 2; i_5++) {
2128 # CHECK:         cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5});
2129 # CHECK:       }
2130 # CHECK:     }
2131 # CHECK:   }
2132 # CHECK: }
2133 )IR");
2134 
2135   {
2136     // First let's try to compute P at axis cy (the outer loop)
2137     LoopNest l(orig_loopnest);
2138     auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2139     LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
2140     // FIXME: Calling simplify here breaks the IR:
2141     // MALFORMED INPUT: could not find base node in Load - temp[...]
2142     // l.simplify();
2143     l.eliminateDeadStores();
2144     l.prepareForCodegen();
2145     SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
2146     checkIR(cg.stmt(), R"IR(
2147 # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
2148 # CHECK: for (int i = 0; i < H; i++) {
2149 # CHECK:   for (int idx0 = 0; idx0 < 2; idx0++) {
2150 # CHECK:     for (int idx1 = 0; idx1 < W + 1; idx1++) {
2151 # CHECK:       temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0);
2152 # CHECK:     }
2153 # CHECK:   }
2154 # CHECK:   for (int i_1 = 0; i_1 < W; i_1++) {
2155 # CHECK:     cons[(0 + i * (1 * W)) + i_1 * 1] = int(0);
2156 # CHECK:     for (int i_2 = 0; i_2 < 2; i_2++) {
2157 # CHECK:       for (int i_3 = 0; i_3 < 2; i_3++) {
2158 # CHECK:         cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]);
2159 # CHECK:       }
2160 # CHECK:     }
2161 # CHECK:   }
2162 # CHECK: }
2163 # CHECK: Free(temp);
2164 )IR");
2165 
2166     // Now check that the loop still produces the correct result.
2167     std::vector<int> c_data(kW * kH, 0);
2168     cg.call({c_data, kW, kH});
2169     assertAllEqual(c_data, c_ref);
2170   }
2171   {
2172     // Now let's try to compute P at axis cx (the inner loop)
2173     LoopNest l(orig_loopnest);
2174     std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2175     LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
2176     l.simplify();
2177     l.eliminateDeadStores();
2178     l.prepareForCodegen();
2179     SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
2180     checkIR(cg.stmt(), R"IR(
2181 # CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
2182 # CHECK: for (int i = 0; i < H; i++) {
2183 # CHECK:   for (int i_1 = 0; i_1 < W; i_1++) {
2184 # CHECK:     for (int idx0 = 0; idx0 < 2; idx0++) {
2185 # CHECK:       for (int idx1 = 0; idx1 < 2; idx1++) {
2186 # CHECK:         temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1);
2187 # CHECK:       }
2188 # CHECK:     }
2189 # CHECK:     cons[(0 + i * (1 * W)) + i_1 * 1] = 0;
2190 # CHECK:     for (int i_2 = 0; i_2 < 2; i_2++) {
2191 # CHECK:       for (int i_3 = 0; i_3 < 2; i_3++) {
2192 # CHECK:         cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]);
2193 # CHECK:       }
2194 # CHECK:     }
2195 # CHECK:   }
2196 # CHECK: }
2197 # CHECK: Free(temp);
2198 )IR");
2199 
2200     // Now check that the loop still produces the correct result.
2201     std::vector<int> c_data(kW * kH, 0);
2202     cg.call({c_data, kW, kH});
2203     assertAllEqual(c_data, c_ref);
2204   }
2205 }
2206 
TEST(LoopNest,DISABLED_Conv1d_NH)2207 TEST(LoopNest, DISABLED_Conv1d_NH) {
2208   // Lots of stuff is broken here.  The computeAt swaps the axes for some odd
2209   // reason.  Even without that, the index flattener fails due to "dimensions
2210   // mismatch in flatten index".
2211 
2212   int N = 4;
2213   int H = 256;
2214   int R = 3;
2215   int Pad = 1;
2216   BufHandle IP("input", {H}, kFloat);
2217 
2218   Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) {
2219     auto cond = CompareSelect::make(h, Pad, 1, 0, kLT);
2220     cond = CompareSelect::make(h, H + Pad, 1, cond, kGE);
2221     return ifThenElse(cond, 0.f, IP.load(n, h - Pad));
2222   });
2223   Tensor B = Reduce(
2224       "B",
2225       {N, H},
2226       Sum(),
2227       [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); },
2228       {R});
2229   LoopNest l({B});
2230   checkIR(l.root_stmt(), R"IR(
2231 # CHECK: for (int np = 0; np < 4; np++) {
2232 # CHECK:   for (int hp = 0; hp < 258; hp++) {
2233 # CHECK:     A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]);
2234 # CHECK:   }
2235 # CHECK: }
2236 # CHECK: for (int n = 0; n < 4; n++) {
2237 # CHECK:   for (int h = 0; h < 256; h++) {
2238 # CHECK:     B[n, h] = float(0);
2239 # CHECK:     for (int r = 0; r < 3; r++) {
2240 # CHECK:       B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r});
2241 # CHECK:     }
2242 # CHECK:   }
2243 # CHECK: }
2244 )IR");
2245   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
2246   LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
2247   // FIXME: The current IR is totally broken.  The body of the inlined loop is:
2248 
2249   // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0),
2250   // 0.f, input[idx1 + 0, (idx0 + n) - 1]);
2251 
2252   // Which seems to mix up the axes.  The CHECK below is my best guess at what
2253   // the input "should" look like
2254 
2255   checkIR(l.root_stmt(), R"IR(
2256 # CHECK: for (int n = 0; n < 4; n++) {
2257 # CHECK:   for (int idx0 = 0; idx0 < 1; idx0++) {
2258 # CHECK:     for (int idx1 = 0; idx1 < 258; idx1++) {
2259         temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]);
2260 # CHECK:     }
2261 # CHECK:   }
2262 # CHECK:   for (int h = 0; h < 256; h++) {
2263 # CHECK:     B[n, h] = float(0);
2264 # CHECK:     for (int r = 0; r < 3; r++) {
2265 # CHECK:       B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r});
2266 # CHECK:     }
2267 # CHECK:   }
2268 # CHECK: }
2269 )IR");
2270 
2271   l.simplify();
2272   l.prepareForCodegen();
2273   StmtPtr s = l.root_stmt();
2274 
2275   SimpleIREvaluator cg(s, {IP, B});
2276   // auto At = at::ones({N, H}, at::kFloat);
2277   auto At = at::arange(N * H, at::kFloat).reshape({N, H});
2278   auto Rt = at::conv1d(
2279       At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3);
2280   auto Bt = at::empty_like(Rt);
2281   cg.call({At.data_ptr<float>(), Bt.data_ptr<float>()});
2282   ASSERT_TRUE(at::allclose(Rt, Bt));
2283 }
2284 
2285 class LoopOrderHelper : public IRVisitor {
2286   std::stringstream ordering;
2287 
2288  public:
getOrder(StmtPtr s)2289   std::string getOrder(StmtPtr s) {
2290     ordering.str("");
2291     s->accept(this);
2292     return ordering.str();
2293   }
2294 
visit(const ForPtr & v)2295   void visit(const ForPtr& v) final {
2296     ordering << v->var()->name_hint() << ",";
2297     IRVisitor::visit(v);
2298   }
2299 };
2300 
TEST(LoopNest,LoopNestReorderAxis1)2301 TEST(LoopNest, LoopNestReorderAxis1) {
2302   Tensor tensor =
2303       Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
2304         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2305       });
2306   LoopNest l({tensor});
2307   StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2308 
2309   std::vector<int> stmt1_output(6, 0);
2310   SimpleIREvaluator cg(stmt1, {tensor});
2311   cg.call({stmt1_output});
2312 
2313   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2314   LoopNest::reorderAxis(loops[0], loops[1]);
2315   StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2316 
2317   ASSERT_NE(stmt1, stmt2);
2318   LoopOrderHelper loopOrderHelper;
2319   std::string order1 = loopOrderHelper.getOrder(stmt1);
2320   std::string order2 = loopOrderHelper.getOrder(stmt2);
2321 
2322   ASSERT_EQ(order1, "j,i,");
2323   ASSERT_EQ(order2, "i,j,");
2324 
2325   std::vector<int> stmt2_output(6, 0);
2326   SimpleIREvaluator cg2(stmt2, {tensor});
2327   cg.call({stmt2_output});
2328 
2329   for (int i = 0; i < 6; ++i) {
2330     ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2331   }
2332 
2333   // Reorder them back.
2334   loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2335   LoopNest::reorderAxis(loops[0], loops[1]);
2336   StmtPtr stmt3 = l.root_stmt();
2337 
2338   std::string order3 = loopOrderHelper.getOrder(stmt3);
2339   ASSERT_EQ(order3, order1);
2340 
2341   std::ostringstream oss1, oss2;
2342   oss1 << *stmt1;
2343   oss2 << *stmt3;
2344 
2345   // Should be identical to the unreordered statement.
2346   ASSERT_EQ(oss1.str(), oss2.str());
2347 }
2348 
TEST(LoopNest,LoopNestReorderPartialAxes)2349 TEST(LoopNest, LoopNestReorderPartialAxes) {
2350   Tensor tensor = Compute(
2351       "f",
2352       {2, 3, 4},
2353       [](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
2354         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
2355             cast<float>(z) * z;
2356       });
2357   LoopNest l({tensor});
2358 
2359   LoopOrderHelper loopOrderHelper;
2360   StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2361   ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,");
2362 
2363   std::vector<int> stmt1_output(24, 0);
2364   SimpleIREvaluator cg(stmt1, {tensor});
2365   cg.call({stmt1_output});
2366 
2367   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2368   LoopNest::reorderAxis(loops[0], loops[1]);
2369   ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,");
2370 
2371   StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2372 
2373   std::vector<int> stmt2_output(24, 0);
2374   SimpleIREvaluator cg2(stmt2, {tensor});
2375   cg2.call({stmt2_output});
2376 
2377   for (int i = 0; i < 24; ++i) {
2378     ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2379   }
2380 
2381   loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2382   LoopNest::reorderAxis(loops[1], loops[2]);
2383   ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,");
2384 
2385   StmtPtr stmt3 = Stmt::clone(l.root_stmt());
2386 
2387   std::vector<int> stmt3_output(24, 0);
2388   SimpleIREvaluator cg3(stmt3, {tensor});
2389   cg3.call({stmt3_output});
2390 
2391   for (int i = 0; i < 24; ++i) {
2392     ASSERT_EQ(stmt1_output[i], stmt3_output[i]);
2393   }
2394 }
2395 
TEST(LoopNest,LoopNestReorderInternalAxis)2396 TEST(LoopNest, LoopNestReorderInternalAxis) {
2397   Tensor tensor = Compute(
2398       "f",
2399       {1, 2, 3, 4},
2400       [](const VarHandle& w,
2401          const VarHandle& x,
2402          const VarHandle& y,
2403          const VarHandle& z) {
2404         return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
2405             cast<float>(z) * z;
2406       });
2407   LoopNest l({tensor});
2408 
2409   LoopOrderHelper loopOrderHelper;
2410   StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2411   ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,");
2412 
2413   std::vector<int> stmt1_output(24, 0);
2414   SimpleIREvaluator cg(stmt1, {tensor});
2415   cg.call({stmt1_output});
2416 
2417   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2418   LoopNest::reorderAxis(loops[2], loops[1]);
2419   ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,");
2420 
2421   StmtPtr stmt2 = l.root_stmt();
2422 
2423   std::vector<int> stmt2_output(24, 0);
2424   SimpleIREvaluator cg2(stmt2, {tensor});
2425   cg2.call({stmt2_output});
2426 
2427   for (int i = 0; i < 24; ++i) {
2428     ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2429   }
2430 }
2431 
TEST(LoopNest,LoopNestReorderEnclosingAxis)2432 TEST(LoopNest, LoopNestReorderEnclosingAxis) {
2433   Tensor tensor = Compute(
2434       "f",
2435       {1, 2, 3, 4},
2436       [](const VarHandle& w,
2437          const VarHandle& x,
2438          const VarHandle& y,
2439          const VarHandle& z) {
2440         return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
2441             cast<float>(z) * z;
2442       });
2443   LoopNest l({tensor});
2444 
2445   LoopOrderHelper loopOrderHelper;
2446   StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2447 
2448   std::vector<int> stmt1_output(24, 0);
2449   SimpleIREvaluator cg(stmt1, {tensor});
2450   cg.call({stmt1_output});
2451 
2452   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2453   LoopNest::reorderAxis(loops[0], loops[3]);
2454   ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,");
2455 
2456   StmtPtr stmt2 = l.root_stmt();
2457 
2458   std::vector<int> stmt2_output(24, 0);
2459   SimpleIREvaluator cg2(stmt2, {tensor});
2460   cg2.call({stmt2_output});
2461 
2462   for (int i = 0; i < 24; ++i) {
2463     ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2464   }
2465 }
2466 
TEST(LoopNest,LoopNestReorderSameAxis)2467 TEST(LoopNest, LoopNestReorderSameAxis) {
2468   Tensor tensor =
2469       Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
2470         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2471       });
2472   LoopNest l({tensor});
2473   StmtPtr stmt1 = Stmt::clone(l.root_stmt());
2474 
2475   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2476   LoopNest::reorderAxis(loops[1], loops[1]);
2477   StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2478 
2479   std::ostringstream oss, oss2;
2480   oss << *stmt1;
2481   oss2 << *stmt2;
2482   ASSERT_EQ(oss.str(), oss2.str());
2483 }
2484 
TEST(LoopNest,LoopNestReorderExtraStatements)2485 TEST(LoopNest, LoopNestReorderExtraStatements) {
2486   /* We're going for a structure like this:
2487    * for i in ...
2488    *   Stmt 1
2489    *   for j in ...
2490    *     Stmt 2
2491    *     for k in ...
2492    *       Stmt 3
2493    *     Stmt 4
2494    */
2495 
2496   Tensor tensor = Compute(
2497       "f",
2498       {2, 3, 4},
2499       [](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
2500         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
2501             cast<float>(z) * z;
2502       });
2503   LoopNest l({tensor});
2504 
2505   BufHandle extra("res", {6, 3}, kFloat);
2506 
2507   auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2508 
2509   VarHandle i = VarHandle(loops[0]->var());
2510 
2511   StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f);
2512   StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f);
2513   // stmt 3 is the Function body.
2514   StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f);
2515 
2516   loops[0]->body()->prepend_stmt(store_1);
2517   loops[1]->body()->prepend_stmt(store_2);
2518   loops[1]->body()->append_stmt(store_3);
2519   StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2520 
2521   std::vector<int> extra1(6, 0);
2522   std::vector<int> res1(24, 0);
2523   SimpleIREvaluator cg(stmt1, {tensor, extra});
2524   cg.call({res1, extra1});
2525 
2526   /* Then we reorder loop y and z, we want it to look like:
2527    *
2528    * for i in ...
2529    *   Stmt 1
2530    *   for j in ...
2531    *     Stmt 2
2532    *   for j_1 in ...
2533    *    for k in ...
2534    *       Stmt 3
2535    *   for j_2 in ...
2536    *     Stmt 4
2537    *
2538    * We need extra loops because we don't have dependency info about stmt 3
2539    * and 4.
2540    *
2541    */
2542 
2543   LoopNest::reorderAxis(loops[1], loops[2]);
2544   StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2545 
2546   // Check the IR we produced
2547   checkIR(stmt2, R"IR(
2548 # CHECK: for
2549 # CHECK:   res[i, 0] = 1
2550 # CHECK:   for
2551 # CHECK:     res[i, 1] = 2
2552 # CHECK:   for
2553 # CHECK:     for
2554 # CHECK:       f[
2555 # CHECK:   for
2556 # CHECK:     res[i, 2] = 4
2557 )IR");
2558 
2559   std::vector<int> extra2(6, 0);
2560   std::vector<int> res2(24, 0);
2561   SimpleIREvaluator cg2(stmt2, {tensor, extra});
2562   cg2.call({res2, extra2});
2563 
2564   for (int i = 0; i < 24; ++i) {
2565     ASSERT_EQ(res1[i], res2[i]);
2566   }
2567   for (int i = 0; i < 6; ++i) {
2568     ASSERT_EQ(extra1[i], extra2[i]);
2569   }
2570 
2571   /* Now reorder x and the y above stmt 3:
2572    *
2573    *
2574    * for x in ...
2575    *   Stmt 1
2576    *   for y in ...
2577    *     Stmt 2
2578    *
2579    * for y in ...
2580    *   for z in ...
2581    *    for x in ...
2582    *       Stmt 3
2583    *
2584    * for x in ...
2585    *   for y in ...
2586    *     Stmt 4
2587    *
2588    *
2589    */
2590   loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2591   LoopNest::reorderAxis(loops[0], loops[2]);
2592   StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2593 
2594   // Check the IR we produced
2595   checkIR(stmt3, R"IR(
2596 # CHECK: for
2597 # CHECK:   res[i, 0] = 1
2598 # CHECK:   for
2599 # CHECK:     res[i, 1] = 2
2600 # CHECK: for
2601 # CHECK:   for
2602 # CHECK:     for
2603 # CHECK:       f[
2604 # CHECK: for
2605 # CHECK:   for
2606 # CHECK:     res[i_2, 2] = 4
2607 )IR");
2608 
2609   std::vector<int> extra3(6, 0);
2610   std::vector<int> res3(24, 0);
2611   SimpleIREvaluator cg3(stmt3, {tensor, extra});
2612   cg3.call({res3, extra3});
2613 
2614   for (int i = 0; i < 24; ++i) {
2615     ASSERT_EQ(res1[i], res3[i]);
2616   }
2617   for (int i = 0; i < 6; ++i) {
2618     ASSERT_EQ(extra1[i], extra3[i]);
2619   }
2620 }
2621 
LoopNestReorderTestHelper(bool prepend,bool append,int index1,int index2)2622 void LoopNestReorderTestHelper(
2623     bool prepend,
2624     bool append,
2625     int index1,
2626     int index2) {
2627   Tensor c = Compute(
2628       "5d", {2, 3, 2, 3, 2}, [](const std::vector<VarHandle>&) { return -1; });
2629   LoopNest l({c});
2630 
2631   BufHandle extra("extra", {5}, kInt);
2632 
2633   auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2634   int j = 0;
2635   for (auto l : loops) {
2636     // Add an increment at each layer of the loop which counts the number of
2637     // times the loop executes.
2638     LoadPtr load =
2639         alloc<Load>(extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}));
2640     AddPtr add = alloc<Add>(load, alloc<IntImm>(1));
2641     StmtPtr store = alloc<Store>(
2642         extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add);
2643     if (prepend) {
2644       l->body()->prepend_stmt(store);
2645     }
2646     if (append) {
2647       l->body()->append_stmt(Stmt::clone(store));
2648     }
2649 
2650     j++;
2651   }
2652 
2653   StmtPtr stmt1 = Stmt::clone(l.root_stmt());
2654 
2655   std::vector<int> extra1(5, 0);
2656   std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0);
2657   SimpleIREvaluator cg(stmt1, {c, extra});
2658   cg.call({res1, extra1});
2659 
2660   std::vector<int> loopExtents = {2, 3, 2, 3, 2};
2661 
2662   int expected_loops = 0;
2663   if (prepend) {
2664     expected_loops++;
2665   }
2666   if (append) {
2667     expected_loops++;
2668   }
2669   for (int i = 0; i < 5; ++i) {
2670     expected_loops *= loopExtents[i];
2671     ASSERT_EQ(extra1[i], expected_loops);
2672   }
2673 
2674   loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2675   LoopNest::reorderAxis(loops[index1], loops[index2]);
2676   StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2677 
2678   std::ostringstream oss, oss2;
2679   oss << *stmt1;
2680   oss2 << *stmt2;
2681   ASSERT_NE(oss.str(), oss2.str());
2682 
2683   std::vector<int> extra2(5, 0);
2684   std::vector<int> res2(2 * 3 * 2 * 3 * 2, 0);
2685   SimpleIREvaluator cg2(stmt2, {c, extra});
2686   cg2.call({res2, extra2});
2687 
2688   expected_loops = 0;
2689   if (prepend) {
2690     expected_loops++;
2691   }
2692   if (append) {
2693     expected_loops++;
2694   }
2695 
2696   for (int i = 0; i < 5; ++i) {
2697     expected_loops *= loopExtents[i];
2698     ASSERT_EQ(extra2[i], expected_loops);
2699   }
2700 
2701   for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) {
2702     ASSERT_EQ(res2[i], res1[i]);
2703   }
2704 }
2705 
TEST(LoopNest,LoopNestReorderLongStringOfPreOrphans)2706 TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) {
2707   for (int i = 0; i < 5; ++i) {
2708     for (int j = 0; j < 5; ++j) {
2709       // skip noops, since we check the loop isn't the same after reordering.
2710       if (i != j) {
2711         LoopNestReorderTestHelper(true, false, i, j);
2712       }
2713     }
2714   }
2715 }
2716 
TEST(LoopNest,LoopNestReorderLongStringOfPostOrphans)2717 TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) {
2718   for (int i = 0; i < 5; ++i) {
2719     for (int j = 0; j < 5; ++j) {
2720       // skip noops, since we check the loop isn't the same after reordering.
2721       if (i != j) {
2722         LoopNestReorderTestHelper(false, true, i, j);
2723       }
2724     }
2725   }
2726 }
2727 
TEST(LoopNest,LoopNestReorderLongStringFull)2728 TEST(LoopNest, LoopNestReorderLongStringFull) {
2729   for (int i = 0; i < 5; ++i) {
2730     for (int j = 0; j < 5; ++j) {
2731       // skip noops, since we check the loop isn't the same after reordering.
2732       if (i != j) {
2733         LoopNestReorderTestHelper(true, true, i, j);
2734       }
2735     }
2736   }
2737 }
2738 
TEST(LoopNest,LoopNestReorderInternalLoopNest)2739 TEST(LoopNest, LoopNestReorderInternalLoopNest) {
2740   const int M = 4;
2741   const int N = 5;
2742   const int K = 6;
2743   BufHandle a_buf("a", {M, N}, kFloat);
2744   BufHandle b_buf("b", {N, K}, kFloat);
2745   BufHandle c_buf("c", {M, N}, kFloat);
2746   BufHandle d_buf("d", {M, K}, kFloat);
2747 
2748   Tensor x = Compute(
2749       "x",
2750       {M, N, K},
2751       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2752         return a_buf.load(m, n) * b_buf.load(n, k);
2753       });
2754   Tensor y = Compute(
2755       "y",
2756       {M, N, K},
2757       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2758         return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
2759       });
2760   Tensor z = Compute(
2761       "z",
2762       {M, N, K},
2763       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2764         return x.load(m, n, k) + y.load(m, n, k);
2765       });
2766 
2767   LoopNest l({z}, {x, y, z});
2768   ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2];
2769   ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0];
2770   LoopNest::reorderAxis(a, b);
2771 
2772   l.prepareForCodegen();
2773   StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2774 
2775   // Check the IR we produced has the 3 nests in the right order, but k and m
2776   // swapped in the middle.
2777   checkIR(stmt, R"IR(
2778 # CHECK: < 4
2779 # CHECK: < 5
2780 # CHECK: < 6
2781 # CHECK: < 6
2782 # CHECK: < 5
2783 # CHECK: < 4
2784 # CHECK: < 4
2785 # CHECK: < 5
2786 # CHECK: < 6)IR");
2787 
2788   {
2789     PaddedBuffer<float> a_v(M, N);
2790     PaddedBuffer<float> b_v(N, K);
2791     PaddedBuffer<float> c_v(M, N);
2792     PaddedBuffer<float> d_v(M, K);
2793 
2794     for (int i = 0; i < M; i++) {
2795       for (int j = 0; j < N; j++) {
2796         a_v(i, j) = i * i;
2797       }
2798     }
2799     for (int i = 0; i < N; i++) {
2800       for (int j = 0; j < K; j++) {
2801         b_v(i, j) = j * j;
2802       }
2803     }
2804     for (int i = 0; i < M; i++) {
2805       for (int j = 0; j < N; j++) {
2806         c_v(i, j) = i + j;
2807       }
2808     }
2809     for (int i = 0; i < M; i++) {
2810       for (int j = 0; j < K; j++) {
2811         d_v(i, j) = i * j;
2812       }
2813     }
2814 
2815     PaddedBuffer<float> z_v(M, N, K);
2816     PaddedBuffer<float> z_ref(M, N, K);
2817     for (int m = 0; m < M; m++) {
2818       for (int n = 0; n < N; n++) {
2819         for (int k = 0; k < K; k++) {
2820           z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
2821         }
2822       }
2823     }
2824 
2825     SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
2826     eval(a_v, b_v, c_v, d_v, z_v);
2827     ExpectAllNear(z_v, z_ref, 1e-5);
2828   }
2829 }
2830 
TEST(LoopNest,OuterLoopVectorization)2831 TEST(LoopNest, OuterLoopVectorization) {
2832   Tensor tensor =
2833       Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) {
2834         return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2835       });
2836   LoopNest l({tensor});
2837 
2838   ASSERT_TRUE(
2839       LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0]));
2840 
2841   StmtPtr root_stmt = l.root_stmt();
2842   BlockPtr outer_block = to<Block>(root_stmt);
2843   ASSERT_NE(outer_block, nullptr);
2844   while (BlockPtr inner_block = to<Block>(outer_block->front())) {
2845     outer_block = inner_block;
2846   }
2847 
2848   // Verify that we have only a single loop level remaining after
2849   // vectorization.
2850   ASSERT_EQ(outer_block->nstmts(), 1);
2851   ForPtr for_loop = to<For>(outer_block->front());
2852   ASSERT_NE(for_loop, nullptr);
2853   BlockPtr for_body = for_loop->body();
2854   ASSERT_EQ(for_body->nstmts(), 1);
2855   ASSERT_EQ(to<For>(for_body->front()), nullptr);
2856 }
2857 
TEST(LoopNest,VectorizeLoopNotNormalized)2858 TEST(LoopNest, VectorizeLoopNotNormalized) {
2859   // Input IR:
2860   //   for (int i = 0; i < 10; i++) {
2861   //     for (int j = 1; j < 5; j++) {
2862   //       A[i,j] = i * j;
2863   //     }
2864   //   }
2865   BufHandle a_buf("A", {10, 5}, kInt);
2866   VarHandle i("i", kInt);
2867   VarHandle j("j", kInt);
2868   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
2869   auto inner_for = For::make(j, 1, 5, for_body);
2870   auto outer_for = For::make(i, 0, 10, inner_for);
2871   auto block = Block::make({outer_for});
2872   LoopNest l(block, {a_buf.node()});
2873 
2874   ASSERT_TRUE(LoopNest::vectorize(inner_for));
2875   ASSERT_EQ(outer_for->body()->nstmts(), 1);
2876   ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr);
2877 }
2878 
2879 namespace {
2880 
constantUpperBoundLoopIR(int upper_bound_val)2881 std::string constantUpperBoundLoopIR(int upper_bound_val) {
2882   ExprHandle upper_bound(upper_bound_val);
2883   Tensor A =
2884       Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
2885   LoopNest l({A});
2886   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2887   StmtPtr unrolled = nullptr;
2888   LoopNest::fullUnroll(loops[0], &unrolled);
2889   std::ostringstream oss;
2890   oss << *unrolled;
2891   return oss.str();
2892 }
2893 
2894 } // namespace
2895 
TEST(LoopNest,Unroll)2896 TEST(LoopNest, Unroll) {
2897   const std::string actual = constantUpperBoundLoopIR(3);
2898   const std::string& verification_pattern =
2899       R"IR(
2900 # CHECK: A[0] = 0;
2901 # CHECK: A[1] = 2;
2902 # CHECK: A[2] = 4)IR";
2903 
2904   torch::jit::testing::FileCheck().run(verification_pattern, actual);
2905 }
2906 
TEST(LoopNest,UnrollOuter)2907 TEST(LoopNest, UnrollOuter) {
2908   ExprHandle outer_bound(3);
2909   ExprHandle inner_bound(4);
2910   Tensor A = Compute(
2911       "A",
2912       {outer_bound, inner_bound},
2913       [&](const VarHandle& x, const VarHandle& y) { return x + y; });
2914   LoopNest l({A});
2915   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2916   StmtPtr unrolled = nullptr;
2917   LoopNest::fullUnroll(loops[0], &unrolled);
2918   checkIR(unrolled, R"IR(
2919 # CHECK: for (int i = 0; i < 4; i++) {
2920 # CHECK: A[0, i] = i;
2921 # CHECK: }
2922 # CHECK: for (int i = 0; i < 4; i++) {
2923 # CHECK: A[1, i] = i + 1;
2924 # CHECK: }
2925 # CHECK: for (int i = 0; i < 4; i++) {
2926 # CHECK: A[2, i] = i + 2;
2927 # CHECK: })IR");
2928 }
2929 
TEST(LoopNest,UnrollInner)2930 TEST(LoopNest, UnrollInner) {
2931   ExprHandle outer_bound(3);
2932   ExprHandle inner_bound(4);
2933   Tensor A = Compute(
2934       "A",
2935       {outer_bound, inner_bound},
2936       [&](const VarHandle& x, const VarHandle& y) { return x + y; });
2937   LoopNest l({A});
2938   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2939   StmtPtr unrolled = nullptr;
2940   LoopNest::fullUnroll(
2941       static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
2942   checkIR(loops[0], R"IR(
2943 # CHECK: for (int i = 0; i < 3; i++) {
2944 # CHECK: A[i, 0] = i;
2945 # CHECK: A[i, 1] = i + 1;
2946 # CHECK: A[i, 2] = i + 2;
2947 # CHECK: A[i, 3] = i + 3;
2948 # CHECK: })IR");
2949 }
2950 
TEST(LoopNest,UnrollMultipleStatements)2951 TEST(LoopNest, UnrollMultipleStatements) {
2952   const int kTotalSize = 3;
2953   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
2954   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
2955 
2956   VarHandle x("x", kInt);
2957   auto f = For::make(
2958       x,
2959       0,
2960       kTotalSize,
2961       Block::make(
2962           {Store::make(a_buf, {x}, x * 2),
2963            Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
2964   auto parent_block = Block::make({f});
2965   StmtPtr unrolled = nullptr;
2966   LoopNest::fullUnroll(f, &unrolled);
2967   checkIR(unrolled, R"IR(
2968 # CHECK: A[0] = 0;
2969 # CHECK: B[0] = A[0];
2970 # CHECK: A[1] = 2;
2971 # CHECK: B[1] = A[1];
2972 # CHECK: A[2] = 4
2973 # CHECK: B[2] = A[2];)IR");
2974 }
2975 
TEST(LoopNest,UnrollNonLiteralConstantBounds)2976 TEST(LoopNest, UnrollNonLiteralConstantBounds) {
2977   // Input IR:
2978   //   for (int i = 2 - 1; i < 12 / 3; i++) {
2979   //     for (int j = 0; j < 4; j++) {
2980   //       A[i,j] = i * j;
2981   //     }
2982   //   }
2983   BufHandle a_buf("A", {3, 4}, kInt);
2984   VarHandle i("i", kInt);
2985   VarHandle j("j", kInt);
2986   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
2987   auto inner_for = For::make(j, 0, 4, for_body);
2988   auto outer_for = For::make(
2989       i,
2990       IntImm::make(2) - IntImm::make(1),
2991       IntImm::make(12) / IntImm::make(3),
2992       inner_for);
2993   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2994   auto b = Block::make({outer_for});
2995 
2996   std::vector<ForPtr> loops = {outer_for, inner_for};
2997   StmtPtr unrolled = nullptr;
2998   LoopNest::fullUnroll(loops[0], &unrolled);
2999   checkIR(unrolled, R"IR(
3000 # CHECK: for (int j = 0; j < 4; j++) {
3001 # CHECK:   A[1, j] = j;
3002 # CHECK: }
3003 # CHECK: for (int j = 0; j < 4; j++) {
3004 # CHECK:   A[2, j] = 2 * j;
3005 # CHECK: }
3006 # CHECK: for (int j = 0; j < 4; j++) {
3007 # CHECK:   A[3, j] = 3 * j;
3008 # CHECK: })IR");
3009 }
3010 
TEST(LoopNest,UnrollNonConstantBounds)3011 TEST(LoopNest, UnrollNonConstantBounds) {
3012   // Input IR:
3013   //   for (int i = 0; i < M; i++) {
3014   //     for (int j = 0; j < N; j++) {
3015   //       A[i, j] = i * j;
3016   //     }
3017   //   }
3018   VarHandle M("M", kInt);
3019   VarHandle N("N", kInt);
3020   BufHandle a_buf("A", {M, N}, kInt);
3021   VarHandle i("i", kInt);
3022   VarHandle j("j", kInt);
3023   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3024   auto inner_for = For::make(j, 0, N, for_body);
3025   auto outer_for = For::make(i, 0, M, inner_for);
3026   auto block = Block::make({outer_for});
3027   LoopNest l(block, {a_buf.node()});
3028 
3029   LoopNest::unroll(inner_for, 8);
3030   l.simplify();
3031   checkIR(l.root_stmt(), R"IR(
3032     # CHECK: for (int i = 0; i < M; i++) {
3033     # CHECK:   for (int j_outer = 0; j_outer < N / 8; j_outer++) {
3034     # CHECK:     A[i, 8 * j_outer] =
3035     # CHECK:     A[i, 8 * j_outer + 1] =
3036     # CHECK:     A[i, 2 * (4 * j_outer + 1)] =
3037     # CHECK:     A[i, 8 * j_outer + 3] =
3038     # CHECK:     A[i, 4 * (2 * j_outer + 1)] =
3039     # CHECK:     A[i, 8 * j_outer + 5] =
3040     # CHECK:     A[i, 8 * j_outer + 6] =
3041     # CHECK:     A[i, 8 * j_outer + 7] =
3042     # CHECK:   }
3043     # CHECK:   for (int j_tail = 0; j_tail < N % 8; j_tail++) {
3044     # CHECK:     A[i, 8 * (N / 8) + j_tail] =
3045     # CHECK:   }
3046     # CHECK: }
3047   )IR");
3048 }
3049 
TEST(LoopNest,UnrollByFactorsLessThan2)3050 TEST(LoopNest, UnrollByFactorsLessThan2) {
3051   // Input IR:
3052   //   for (int i = 0; i < M; i++) {
3053   //     for (int j = 0; j < N; j++) {
3054   //       A[i, j] = i * j;
3055   //     }
3056   //   }
3057   VarHandle M("M", kInt);
3058   VarHandle N("N", kInt);
3059   BufHandle a_buf("A", {M, N}, kInt);
3060   VarHandle i("i", kInt);
3061   VarHandle j("j", kInt);
3062   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3063   auto inner_for = For::make(j, 0, N, for_body);
3064   auto outer_for = For::make(i, 0, M, inner_for);
3065   auto block = Block::make({outer_for});
3066   LoopNest l(block, {a_buf.node()});
3067 
3068   // Unrolling by factor = 1 should do nothing.
3069   LoopNest::unroll(inner_for, 1);
3070   checkIR(l.root_stmt(), R"IR(
3071     # CHECK: for (int i = 0; i < M; i++) {
3072     # CHECK:   for (int j = 0; j < N; j++) {
3073     # CHECK:     A[i, j] =
3074     # CHECK:   }
3075     # CHECK: }
3076   )IR");
3077 
3078   // Unrolling by factor = 0 should do nothing.
3079   LoopNest::unroll(inner_for, 0);
3080   checkIR(l.root_stmt(), R"IR(
3081     # CHECK: for (int i = 0; i < M; i++) {
3082     # CHECK:   for (int j = 0; j < N; j++) {
3083     # CHECK:     A[i, j] =
3084     # CHECK:   }
3085     # CHECK: }
3086   )IR");
3087 
3088   // Unrolling by negative factor should do nothing.
3089   LoopNest::unroll(inner_for, -2);
3090   checkIR(l.root_stmt(), R"IR(
3091     # CHECK: for (int i = 0; i < M; i++) {
3092     # CHECK:   for (int j = 0; j < N; j++) {
3093     # CHECK:     A[i, j] =
3094     # CHECK:   }
3095     # CHECK: }
3096   )IR");
3097 }
3098 
TEST(LoopNest,UnrollByFactorEqualToIters)3099 TEST(LoopNest, UnrollByFactorEqualToIters) {
3100   // Input IR:
3101   //   for (int i = 0; i < 5; i++) {
3102   //     A[i] = i * i;
3103   //   }
3104   BufHandle a_buf("A", {5}, kInt);
3105   VarHandle i("i", kInt);
3106   auto for_body = Block::make({Store::make(a_buf, {i}, i * i)});
3107   auto for_loop = For::make(i, 0, 5, for_body);
3108   auto block = Block::make({for_loop});
3109   LoopNest l(block, {a_buf.node()});
3110 
3111   LoopNest::unroll(for_loop, 5);
3112   checkIR(l.root_stmt(), R"IR(
3113     # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++)
3114     # CHECK:   A[5 * i_outer]
3115     # CHECK:   A[5 * i_outer + 1]
3116     # CHECK:   A[5 * i_outer + 2]
3117     # CHECK:   A[5 * i_outer + 3]
3118     # CHECK:   A[5 * i_outer + 4]
3119   )IR");
3120 }
3121 
TEST(LoopNest,UnrollEmpty)3122 TEST(LoopNest, UnrollEmpty) {
3123   const std::string actual = constantUpperBoundLoopIR(0);
3124   const std::string& verification_pattern = R"IR(
3125 # CHECK-NOT: A[
3126   )IR";
3127 
3128   torch::jit::testing::FileCheck().run(verification_pattern, actual);
3129 }
3130 
TEST(LoopNest,NoUnroll)3131 TEST(LoopNest, NoUnroll) {
3132   VarHandle upper_bound("N", kInt);
3133   Tensor A =
3134       Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
3135   LoopNest l({A});
3136   std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
3137   StmtPtr unrolled = nullptr;
3138   ASSERT_THROWS_WITH(
3139       LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop");
3140 }
3141 
TEST(LoopNest,UnrollWithLet)3142 TEST(LoopNest, UnrollWithLet) {
3143   const int kTotalSize = 3;
3144   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3145   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3146 
3147   VarHandle e("e", kInt);
3148   VarHandle x("x", kInt);
3149   auto f = For::make(
3150       x,
3151       0,
3152       kTotalSize,
3153       Block::make(
3154           {Let::make(e, 7),
3155            Store::make(a_buf, {x}, e),
3156            Store::make(b_buf, {x}, e + 1)}));
3157   auto parent_block = Block::make({f});
3158   StmtPtr unrolled = nullptr;
3159   LoopNest::fullUnroll(f, &unrolled);
3160   std::ostringstream oss;
3161   oss << *unrolled;
3162   const std::string& verification_pattern =
3163       R"IR(
3164 # CHECK: int e = 7;
3165 # CHECK: A[0] = e;
3166 # CHECK: B[0] = e + 1;
3167 # CHECK: A[1] = e;
3168 # CHECK: B[1] = e + 1;
3169 # CHECK: A[2] = e;
3170 # CHECK: B[2] = e + 1;)IR";
3171 
3172   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3173 
3174   std::vector<int> a_v(kTotalSize, 0);
3175   std::vector<int> b_v(kTotalSize, 0);
3176   SimpleIREvaluator eval(unrolled, {a_buf, b_buf});
3177   eval(a_v, b_v);
3178   for (int i = 0; i < kTotalSize; ++i) {
3179     ASSERT_EQ(a_v[i], 7);
3180     ASSERT_EQ(b_v[i], 8);
3181   }
3182 }
3183 
TEST(LoopNest,IsNormalized)3184 TEST(LoopNest, IsNormalized) {
3185   // Input IR:
3186   //   for (int i = 50; i < 100; i++) {
3187   //     A[i] = B[i];
3188   //   }
3189   BufHandle a_buf("A", {ExprHandle(100)}, kInt);
3190   BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3191   VarHandle i("i", kInt);
3192   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
3193   auto for_stmt =
3194       For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i})));
3195   Block::make({for_stmt});
3196   ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
3197 
3198   for_stmt->set_start(alloc<IntImm>(0));
3199   ASSERT_TRUE(LoopNest::isNormalized(for_stmt));
3200 
3201   VarHandle N("N", kInt);
3202   for_stmt->set_start(N.node());
3203   ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
3204 }
3205 
TEST(LoopNest,NormalizeStartPositive)3206 TEST(LoopNest, NormalizeStartPositive) {
3207   // Input IR:
3208   //   for (int x = 50; x < 100; x++) {
3209   //     A[x] = B[x];
3210   //     B[x] = x * 2;
3211   //   }
3212   const int kTotalSize = 50;
3213   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3214   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3215   VarHandle x("x", kInt);
3216   auto for_body = Block::make(
3217       {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3218        Store::make(b_buf, {x}, x * 2)});
3219   auto for_stmt = For::make(x, 50, 100, for_body);
3220   Block::make({for_stmt});
3221 
3222   LoopNest::normalize(for_stmt);
3223 
3224   auto result = IRSimplifier::simplify(for_stmt);
3225   std::ostringstream oss;
3226   oss << *result;
3227   const std::string& expected_ir =
3228       R"IR(
3229         # CHECK: for (int x = 0; x < 50; x++) {
3230         # CHECK:   A[x + 50] = B[x + 50];
3231         # CHECK:   B[x + 50] = 2 * (x + 50);
3232       )IR";
3233   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3234 }
3235 
TEST(LoopNest,NormalizeStartNegative)3236 TEST(LoopNest, NormalizeStartNegative) {
3237   // Input IR:
3238   //   for (int x = -50; x < 100; x++) {
3239   //     A[x + 50] = B[x + 50];
3240   //     B[x + 50] = x * 2;
3241   //   }
3242   const int kTotalSize = 150;
3243   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3244   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3245   VarHandle x("x", kInt);
3246   auto for_body = Block::make(
3247       {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})),
3248        Store::make(b_buf, {x + 50}, x * 2)});
3249   auto for_stmt = For::make(x, -50, 100, for_body);
3250   Block::make({for_stmt});
3251 
3252   LoopNest::normalize(for_stmt);
3253 
3254   auto result = IRSimplifier::simplify(for_stmt);
3255   std::ostringstream oss;
3256   oss << *result;
3257   const std::string& expected_ir =
3258       R"IR(
3259         # CHECK: for (int x = 0; x < 150; x++) {
3260         # CHECK:   A[x] = B[x];
3261         # CHECK:   B[x] = 2 * (x - 50);
3262       )IR";
3263   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3264 }
3265 
TEST(LoopNest,NormalizeStartZero)3266 TEST(LoopNest, NormalizeStartZero) {
3267   // Input IR:
3268   //   for (int x = 0; x < 100; x++) {
3269   //     A[x] = B[x];
3270   //     B[x] = x * 2;
3271   //   }
3272   // Should not be modified.
3273 
3274   const int kTotalSize = 100;
3275   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3276   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3277   VarHandle x("x", kInt);
3278   auto for_body = Block::make(
3279       {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3280        Store::make(b_buf, {x}, x * 2)});
3281   auto for_stmt = For::make(x, 0, 100, for_body);
3282   Block::make({for_stmt});
3283 
3284   LoopNest::normalize(for_stmt);
3285 
3286   auto result = IRSimplifier::simplify(for_stmt);
3287   std::ostringstream oss;
3288   oss << *result;
3289   const std::string& expected_ir =
3290       R"IR(
3291         # CHECK: for (int x = 0; x < 100; x++) {
3292         # CHECK:   A[x] = B[x];
3293         # CHECK:   B[x] = 2 * x;
3294       )IR";
3295   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3296 }
3297 
TEST(LoopNest,NormalizeStartVariable)3298 TEST(LoopNest, NormalizeStartVariable) {
3299   // Input IR:
3300   //   for (int x = y; x < 100; x++) {
3301   //     A[x] = B[x];
3302   //     B[x] = x * 2;
3303   //   }
3304 
3305   const int kTotalSize = 100;
3306   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3307   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3308   VarHandle x("x", kInt);
3309   VarHandle y("y", kInt);
3310   auto for_body = Block::make(
3311       {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3312        Store::make(b_buf, {x}, x * 2)});
3313   auto for_stmt = For::make(x, y, 100, for_body);
3314   auto parent_block = Block::make({for_stmt});
3315 
3316   LoopNest::normalize(for_stmt);
3317 
3318   auto result = IRSimplifier::simplify(for_stmt);
3319   std::ostringstream oss;
3320   oss << *result;
3321   const std::string& expected_ir =
3322       R"IR(
3323         # CHECK: for (int x = 0; x < 100 - y; x++) {
3324         # CHECK:   A[x + y] = B[x + y];
3325         # CHECK:   B[x + y] = 2 * (x + y);
3326       )IR";
3327   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3328 }
3329 
TEST(LoopNest,NormalizeOnNestedOuterLoop)3330 TEST(LoopNest, NormalizeOnNestedOuterLoop) {
3331   // Input IR:
3332   //   for (int x = 50; x < 100; x++) {
3333   //     for (int y = 10; y < 100; y++) {
3334   //       A[x] = A[x] + B[y] + y * 2;
3335   //     }
3336   //   }
3337 
3338   BufHandle a_buf("A", {ExprHandle(50)}, kInt);
3339   BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3340   VarHandle x("x", kInt);
3341   VarHandle y("y", kInt);
3342   auto inner_for_body = Store::make(
3343       a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
3344   auto inner_for = For::make(y, 10, 100, inner_for_body);
3345   auto for_stmt = For::make(x, 50, 100, inner_for);
3346   Block::make({for_stmt});
3347 
3348   LoopNest::normalize(for_stmt);
3349 
3350   auto result = IRSimplifier::simplify(for_stmt);
3351   std::ostringstream oss;
3352   oss << *result;
3353   const std::string& expected_ir =
3354       R"IR(
3355         # CHECK: for (int x = 0; x < 50; x++) {
3356         # CHECK:   for (int y = 10; y < 100; y++) {
3357         # CHECK:     A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y;
3358       )IR";
3359   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3360 }
3361 
TEST(LoopNest,NormalizeOnNestedInnerLoop)3362 TEST(LoopNest, NormalizeOnNestedInnerLoop) {
3363   // Input IR:
3364   //   for (int x = 50; x < 100; x++) {
3365   //     for (int y = 10; y < 100; y++) {
3366   //       A[x] = A[x] + B[y] + y * 2;
3367   //     }
3368   //   }
3369 
3370   BufHandle a_buf("A", {ExprHandle(50)}, kInt);
3371   BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3372   VarHandle x("x", kInt);
3373   VarHandle y("y", kInt);
3374   auto inner_for_body = Store::make(
3375       a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
3376   auto inner_for = For::make(y, 10, 100, inner_for_body);
3377   auto for_stmt = For::make(x, 50, 100, inner_for);
3378   Block::make({for_stmt});
3379 
3380   LoopNest::normalize(inner_for);
3381 
3382   auto result = IRSimplifier::simplify(for_stmt);
3383   std::ostringstream oss;
3384   oss << *result;
3385   const std::string& expected_ir =
3386       R"IR(
3387         # CHECK: for (int x = 50; x < 100; x++) {
3388         # CHECK:   for (int y = 0; y < 90; y++) {
3389         # CHECK:     A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20;
3390       )IR";
3391   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3392 }
3393 
TEST(LoopNest,NormalizeAndSplitWithTail)3394 TEST(LoopNest, NormalizeAndSplitWithTail) {
3395   // Create a dummy tensor to construct LoopNest.
3396   ExprHandle n(100);
3397   BufHandle a("a", {n}, kFloat);
3398   Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
3399   LoopNest l({b});
3400 
3401   // Input IR:
3402   //   for (int x = 5; x < 10; x++) {
3403   //     A[x] = x * 2;
3404   //   }
3405   const int kTotalSize = 5;
3406   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3407   VarHandle x("x", kInt);
3408   auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2));
3409   auto parent_block = Block::make({for_stmt});
3410 
3411   LoopNest::normalize(for_stmt);
3412 
3413   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3414   ForPtr x_inner;
3415   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3416   ForPtr x_tail;
3417   LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail);
3418 
3419   auto x_outer_result = IRSimplifier::simplify(for_stmt);
3420   std::ostringstream oss_outer;
3421   oss_outer << *x_outer_result;
3422   const std::string& expected_outer_ir =
3423       R"IR(
3424         # CHECK: {
3425         # CHECK: }
3426       )IR";
3427   torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
3428 
3429   auto x_tail_result = IRSimplifier::simplify(x_tail);
3430   std::ostringstream oss_tail;
3431   oss_tail << *x_tail_result;
3432   const std::string& expected_tail_ir =
3433       R"IR(
3434         # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) {
3435         # CHECK:   A[x_tail + 5] = 2 * (x_tail + 5);
3436       )IR";
3437   torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
3438 }
3439 
TEST(LoopNest,NotNormalizeAndSplitWithTail)3440 TEST(LoopNest, NotNormalizeAndSplitWithTail) {
3441   // Create a dummy tensor to construct LoopNest.
3442   ExprHandle n(100);
3443   BufHandle a("a", {n}, kFloat);
3444   Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
3445   LoopNest l({b});
3446 
3447   // Input IR:
3448   //   for (int x = 5; x < 15; x++) {
3449   //     A[x] = x * 2;
3450   //   }
3451   const int kTotalSize = 10;
3452   BufHandle a_buf("A", {kTotalSize}, kInt);
3453   VarHandle x("x", kInt);
3454   auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2));
3455   auto parent_block = Block::make({for_stmt});
3456 
3457   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3458   ForPtr x_inner;
3459   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3460   ForPtr x_tail;
3461   LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail);
3462 
3463   auto x_outer_result = IRSimplifier::simplify(for_stmt);
3464   std::ostringstream oss_outer;
3465   oss_outer << *x_outer_result;
3466   const std::string& expected_outer_ir =
3467       R"IR(
3468         # CHECK: {
3469         # CHECK: }
3470       )IR";
3471   torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
3472 
3473   auto x_tail_result = IRSimplifier::simplify(x_tail);
3474   std::ostringstream oss_tail;
3475   oss_tail << *x_tail_result;
3476   const std::string& expected_tail_ir =
3477       R"IR(
3478         # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) {
3479         # CHECK:   A[x_tail + 13] = 2 * (x_tail + 13);
3480       )IR";
3481   torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
3482 }
3483 
TEST(LoopNest,FlattenSimpleLoopNest2D)3484 TEST(LoopNest, FlattenSimpleLoopNest2D) {
3485   // Input IR:
3486   //   for (int i = 0; i < 10; i++) {
3487   //     for (int j = 0; j < 5; j++) {
3488   //       A[i,j] = i * j;
3489   //     }
3490   //   }
3491   BufHandle a_buf("A", {10, 5}, kInt);
3492   VarHandle i("i", kInt);
3493   VarHandle j("j", kInt);
3494   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3495   auto inner_for = For::make(j, 0, 5, for_body);
3496   auto outer_for = For::make(i, 0, 10, inner_for);
3497   auto parent_block = Block::make({outer_for});
3498 
3499   std::vector<ForPtr> loops = {outer_for, inner_for};
3500   ForPtr flattened = nullptr;
3501   ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3502   ASSERT_EQ(flattened, loops.front());
3503 
3504   auto result = IRSimplifier::simplify(flattened);
3505   std::ostringstream oss;
3506   oss << *result;
3507   const std::string& expected_ir =
3508       R"IR(
3509         # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
3510         # CHECK:   A[i_flat / 5, i_flat % 5] =
3511       )IR";
3512   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3513 
3514   {
3515     SimpleIREvaluator eval1(loops[0], {a_buf});
3516     PaddedBuffer<int> inp1(10, 5);
3517     eval1(inp1);
3518     SimpleIREvaluator eval2(flattened, {a_buf});
3519     PaddedBuffer<int> inp2(10, 5);
3520     eval2(inp2);
3521     ExpectAllNear(inp1, inp2, 1e-5);
3522   }
3523 }
3524 
TEST(LoopNest,FlattenSimpleLoopNest3D)3525 TEST(LoopNest, FlattenSimpleLoopNest3D) {
3526   // Input IR:
3527   //   for (int i = 0; i < 10; i++) {
3528   //     for (int j = 0; j < 5; j++) {
3529   //       for (int k = 0; k < 7; k++) {
3530   //         A[i,j,k] = i + j * k;
3531   //       }
3532   //     }
3533   //   }
3534   BufHandle a_buf("A", {10, 5, 7}, kInt);
3535   VarHandle i("i", kInt);
3536   VarHandle j("j", kInt);
3537   VarHandle k("k", kInt);
3538   auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)});
3539   auto for1 = For::make(k, 0, 7, for_body);
3540   auto for2 = For::make(j, 0, 5, for1);
3541   auto for3 = For::make(i, 0, 10, for2);
3542   auto parent_block = Block::make({for3});
3543 
3544   std::vector<ForPtr> loops = {for3, for2, for1};
3545   ForPtr flattened = nullptr;
3546   ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3547   ASSERT_EQ(flattened, loops.front());
3548 
3549   auto result = IRSimplifier::simplify(flattened);
3550   std::ostringstream oss;
3551   oss << *result;
3552   const std::string& expected_ir =
3553       R"IR(
3554         # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) {
3555         # CHECK:   A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] =
3556       )IR";
3557   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3558 
3559   {
3560     SimpleIREvaluator eval1(loops[0], {a_buf});
3561     PaddedBuffer<int> inp1(10, 5, 7);
3562     eval1(inp1);
3563     SimpleIREvaluator eval2(flattened, {a_buf});
3564     PaddedBuffer<int> inp2(10, 5, 7);
3565     eval2(inp2);
3566     ExpectAllNear(inp1, inp2, 1e-5);
3567   }
3568 }
3569 
TEST(LoopNest,FlattenLoopNestAfterNormalize)3570 TEST(LoopNest, FlattenLoopNestAfterNormalize) {
3571   // Input IR:
3572   //   for (int i = 2; i < 10; i++) {
3573   //     for (int j = 3; j < 15; j++) {
3574   //       A[i - 2,j - 3] = i * j;
3575   //     }
3576   //   }
3577   BufHandle a_buf("A", {8, 12}, kInt);
3578   VarHandle i("i", kInt);
3579   VarHandle j("j", kInt);
3580   auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)});
3581   auto inner_for = For::make(j, 3, 15, for_body);
3582   auto outer_for = For::make(i, 2, 10, inner_for);
3583   auto parent_block = Block::make({outer_for});
3584 
3585   std::vector<ForPtr> loops = {outer_for, inner_for};
3586   ForPtr flattened = nullptr;
3587   ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3588   ASSERT_EQ(flattened, loops.front());
3589 
3590   auto result = IRSimplifier::simplify(flattened);
3591   std::ostringstream oss;
3592   oss << *result;
3593   const std::string& expected_ir =
3594       R"IR(
3595         # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) {
3596         # CHECK:   A[i_flat / 12, i_flat % 12] =
3597       )IR";
3598   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3599 
3600   {
3601     SimpleIREvaluator eval1(loops[0], {a_buf});
3602     PaddedBuffer<int> inp1(8, 12);
3603     eval1(inp1);
3604     SimpleIREvaluator eval2(flattened, {a_buf});
3605     PaddedBuffer<int> inp2(8, 12);
3606     eval2(inp2);
3607     ExpectAllNear(inp1, inp2, 1e-5);
3608   }
3609 }
3610 
TEST(LoopNest,FlattenLoopNestWithNonLiteralConstantBounds)3611 TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) {
3612   // Input IR:
3613   //   for (int i = 0; i < 15-5; i++) {
3614   //     for (int j = 0; j < 20/4; j++) {
3615   //       A[i,j] = i * j;
3616   //     }
3617   //   }
3618   BufHandle a_buf("A", {10, 5}, kInt);
3619   VarHandle i("i", kInt);
3620   VarHandle j("j", kInt);
3621   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3622   auto inner_for =
3623       For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body);
3624   auto outer_for =
3625       For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for);
3626   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
3627   auto b = Block::make({outer_for});
3628 
3629   std::vector<ForPtr> loops = {outer_for, inner_for};
3630   ForPtr flattened = nullptr;
3631   ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3632   ASSERT_EQ(flattened, loops.front());
3633 
3634   auto result = IRSimplifier::simplify(flattened);
3635   checkIR(result, R"IR(
3636         # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
3637         # CHECK:   A[i_flat / 5, i_flat % 5] =
3638       )IR");
3639 
3640   {
3641     SimpleIREvaluator eval1(loops[0], {a_buf});
3642     PaddedBuffer<int> inp1(10, 5);
3643     eval1(inp1);
3644     SimpleIREvaluator eval2(flattened, {a_buf});
3645     PaddedBuffer<int> inp2(10, 5);
3646     eval2(inp2);
3647     ExpectAllNear(inp1, inp2, 1e-5);
3648   }
3649 }
3650 
TEST(LoopNest,FlattenImperfectLoopNest)3651 TEST(LoopNest, FlattenImperfectLoopNest) {
3652   // Input IR:
3653   //   for (int i = 0; i < 10; i++) {
3654   //     A[i, i] = 0;
3655   //     for (int j = 0; j < 15; j++) {
3656   //       A[i,j] = i * j;
3657   //     }
3658   //   }
3659   // Do not flatten.
3660 
3661   BufHandle a_buf("A", {10, 15}, kInt);
3662   VarHandle i("i", kInt);
3663   VarHandle j("j", kInt);
3664   auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3665   auto inner_for = For::make(j, 0, 15, for_body);
3666   auto outer_for = For::make(
3667       i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for}));
3668   auto par = Block::make({outer_for});
3669   HashProvider hasher;
3670   auto hash_before = hasher.hash(par);
3671 
3672   std::vector<ForPtr> loops = {outer_for, inner_for};
3673   ForPtr flattened = nullptr;
3674   ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3675   ASSERT_EQ(flattened, nullptr);
3676   auto hash_after = hasher.hash(par);
3677   ASSERT_EQ(hash_before, hash_after);
3678 }
3679 
TEST(LoopNest,FlattenReductionLoopNest)3680 TEST(LoopNest, FlattenReductionLoopNest) {
3681   // Input IR:
3682   //   for (int i = 0; i < 10; i++) {
3683   //     S[i] = 0;
3684   //     for (int j = 0; j < 15; j++) {
3685   //       S[i] = S[i] + A[i,j];
3686   //     }
3687   //   }
3688   // Do not flatten.
3689 
3690   BufHandle a_buf("A", {10, 15}, kInt);
3691   BufHandle s_buf("S", {10}, kInt);
3692   VarHandle i("i", kInt);
3693   VarHandle j("j", kInt);
3694   auto for_body = Block::make({Store::make(
3695       s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))});
3696   auto inner_for = For::make(j, 0, 15, for_body);
3697   auto outer_for =
3698       For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for}));
3699   auto par = Block::make({outer_for});
3700   HashProvider hasher;
3701   auto hash_before = hasher.hash(par);
3702 
3703   std::vector<ForPtr> loops = {outer_for, inner_for};
3704   ForPtr flattened = nullptr;
3705   ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3706   ASSERT_EQ(flattened, nullptr);
3707   auto hash_after = hasher.hash(par);
3708   ASSERT_EQ(hash_before, hash_after);
3709 }
3710 
TEST(LoopNest,FlattenReductionLoopNestFromTensor)3711 TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
3712   const int M = 3;
3713   const int N = 7;
3714   VarHandle m("m", kInt);
3715   VarHandle n("n", kInt);
3716   BufHandle b("b", {m, n}, kFloat);
3717   Tensor c = Reduce("sum", {M}, Sum(), b, {N});
3718   LoopNest loop({c});
3719   HashProvider hasher;
3720   auto hash_before = hasher.hash(loop.root_stmt());
3721 
3722   auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1];
3723   ForPtr flattened = nullptr;
3724   ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3725   ASSERT_EQ(flattened, nullptr);
3726   auto hash_after = hasher.hash(loop.root_stmt());
3727   ASSERT_EQ(hash_before, hash_after);
3728 }
3729 
TEST(LoopNest,FlattenIncorrectLoopsAsInput)3730 TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
3731   // Input IR:
3732   //   for (int i = 0; i < 10; i++) {
3733   //     for (int j = 0; j < 5; j++) {
3734   //       A[i,j] = i * j;
3735   //     }
3736   //   }
3737   //   for (int x = 0; x < 10; x++) {
3738   //     for (int y = 0; y < 5; y++) {
3739   //       A[x,y] = A[x,y] + x + y;
3740   //     }
3741   //   }
3742   // Flatten({For_i, For_y}) => should not succeed
3743 
3744   BufHandle a_buf("A", {10, 5}, kInt);
3745   VarHandle i("i", kInt);
3746   VarHandle j("j", kInt);
3747   VarHandle x("x", kInt);
3748   VarHandle y("y", kInt);
3749   auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
3750   auto inner_for1 = For::make(j, 0, 5, for_body1);
3751   auto outer_for1 = For::make(i, 0, 10, inner_for1);
3752   auto for_body2 = Block::make(
3753       {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
3754   auto inner_for2 = For::make(y, 0, 5, for_body2);
3755   auto outer_for2 = For::make(x, 0, 10, inner_for2);
3756   auto par = Block::make({outer_for1, outer_for2});
3757   HashProvider hasher;
3758   auto hash_before = hasher.hash(par);
3759 
3760   std::vector<ForPtr> loops = {outer_for1, inner_for2};
3761   ForPtr flattened = nullptr;
3762   ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3763   ASSERT_EQ(flattened, nullptr);
3764   auto hash_after = hasher.hash(par);
3765   ASSERT_EQ(hash_before, hash_after);
3766 }
3767 
TEST(LoopNest,DetectInlineRankMismatch)3768 TEST(LoopNest, DetectInlineRankMismatch) {
3769   const int kTotalSize = 8;
3770 
3771   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
3772   Tensor a = Compute(
3773       "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); });
3774   Tensor reshape = Compute(
3775       "reshape",
3776       {kTotalSize / 2, 2},
3777       [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); });
3778   LoopNest l({reshape}, {a, reshape});
3779   ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a)));
3780 }
3781 
TEST(LoopNest,CacheReadsSimple)3782 TEST(LoopNest, CacheReadsSimple) {
3783   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3784     return i * j;
3785   });
3786   Tensor B =
3787       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3788         return A.load(i + 30, j + 3);
3789       });
3790   Tensor C =
3791       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3792         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3793       });
3794 
3795   LoopNest l({B, C}, {A, B, C});
3796   StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
3797   LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
3798 
3799   l.prepareForCodegen();
3800   StmtPtr result =
3801       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3802   SimpleIREvaluator cg(result, {B, C});
3803   result = cg.stmt();
3804 
3805   // just this once: verify the whole thing.
3806   checkIR(result, R"IR(
3807 #CHECK: Allocate(A); // dtype=int, dims=[64, 64]
3808 #CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
3809 #CHECK: for (int i
3810 #CHECK:  for (int j
3811 #CHECK:   A[
3812 #CHECK:  }
3813 #CHECK: }
3814 #CHECK: for (int i_1
3815 #CHECK:  for (int j_1
3816 #CHECK:   A_local[j_1] = A[
3817 #CHECK:  }
3818 #CHECK:  for (int j_2
3819 #CHECK:   B[j_2 + 10 * i_1] = A_local[j_2];
3820 #CHECK:  }
3821 #CHECK: }
3822 #CHECK: for (int i_2
3823 #CHECK:  for (int j_3
3824 #CHECK:   C[
3825 #CHECK:  }
3826 #CHECK: }
3827 #CHECK: Free(A_local);
3828 #CHECK: Free(A);
3829       )IR");
3830 
3831   std::vector<int> b_data(200, 0);
3832   std::vector<int> c_data(200, 0);
3833   cg.call({b_data, c_data});
3834 
3835   std::vector<int> b_ref(200, 0);
3836   std::vector<int> c_ref(200, 0);
3837 
3838   for (int i = 0; i < 20; ++i) {
3839     for (int j = 0; j < 10; ++j) {
3840       b_ref[i * 10 + j] = (i + 30) * (j + 3);
3841       c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3842     }
3843   }
3844 
3845   assertAllEqual(b_data, b_ref);
3846   assertAllEqual(c_data, c_ref);
3847 }
3848 
TEST(LoopNest,CacheReadsOuter)3849 TEST(LoopNest, CacheReadsOuter) {
3850   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3851     return i * j;
3852   });
3853   Tensor B =
3854       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3855         return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3856       });
3857   Tensor C =
3858       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3859         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3860       });
3861 
3862   LoopNest l({B, C}, {A, B, C});
3863   StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0];
3864   LoopNest::cacheAccesses(A.buf(), "A_local", i_loop);
3865 
3866   l.prepareForCodegen();
3867   StmtPtr result =
3868       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3869   SimpleIREvaluator cg(result, {B, C});
3870   result = cg.stmt();
3871 
3872   checkIR(result, R"IR(
3873 #CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
3874 #CHECK: A_local[j_1 + 11 * i_1] =
3875 #CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]);
3876       )IR");
3877 
3878   std::vector<int> b_data(200, 0);
3879   std::vector<int> c_data(200, 0);
3880   cg.call({b_data, c_data});
3881 
3882   std::vector<int> b_ref(200, 0);
3883   std::vector<int> c_ref(200, 0);
3884 
3885   for (int i = 0; i < 20; ++i) {
3886     for (int j = 0; j < 10; ++j) {
3887       b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
3888       c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3889     }
3890   }
3891 
3892   assertAllEqual(b_data, b_ref);
3893   assertAllEqual(c_data, c_ref);
3894 }
3895 
TEST(LoopNest,CacheReadsInternal)3896 TEST(LoopNest, CacheReadsInternal) {
3897   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3898     return i * j;
3899   });
3900   Tensor B =
3901       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3902         return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3903       });
3904   Tensor C =
3905       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3906         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3907       });
3908 
3909   LoopNest l({B, C}, {A, B, C});
3910   StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
3911   LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
3912   l.prepareForCodegen();
3913   StmtPtr result =
3914       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3915   SimpleIREvaluator cg(result, {B, C});
3916   result = cg.stmt();
3917 
3918   checkIR(result, R"IR(
3919 #CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
3920 #CHECK: A_local[k + 11 * j_1] =
3921 #CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
3922       )IR");
3923 
3924   std::vector<int> b_data(200, 0);
3925   std::vector<int> c_data(200, 0);
3926   cg.call({b_data, c_data});
3927 
3928   std::vector<int> b_ref(200, 0);
3929   std::vector<int> c_ref(200, 0);
3930 
3931   for (int i = 0; i < 20; ++i) {
3932     for (int j = 0; j < 10; ++j) {
3933       b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
3934       c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3935     }
3936   }
3937 
3938   assertAllEqual(b_data, b_ref);
3939   assertAllEqual(c_data, c_ref);
3940 }
3941 
TEST(LoopNest,CacheReadsInner)3942 TEST(LoopNest, CacheReadsInner) {
3943   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3944     return i * j;
3945   });
3946   // note im changing the offset of the first arg of the first call to A.
3947   Tensor B =
3948       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3949         return A.load(i + 34, j + 40) + A.load(i + 30, j + 41);
3950       });
3951   Tensor C =
3952       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3953         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3954       });
3955 
3956   LoopNest l({B, C}, {A, B, C});
3957   StmtPtr body = l.getLoopBodyFor(B);
3958   LoopNest::cacheAccesses(A.buf(), "A_local", body);
3959   l.prepareForCodegen();
3960   StmtPtr result =
3961       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3962   SimpleIREvaluator cg(result, {B, C});
3963   result = cg.stmt();
3964 
3965   checkIR(result, R"IR(
3966 #CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
3967 #CHECK: A_local[l + 2 * k] =
3968 #CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
3969       )IR");
3970 
3971   std::vector<int> b_data(200, 0);
3972   std::vector<int> c_data(200, 0);
3973   cg.call({b_data, c_data});
3974 
3975   std::vector<int> b_ref(200, 0);
3976   std::vector<int> c_ref(200, 0);
3977 
3978   for (int i = 0; i < 20; ++i) {
3979     for (int j = 0; j < 10; ++j) {
3980       b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41);
3981       c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3982     }
3983   }
3984 
3985   assertAllEqual(b_data, b_ref);
3986   assertAllEqual(c_data, c_ref);
3987 }
3988 
TEST(LoopNest,CacheWritesSimple)3989 TEST(LoopNest, CacheWritesSimple) {
3990   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3991     return i * j;
3992   });
3993   Tensor B =
3994       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3995         return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3996       });
3997   Tensor C =
3998       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3999         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
4000       });
4001 
4002   LoopNest l({B, C}, {A, B, C});
4003   StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1];
4004   LoopNest::cacheAccesses(A.buf(), "A_local", a_loop);
4005 
4006   l.prepareForCodegen();
4007   StmtPtr result =
4008       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
4009   SimpleIREvaluator cg(result, {B, C});
4010   result = cg.stmt();
4011 
4012   checkIR(result, R"IR(
4013 #CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
4014 #CHECK: for (int j = 0; j < 64
4015 #CHECK:   A_local[j] = i * j;
4016 #CHECK: for (int j_1 = 0; j_1 < 64
4017 #CHECK:   A[j_1 + 64 * i] = A_local[
4018 #CHECK: Free(A_local);
4019 #CHECK-NOT: A_local
4020       )IR");
4021 
4022   std::vector<int> b_data(200, 0);
4023   std::vector<int> c_data(200, 0);
4024   cg.call({b_data, c_data});
4025 
4026   std::vector<int> b_ref(200, 0);
4027   std::vector<int> c_ref(200, 0);
4028 
4029   for (int i = 0; i < 20; ++i) {
4030     for (int j = 0; j < 10; ++j) {
4031       b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
4032       c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
4033     }
4034   }
4035 
4036   assertAllEqual(b_data, b_ref);
4037   assertAllEqual(c_data, c_ref);
4038 }
4039 
TEST(LoopNest,DeadStoreElimination)4040 TEST(LoopNest, DeadStoreElimination) {
4041   VarHandle y("y", kInt);
4042   VarHandle x("x_tail", kInt);
4043   BufHandle f("f", {26, 5}, kInt);
4044   BufHandle g("g", {26, 5}, kInt);
4045   ExprHandle x_outer_end = 5;
4046   ExprHandle x_2 = x + x_outer_end * 4;
4047   ForPtr stmt1 = For::make(
4048       x,
4049       0,
4050       5,
4051       For::make(
4052           y,
4053           0,
4054           5,
4055           Block::make({
4056               Store::make(f, {x_2, y}, (x_2 + y)),
4057               Store::make(g, {x_2, y}, (x_2 * y)),
4058           })));
4059   StmtPtr stmt = Block::make({stmt1});
4060 
4061   // Will eliminate if not used by an output.
4062   LoopNest loop(Stmt::clone(stmt), {f.node()});
4063   loop.eliminateDeadStores();
4064 
4065   checkIR(loop.root_stmt(), R"IR(
4066 #CHECK:     f[x_tail + 5 * 4, y]
4067 #CHECK-NOT: g[x_tail + 5 * 4, y]
4068       )IR");
4069 
4070   // But won't eliminate if used by different outputs.
4071   LoopNest loop2(stmt, {f.node(), g.node()});
4072   loop2.eliminateDeadStores();
4073 
4074   checkIR(loop2.root_stmt(), R"IR(
4075 #CHECK:     f[x_tail + 5 * 4, y]
4076 #CHECK:     g[x_tail + 5 * 4, y]
4077       )IR");
4078 }
4079 
TEST(LoopNest,DeadStoreEliminationWithIntermediates)4080 TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
4081   VarHandle x("x", kInt);
4082   VarHandle y("y", kInt);
4083   VarHandle z("z", kInt);
4084   BufHandle f("f", {26 * 5}, kInt);
4085   BufHandle g("g", {26 * 5}, kInt);
4086   BufHandle h("h", {26, 5}, kInt);
4087   ExprHandle x_outer_end = 5;
4088   ExprHandle x_2 = x + x_outer_end * 4;
4089   ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x));
4090   ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1));
4091   ForPtr stmt3 = For::make(
4092       x,
4093       0,
4094       5,
4095       For::make(
4096           y,
4097           0,
4098           5,
4099           Block::make({
4100               Store::make(h, {x, y}, Load::make(f, {x * y})),
4101           })));
4102   StmtPtr stmt = Block::make({stmt1, stmt2, stmt3});
4103 
4104   // Will eliminate the write to g, but not f since it used by the producer of
4105   // h.
4106   LoopNest loop(Stmt::clone(stmt), {h.node()});
4107   loop.eliminateDeadStores();
4108 
4109   checkIR(loop.root_stmt(), R"IR(
4110   #CHECK:     f[x] = x;
4111   #CHECK-NOT: g[z] =
4112   #CHECK:     h[x, y] = f[x * y];
4113       )IR");
4114 
4115   // Sanity check won't eliminate if g is an output.
4116   LoopNest loop2(stmt, {h.node(), g.node()});
4117   loop2.eliminateDeadStores();
4118 
4119   checkIR(loop2.root_stmt(), R"IR(
4120   #CHECK:     f[x] = x;
4121   #CHECK:     g[z] = z + 1;
4122   #CHECK:     h[x, y] = f[x * y];
4123       )IR");
4124 }
4125 
TEST(LoopNest,CompoundTensorSimple)4126 TEST(LoopNest, CompoundTensorSimple) {
4127   BufHandle a_buf("A", {10, 5}, kInt);
4128   VarHandle i("i", kInt);
4129   VarHandle j("j", kInt);
4130   VarHandle x("x", kInt);
4131   VarHandle y("y", kInt);
4132   auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
4133   auto inner_for1 = For::make(j, 0, 5, for_body1);
4134   auto outer_for1 = For::make(i, 0, 10, inner_for1);
4135   auto for_body2 = Block::make(
4136       {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
4137   auto inner_for2 = For::make(y, 0, 5, for_body2);
4138   auto outer_for2 = For::make(x, 0, 10, inner_for2);
4139   BlockPtr body = Block::make({outer_for1, outer_for2});
4140 
4141   Tensor A = Tensor(a_buf.node(), body);
4142 
4143   LoopNest l({A});
4144   l.prepareForCodegen();
4145 
4146   std::vector<int> a_data(50, 0);
4147 
4148   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
4149   SimpleIREvaluator cg(s, {A});
4150 
4151   std::vector<int> a_ref(50, 0);
4152 
4153   for (int i = 0; i < 10; ++i) {
4154     for (int j = 0; j < 5; ++j) {
4155       a_ref[i * 5 + j] = (i * j) + i + j;
4156     }
4157   }
4158   cg.call({a_data});
4159 
4160   assertAllEqual(a_data, a_ref);
4161 }
4162 
TEST(LoopNest,InlineConstantIndex)4163 TEST(LoopNest, InlineConstantIndex) {
4164   const int N = 10;
4165   BufHandle x_buf("a", {1, N, 1}, kFloat);
4166   Tensor y = Compute(
4167       "f",
4168       {1, N, 1},
4169       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
4170         return x_buf.load(m, n, o);
4171       });
4172   Tensor z = Compute(
4173       "f",
4174       {1, N, 1},
4175       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
4176         return y.load(m, n, o);
4177       });
4178 
4179   LoopNest l({z}, {y, z});
4180   l.simplify();
4181   ASSERT_TRUE(l.computeInline(y.buf()));
4182 }
4183 
TEST(LoopNest,CompoundTensorUsed)4184 TEST(LoopNest, CompoundTensorUsed) {
4185   BufHandle a_buf("A", {10, 5}, kInt);
4186   VarHandle i("i", kInt);
4187   VarHandle j("j", kInt);
4188   VarHandle x("x", kInt);
4189   VarHandle y("y", kInt);
4190   auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
4191   auto inner_for1 = For::make(j, 0, 5, for_body1);
4192   auto outer_for1 = For::make(i, 0, 10, inner_for1);
4193   auto for_body2 = Block::make(
4194       {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
4195   auto inner_for2 = For::make(y, 0, 5, for_body2);
4196   auto outer_for2 = For::make(x, 0, 10, inner_for2);
4197   BlockPtr body = Block::make({outer_for1, outer_for2});
4198 
4199   Tensor A = Tensor(a_buf.node(), body);
4200   Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) {
4201     return A.load(i, j + 1) + A.load(i, j + 2);
4202   });
4203 
4204   LoopNest l({B}, {A, B});
4205   ASSERT_FALSE(l.computeInline(A.buf()));
4206   l.prepareForCodegen();
4207 
4208   std::vector<int> a_data(50, 0);
4209   std::vector<int> b_data(50, 0);
4210 
4211   StmtPtr s = IRSimplifier::simplify(l.root_stmt());
4212   SimpleIREvaluator cg(s, {B});
4213 
4214   std::vector<int> b_ref(50, 0);
4215 
4216   auto AT = [](int i, int j) { return i * j + i + j; };
4217   for (int i = 0; i < 10; ++i) {
4218     for (int j = 0; j < 3; ++j) {
4219       b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2);
4220     }
4221   }
4222   cg.call({b_data});
4223 
4224   assertAllEqual(b_data, b_ref);
4225 }
4226 
TEST(LoopNest,InlineFromLoad)4227 TEST(LoopNest, InlineFromLoad) {
4228   constexpr int N = 1024;
4229   BufHandle a("A", {N}, kInt);
4230   BufHandle b("B", {N}, kInt);
4231   VarHandle i("i", kInt);
4232   VarHandle j("j", kInt);
4233   auto store_a = For::make(i, 0, N, Store::make(a, {i}, i));
4234   auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j})));
4235   LoopNest l(Block::make({store_a, store_b}), {b.node()});
4236 
4237   l.computeInline(a.node());
4238 
4239   // Check that A[j] is replaced with j after inlining
4240   std::ostringstream oss;
4241   oss << *l.root_stmt();
4242   torch::jit::testing::FileCheck().run(
4243       R"IR(
4244 # CHECK: for (int j
4245 # CHECK-NOT: B[j] = A[j]
4246 # CHECK-NEXT: B[j] = j
4247 )IR",
4248       oss.str());
4249 }
4250 
TEST(LoopNest,OptimizeConditionalsSimple)4251 TEST(LoopNest, OptimizeConditionalsSimple) {
4252   // Input IR:
4253   //   for (int i = 0; i < 20; i++) {
4254   //     A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4255   //   }
4256 
4257   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4258   BufHandle a_buf("A", {20}, kInt);
4259   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4260   BufHandle b_buf("B", {5}, kInt);
4261   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4262   BufHandle c_buf("C", {15}, kInt);
4263   VarHandle i("i", kInt);
4264   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4265   auto store = Store::make(
4266       a_buf,
4267       {i},
4268       IfThenElse::make(
4269           CompareSelect::make(i, 5, kLT),
4270           Load::make(b_buf, {i}),
4271           Load::make(c_buf, {i - 5})));
4272   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4273   auto forI = For::make(i, 0, 20, store);
4274   auto par = Block::make({forI});
4275 
4276   LoopNest nest(par, {a_buf.node()});
4277   nest.optimizeConditionals();
4278 
4279   std::ostringstream oss;
4280   oss << *nest.root_stmt();
4281   const std::string& verification_pattern =
4282       R"IR(
4283 # CHECK: for (int i = 0; i < 5
4284 # CHECK-NEXT: A[i] = B[i]
4285 # CHECK: for (int i = 0; i < 15
4286 # CHECK-NEXT: A[i + 5] = C[i]
4287       )IR";
4288   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4289 }
4290 
TEST(LoopNest,OptimizeConditionalsNestedConditions)4291 TEST(LoopNest, OptimizeConditionalsNestedConditions) {
4292   // Input IR:
4293   //   for (int i = 0; i < 20; i++) {
4294   //     A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4295   //   }
4296 
4297   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4298   BufHandle a_buf("A", {20}, kInt);
4299   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4300   BufHandle b_buf("B", {5}, kInt);
4301   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4302   BufHandle c_buf("C", {5}, kInt);
4303   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4304   BufHandle d_buf("D", {10}, kInt);
4305   VarHandle i("i", kInt);
4306   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4307   auto store = Store::make(
4308       a_buf,
4309       {i},
4310       IfThenElse::make(
4311           CompareSelect::make(i, 10, kLT),
4312           IfThenElse::make(
4313               CompareSelect::make(i, 5, kLT),
4314               Load::make(b_buf, {i}),
4315               Load::make(c_buf, {i - 5})),
4316           Load::make(d_buf, {i - 10})));
4317   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4318   auto forI = For::make(i, 0, 20, store);
4319   auto par = Block::make({forI});
4320 
4321   LoopNest nest(par, {a_buf.node()});
4322   nest.optimizeConditionals();
4323 
4324   std::ostringstream oss;
4325   oss << *nest.root_stmt();
4326   const std::string& verification_pattern =
4327       R"IR(
4328 # CHECK: for (int i = 0; i < 5
4329 # CHECK-NEXT: A[i] = B[i]
4330 # CHECK: for (int i = 0; i < 5
4331 # CHECK-NEXT: A[i + 5] = C[i]
4332 # CHECK: for (int i = 0; i < 10
4333 # CHECK-NEXT: A[i + 10] = D[i]
4334       )IR";
4335   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4336 }
4337 
TEST(LoopNest,OptimizeConditionalsMultipleStores)4338 TEST(LoopNest, OptimizeConditionalsMultipleStores) {
4339   // Input IR:
4340   //   for (int i = 0; i < 20; i++) {
4341   //     A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4342   //   }
4343   //   for (int j = 0; j < 100; j++) {
4344   //     B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
4345   //   }
4346 
4347   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4348   BufHandle a_buf("A", {20}, kInt);
4349   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4350   BufHandle b_buf("B", {5}, kInt);
4351   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4352   BufHandle c_buf("C", {100}, kInt);
4353   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4354   BufHandle d_buf("D", {100}, kInt);
4355   VarHandle i("i", kInt);
4356   VarHandle j("j", kInt);
4357   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4358   auto storeA = Store::make(
4359       a_buf,
4360       {i},
4361       IfThenElse::make(
4362           CompareSelect::make(i, 5, kLT),
4363           Load::make(b_buf, {i}),
4364           Load::make(c_buf, {i - 5})));
4365   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4366   auto forI = For::make(i, 0, 20, storeA);
4367   auto storeB = Store::make(
4368       b_buf,
4369       {j},
4370       IfThenElse::make(
4371           CompareSelect::make(j, 30, kLT),
4372           Load::make(c_buf, {j}),
4373           Load::make(d_buf, {j})));
4374   auto forJ = For::make(j, 0, 100, storeB);
4375   auto par = Block::make({forI, forJ});
4376 
4377   LoopNest nest(par, {a_buf.node()});
4378   nest.optimizeConditionals();
4379 
4380   std::ostringstream oss;
4381   oss << *nest.root_stmt();
4382   const std::string& verification_pattern =
4383       R"IR(
4384 # CHECK: for (int i = 0; i < 5
4385 # CHECK-NEXT: A[i] = B[i]
4386 # CHECK: for (int i = 0; i < 15
4387 # CHECK-NEXT: A[i + 5] = C[i]
4388 # CHECK: for (int j = 0; j < 30
4389 # CHECK-NEXT: B[j] = C[j]
4390 # CHECK: for (int j = 0; j < 70
4391 # CHECK-NEXT: B[j + 30] = D[j + 30]
4392       )IR";
4393   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4394 }
4395 
TEST(LoopNest,OptimizeConditionalsMultipleStoresInOneLoop)4396 TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) {
4397   // Input IR:
4398   //   for (int i = 0; i < 50; i++) {
4399   //     A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4400   //     B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
4401   //   }
4402   // Only the first conditional, in the write to A, will be optimized.
4403 
4404   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4405   BufHandle a_buf("A", {100}, kInt);
4406   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4407   BufHandle b_buf("B", {100}, kInt);
4408   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4409   BufHandle c_buf("C", {100}, kInt);
4410   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4411   BufHandle d_buf("D", {100}, kInt);
4412   VarHandle i("i", kInt);
4413   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4414   auto storeA = Store::make(
4415       a_buf,
4416       {i},
4417       IfThenElse::make(
4418           CompareSelect::make(i, 5, kLT),
4419           Load::make(b_buf, {i}),
4420           Load::make(c_buf, {i - 5})));
4421   auto storeB = Store::make(
4422       b_buf,
4423       {i},
4424       IfThenElse::make(
4425           CompareSelect::make(i, 30, kLT),
4426           Load::make(c_buf, {i}),
4427           Load::make(d_buf, {i})));
4428   auto forI = For::make(i, 0, 50, Block::make({storeA, storeB}));
4429   auto par = Block::make({forI});
4430 
4431   LoopNest nest(par, {a_buf.node()});
4432   nest.optimizeConditionals();
4433 
4434   std::ostringstream oss;
4435   oss << *nest.root_stmt();
4436   const std::string& verification_pattern =
4437       R"IR(
4438 # CHECK: for (int i = 0; i < 5
4439 # CHECK-NEXT: A[i] = B[i]
4440 # CHECK-NEXT: B[i] = C[i]
4441 # CHECK: for (int i = 0; i < 45
4442 # CHECK-NEXT: A[i + 5] = C[i]
4443 # CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5])
4444       )IR";
4445   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4446 }
4447 
TEST(LoopNest,OptimizeConditionalsOuterLoopVar)4448 TEST(LoopNest, OptimizeConditionalsOuterLoopVar) {
4449   // Input IR:
4450   //   for (int i = 0; i < 20; i++) {
4451   //     for (int j = 0; j < 100; j++) {
4452   //       A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4453   //     }
4454   //   }
4455   // Currently, this case where the condition variable `i` is not the
4456   // inner-most loop variable, is not optimized.
4457 
4458   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4459   BufHandle a_buf("A", {20}, kInt);
4460   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4461   BufHandle b_buf("B", {5}, kInt);
4462   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4463   BufHandle c_buf("C", {5}, kInt);
4464   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4465   BufHandle d_buf("D", {10}, kInt);
4466   VarHandle i("i", kInt);
4467   VarHandle j("j", kInt);
4468   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4469   auto store = Store::make(
4470       a_buf,
4471       {i},
4472       IfThenElse::make(
4473           CompareSelect::make(i, 10, kLT),
4474           IfThenElse::make(
4475               CompareSelect::make(i, 5, kLT),
4476               Load::make(b_buf, {i}),
4477               Load::make(c_buf, {i - 5})),
4478           Load::make(d_buf, {i - 10})));
4479   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4480   auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store));
4481   auto par = Block::make({forI});
4482   LoopNest nest(par, {a_buf.node()});
4483 
4484   HashProvider hasher;
4485   auto hash_before = hasher.hash(nest.root_stmt());
4486   nest.optimizeConditionals();
4487   auto hash_after = hasher.hash(nest.root_stmt());
4488   ASSERT_EQ(hash_before, hash_after);
4489 }
4490 
TEST(LoopNest,OptimizeConditionalsCompValuesNotOrdered)4491 TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) {
4492   // Input IR:
4493   //   for (int i = 0; i < 20; i++) {
4494   //     A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10])
4495   //   }
4496   // No optimization should be done here because one of the conditions use '>'.
4497 
4498   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4499   BufHandle a_buf("A", {20}, kInt);
4500   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4501   BufHandle b_buf("B", {5}, kInt);
4502   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4503   BufHandle c_buf("C", {5}, kInt);
4504   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4505   BufHandle d_buf("D", {10}, kInt);
4506   VarHandle i("i", kInt);
4507   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4508   auto store = Store::make(
4509       a_buf,
4510       {i},
4511       IfThenElse::make(
4512           CompareSelect::make(i, 5, kLT),
4513           IfThenElse::make(
4514               CompareSelect::make(i, 10, kLT),
4515               Load::make(b_buf, {i}),
4516               Load::make(c_buf, {i - 5})),
4517           Load::make(d_buf, {i - 10})));
4518   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4519   auto forI = For::make(i, 0, 20, store);
4520   auto par = Block::make({forI});
4521   LoopNest nest(par, {a_buf.node()});
4522 
4523   HashProvider hasher;
4524   auto hash_before = hasher.hash(nest.root_stmt());
4525   nest.optimizeConditionals();
4526   auto hash_after = hasher.hash(nest.root_stmt());
4527   ASSERT_EQ(hash_before, hash_after);
4528 }
4529 
TEST(LoopNest,OptimizeConditionalsCompValuesNotConstants)4530 TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) {
4531   // Input IR:
4532   //   for (int i = 0; i < 20; i++) {
4533   //     A[i] = IfThenElse(i<N, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4534   //   }
4535   // No optimization should be done here because one of the conditions use '>'.
4536 
4537   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4538   BufHandle a_buf("A", {20}, kInt);
4539   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4540   BufHandle b_buf("B", {5}, kInt);
4541   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4542   BufHandle c_buf("C", {5}, kInt);
4543   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4544   BufHandle d_buf("D", {10}, kInt);
4545   VarHandle i("i", kInt);
4546   VarHandle N("N", kInt);
4547   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4548   auto store = Store::make(
4549       a_buf,
4550       {i},
4551       IfThenElse::make(
4552           CompareSelect::make(i, N, kLT),
4553           IfThenElse::make(
4554               CompareSelect::make(i, 5, kLT),
4555               Load::make(b_buf, {i}),
4556               Load::make(c_buf, {i - 5})),
4557           Load::make(d_buf, {i - 10})));
4558   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4559   auto forI = For::make(i, 0, 20, store);
4560   auto par = Block::make({forI});
4561   LoopNest nest(par, {a_buf.node()});
4562 
4563   HashProvider hasher;
4564   auto hash_before = hasher.hash(nest.root_stmt());
4565   nest.optimizeConditionals();
4566   auto hash_after = hasher.hash(nest.root_stmt());
4567   ASSERT_EQ(hash_before, hash_after);
4568 }
4569 
TEST(LoopNest,OptimizeConditionalsInvalidCondition)4570 TEST(LoopNest, OptimizeConditionalsInvalidCondition) {
4571   // Input IR:
4572   //   for (int i = 0; i < 20; i++) {
4573   //     A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10])
4574   //   }
4575   // No optimization should be done here because one of the conditions use '>'.
4576 
4577   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4578   BufHandle a_buf("A", {20}, kInt);
4579   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4580   BufHandle b_buf("B", {5}, kInt);
4581   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4582   BufHandle c_buf("C", {5}, kInt);
4583   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4584   BufHandle d_buf("D", {10}, kInt);
4585   VarHandle i("i", kInt);
4586   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4587   auto store = Store::make(
4588       a_buf,
4589       {i},
4590       IfThenElse::make(
4591           CompareSelect::make(i, 10, kLT),
4592           IfThenElse::make(
4593               CompareSelect::make(i, 5, kGT),
4594               Load::make(b_buf, {i}),
4595               Load::make(c_buf, {i - 5})),
4596           Load::make(d_buf, {i - 10})));
4597   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4598   auto forI = For::make(i, 0, 20, store);
4599   auto par = Block::make({forI});
4600   LoopNest nest(par, {a_buf.node()});
4601 
4602   HashProvider hasher;
4603   auto hash_before = hasher.hash(nest.root_stmt());
4604   nest.optimizeConditionals();
4605   auto hash_after = hasher.hash(nest.root_stmt());
4606   ASSERT_EQ(hash_before, hash_after);
4607 }
4608 
TEST(LoopNest,OptimizeConditionalsInvalidCondition2)4609 TEST(LoopNest, OptimizeConditionalsInvalidCondition2) {
4610   // Input IR:
4611   //   for (int i = 0; i < 20; i++) {
4612   //     A[i] = IfThenElse(10<i, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4613   //   }
4614   // No optimization should be done here because of the invalid condition:
4615   //    "10 < i".
4616 
4617   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4618   BufHandle a_buf("A", {20}, kInt);
4619   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4620   BufHandle b_buf("B", {5}, kInt);
4621   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4622   BufHandle c_buf("C", {5}, kInt);
4623   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4624   BufHandle d_buf("D", {10}, kInt);
4625   VarHandle i("i", kInt);
4626   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4627   auto store = Store::make(
4628       a_buf,
4629       {i},
4630       IfThenElse::make(
4631           CompareSelect::make(10, i, kLT),
4632           IfThenElse::make(
4633               CompareSelect::make(i, 5, kLT),
4634               Load::make(b_buf, {i}),
4635               Load::make(c_buf, {i - 5})),
4636           Load::make(d_buf, {i - 10})));
4637   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4638   auto forI = For::make(i, 0, 20, store);
4639   auto par = Block::make({forI});
4640   LoopNest nest(par, {a_buf.node()});
4641 
4642   HashProvider hasher;
4643   auto hash_before = hasher.hash(nest.root_stmt());
4644   nest.optimizeConditionals();
4645   auto hash_after = hasher.hash(nest.root_stmt());
4646   ASSERT_EQ(hash_before, hash_after);
4647 }
4648 
TEST(LoopNest,OptimizeConditionalsInvalidCondition3)4649 TEST(LoopNest, OptimizeConditionalsInvalidCondition3) {
4650   // Input IR:
4651   //   for (int i = 0; i < 20; i++) {
4652   //     A[i] = IfThenElse(i<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
4653   //   }
4654   // No optimization should be done here because the conditions use different
4655   // variables: "i < 10" and "k < 5"
4656 
4657   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4658   BufHandle a_buf("A", {20}, kInt);
4659   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4660   BufHandle b_buf("B", {5}, kInt);
4661   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4662   BufHandle c_buf("C", {5}, kInt);
4663   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4664   BufHandle d_buf("D", {10}, kInt);
4665   VarHandle i("i", kInt);
4666   VarHandle k("k", kInt);
4667   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4668   auto store = Store::make(
4669       a_buf,
4670       {i},
4671       IfThenElse::make(
4672           CompareSelect::make(i, 10, kLT),
4673           IfThenElse::make(
4674               CompareSelect::make(k, 5, kLT),
4675               Load::make(b_buf, {i}),
4676               Load::make(c_buf, {i - 5})),
4677           Load::make(d_buf, {i - 10})));
4678   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4679   auto forI = For::make(i, 0, 20, store);
4680   auto par = Block::make({forI});
4681   LoopNest nest(par, {a_buf.node()});
4682 
4683   HashProvider hasher;
4684   auto hash_before = hasher.hash(nest.root_stmt());
4685   nest.optimizeConditionals();
4686   auto hash_after = hasher.hash(nest.root_stmt());
4687   ASSERT_EQ(hash_before, hash_after);
4688 }
4689 
TEST(LoopNest,OptimizeConditionalsInvalidCondition4)4690 TEST(LoopNest, OptimizeConditionalsInvalidCondition4) {
4691   // Input IR:
4692   //   for (int i = 0; i < 20; i++) {
4693   //     A[i] = IfThenElse(k<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
4694   //   }
4695   // No optimization should be done here because the conditions use the
4696   // variable 'k' which is not a loop variable.
4697 
4698   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4699   BufHandle a_buf("A", {20}, kInt);
4700   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4701   BufHandle b_buf("B", {5}, kInt);
4702   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4703   BufHandle c_buf("C", {5}, kInt);
4704   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4705   BufHandle d_buf("D", {10}, kInt);
4706   VarHandle i("i", kInt);
4707   VarHandle k("k", kInt);
4708   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4709   auto store = Store::make(
4710       a_buf,
4711       {i},
4712       IfThenElse::make(
4713           CompareSelect::make(k, 10, kLT),
4714           IfThenElse::make(
4715               CompareSelect::make(k, 5, kLT),
4716               Load::make(b_buf, {i}),
4717               Load::make(c_buf, {i - 5})),
4718           Load::make(d_buf, {i - 10})));
4719   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4720   auto forI = For::make(i, 0, 20, store);
4721   auto par = Block::make({forI});
4722   LoopNest nest(par, {a_buf.node()});
4723 
4724   HashProvider hasher;
4725   auto hash_before = hasher.hash(nest.root_stmt());
4726   nest.optimizeConditionals();
4727   auto hash_after = hasher.hash(nest.root_stmt());
4728   ASSERT_EQ(hash_before, hash_after);
4729 }
4730 
TEST(LoopNest,OptimizeConditionalsNotNormalized)4731 TEST(LoopNest, OptimizeConditionalsNotNormalized) {
4732   // Input IR:
4733   //   for (int i = 2; i < 20; i++) {
4734   //     A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4735   //   }
4736 
4737   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4738   BufHandle a_buf("A", {20}, kInt);
4739   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4740   BufHandle b_buf("B", {5}, kInt);
4741   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4742   BufHandle c_buf("C", {15}, kInt);
4743   VarHandle i("i", kInt);
4744   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4745   auto store = Store::make(
4746       a_buf,
4747       {i},
4748       IfThenElse::make(
4749           CompareSelect::make(i, 5, kLT),
4750           Load::make(b_buf, {i}),
4751           Load::make(c_buf, {i - 5})));
4752   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4753   auto forI = For::make(i, 2, 20, store);
4754   auto par = Block::make({forI});
4755   LoopNest nest(par, {a_buf.node()});
4756 
4757   HashProvider hasher;
4758   auto hash_before = hasher.hash(nest.root_stmt());
4759   nest.optimizeConditionals();
4760   auto hash_after = hasher.hash(nest.root_stmt());
4761   ASSERT_EQ(hash_before, hash_after);
4762 }
4763 
colReduce(int M,int N)4764 static std::pair<BufHandle, Tensor> colReduce(int M, int N) {
4765   BufHandle a("a", {M, N}, kFloat);
4766   Tensor t = Reduce(
4767       "b",
4768       {N},
4769       Sum(),
4770       [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); },
4771       {M});
4772   return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))};
4773 }
4774 
splitTailReorder(Tensor b)4775 static StmtPtr splitTailReorder(Tensor b) {
4776   constexpr int kVectorWidth = 8;
4777   LoopNest nest({b});
4778   auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
4779   nest.splitWithTail(loops[0], kVectorWidth);
4780   // Now the loopnests will look like:
4781   //
4782   // for (int i_outer = 0; ...
4783   //   for (int i_inner = 0; ...
4784   //     b[i_outer * 8 + i_inner] = float(0);
4785   //     for (int j = 0; ...
4786   //       b[i_outer * 8 + i_inner] = ReduceOp(...);
4787   //
4788   // for (int i_tail = 0; ...
4789   //   b[i_tail + ((100 - 0) / 8) * 8] = float(0);
4790   //   for (int j = 0; ...
4791   //     b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...);
4792   //
4793   // Since there are 4 writes to b, we will get 4 loopnests from the
4794   // call to `getAllLoopNestsWritingToBuf` below.
4795   //
4796   // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)"
4797   // Loopnest #2: {i_outer, i_inner, j};
4798   // We will have to reorder i_inner and j.
4799   auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf());
4800   LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]);
4801   nest.prepareForCodegen();
4802   return nest.root_stmt();
4803 }
4804 
splitMaskReorder(Tensor b)4805 static StmtPtr splitMaskReorder(Tensor b) {
4806   constexpr int kVectorWidth = 8;
4807   LoopNest nest({b});
4808   auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
4809   nest.splitWithMask(loops[0], kVectorWidth);
4810   loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
4811   LoopNest::reorderAxis(loops[1], loops[2]);
4812   nest.prepareForCodegen();
4813   return nest.root_stmt();
4814 }
4815 
checkColReduce(StmtPtr s,BufHandle p,Tensor t)4816 static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) {
4817   int M = immediateAs<int>(p.dim(0));
4818   int N = immediateAs<int>(p.dim(1));
4819   PaddedBuffer<float> a(M, N);
4820   PaddedBuffer<float> b(N);
4821   PaddedBuffer<float> ref(N);
4822   for (int i = 0; i < M; i++) {
4823     for (int j = 0; j < N; j++) {
4824       a(i, j) = 1.0f;
4825     }
4826   }
4827   for (int i = 0; i < N; i++) {
4828     b(i) = 0.0f;
4829   }
4830   for (int i = 0; i < N; i++) {
4831     ref(i) = 76.0f;
4832   }
4833   SimpleIREvaluator(s, {p, t}).call({a, b});
4834   ExpectAllNear(b, ref, 1e-5);
4835 }
4836 
TEST(LoopNest,ColReduceSplitTailEvenReorder)4837 TEST(LoopNest, ColReduceSplitTailEvenReorder) {
4838   constexpr int M = 76, N = 128;
4839   auto p = colReduce(M, N);
4840   StmtPtr s = splitTailReorder(p.second);
4841 
4842   std::ostringstream oss;
4843   oss << *s;
4844   const std::string& verification_pattern =
4845       R"IR(
4846 # CHECK: for (int i_outer
4847 # CHECK-NEXT: for (int i_inner
4848 # CHECK-NEXT: b[
4849 # CHECK: for (int j
4850 # CHECK-NEXT: for (int i_inner
4851 # CHECK-NEXT: b[
4852 # CHECK-NOT: for (
4853       )IR";
4854   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4855 
4856   checkColReduce(s, p.first, p.second);
4857 }
4858 
TEST(LoopNest,ColReduceSplitTailUnevenReorder)4859 TEST(LoopNest, ColReduceSplitTailUnevenReorder) {
4860   constexpr int M = 76, N = 100;
4861   auto p = colReduce(M, N);
4862   StmtPtr s = splitTailReorder(p.second);
4863 
4864   std::ostringstream oss;
4865   oss << *s;
4866   const std::string& verification_pattern =
4867       R"IR(
4868 # CHECK: for (int i_outer
4869 # CHECK-NEXT: for (int i_inner
4870 # CHECK-NEXT: b[
4871 # CHECK: for (int j
4872 # CHECK-NEXT: for (int i_inner
4873 # CHECK-NEXT: b[
4874 # CHECK: for (int i_tail
4875 # CHECK-NEXT: b[
4876 # CHECK-NEXT: for (int j
4877 # CHECK-NEXT: b[
4878       )IR";
4879   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4880 
4881   checkColReduce(s, p.first, p.second);
4882 }
4883 
TEST(LoopNest,ColReduceSplitMaskEvenReorder)4884 TEST(LoopNest, ColReduceSplitMaskEvenReorder) {
4885   constexpr int M = 76, N = 128;
4886   auto p = colReduce(M, N);
4887   StmtPtr s = splitMaskReorder(p.second);
4888   checkColReduce(s, p.first, p.second);
4889 }
4890 
TEST(LoopNest,ColReduceSplitMaskUnevenReorder)4891 TEST(LoopNest, ColReduceSplitMaskUnevenReorder) {
4892   constexpr int M = 76, N = 100;
4893   auto p = colReduce(M, N);
4894   StmtPtr s = splitMaskReorder(p.second);
4895   checkColReduce(s, p.first, p.second);
4896 }
4897 
TEST(LoopNest,ReorderAxisWithMultipleConds)4898 TEST(LoopNest, ReorderAxisWithMultipleConds) {
4899   // Input IR:
4900   //   for (int i = 0; i < 20; i++) {
4901   //     if i > 5 {
4902   //       if i < 10 {
4903   //         for (int j = 0; j < 100; j++) {
4904   //           A[i] = i * j;
4905   //         }
4906   //       }
4907   //     }
4908   //   }
4909   BufHandle a_buf("A", {20}, kInt);
4910   VarHandle i("i", kInt);
4911   VarHandle j("j", kInt);
4912   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j)));
4913   auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr);
4914   auto outer_cond =
4915       Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr);
4916   auto forI = For::make(i, 0, 20, outer_cond);
4917   StmtPtr par = Block::make({forI});
4918   LoopNest l(par, {a_buf.node()});
4919   LoopNest::reorderAxis(forI, forJ);
4920   ASSERT_EQ(par, l.root_stmt());
4921   par = IRSimplifier::simplify(par);
4922 
4923   const std::string& verification_pattern =
4924       R"IR(
4925 # CHECK: for (int j
4926 # CHECK-NEXT: for (int i
4927 # CHECK-NEXT: if (i>5
4928 # CHECK-NEXT: if (i<10
4929 # CHECK-NEXT: A[i] = i * j
4930 # CHECK-NOT: for (
4931       )IR";
4932   std::ostringstream oss;
4933   oss << *par;
4934   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4935 }
4936 
TEST(LoopNest,VectorizeUse)4937 TEST(LoopNest, VectorizeUse) {
4938   constexpr int N = 8;
4939   BufHandle a("a", {N}, kFloat);
4940   Tensor b =
4941       Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; });
4942   Tensor c =
4943       Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; });
4944   LoopNest nest({c}, {b, c});
4945   auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
4946   ASSERT_TRUE(LoopNest::vectorize(loops[0]));
4947   loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0];
4948   ASSERT_TRUE(LoopNest::vectorize(loops[0]));
4949   nest.prepareForCodegen();
4950   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
4951   StmtPtr s = nest.root_stmt();
4952   std::ostringstream oss;
4953   oss << *nest.root_stmt();
4954   torch::jit::testing::FileCheck().run(
4955       R"IR(
4956 # CHECK: c[Ramp
4957 )IR",
4958       oss.str());
4959 }
4960 
4961 const char* int64Loop = R"IR(
4962 # CHECK: for (int64_t i = 0ll; i < 12ll; i++) {
4963 # CHECK:   b[i] = (a[i]) + 1ll;
4964 # CHECK: }
4965 )IR";
4966 
TEST(LoopNest,Int64Direct)4967 TEST(LoopNest, Int64Direct) {
4968   constexpr int64_t N = 12;
4969   BufHandle a("a", {N}, kLong);
4970   BufHandle b("b", {N}, kLong);
4971   VarHandle n("i", kLong);
4972   StmtPtr s = For::make(
4973       n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l)));
4974   s = IRSimplifier::simplify(s);
4975   std::ostringstream oss;
4976   oss << *s;
4977   torch::jit::testing::FileCheck().run(int64Loop, oss.str());
4978 }
4979 
TEST(LoopNest,Int64Compute)4980 TEST(LoopNest, Int64Compute) {
4981   constexpr int64_t N = 12;
4982   BufHandle a("a", {N}, kLong);
4983   Tensor b = Compute("b", {N}, [&](const VarHandle& n) {
4984     return a.load(n) + LongImm::make(1l);
4985   });
4986   LoopNest nest({b});
4987   nest.prepareForCodegen();
4988   nest.simplify();
4989   std::ostringstream oss;
4990   oss << *nest.root_stmt();
4991   torch::jit::testing::FileCheck().run(int64Loop, oss.str());
4992 }
4993 
TEST(LoopNest,DistributeLoopWithAllStmtsAsPivots)4994 TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) {
4995   // Input IR:
4996   //   for (int i = 0; i < 20; i++) {
4997   //     A[i] = 0;
4998   //     for (int j = 0; j < 100; j++) {
4999   //       A[i] = A[i] + i * j;
5000   //     }
5001   //     B[i] = A[i];
5002   //     for (int k = 0; k < 50; k++) {
5003   //       B[i] = B[i] + i * k;
5004   //     }
5005   //   }
5006   BufHandle a_buf("A", {20}, kInt);
5007   BufHandle b_buf("B", {20}, kInt);
5008   VarHandle i("i", kInt);
5009   VarHandle j("j", kInt);
5010   VarHandle k("k", kInt);
5011   auto initA = Store::make(a_buf, {i}, 0);
5012   auto forJ = For::make(
5013       j,
5014       0,
5015       100,
5016       Store::make(
5017           a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5018   auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5019   auto forK = For::make(
5020       k,
5021       0,
5022       50,
5023       Store::make(
5024           b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5025   auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5026   auto par = Block::make({forI});
5027 
5028   const std::string& verification_pattern =
5029       R"IR(
5030 # CHECK: for (int i
5031 # CHECK-NEXT: A[i] = 0
5032 # CHECK: for (int i
5033 # CHECK-NEXT: for (int j
5034 # CHECK-NEXT: A[i] =
5035 # CHECK: for (int i
5036 # CHECK-NEXT: B[i] = A[i]
5037 # CHECK: for (int i
5038 # CHECK-NEXT: for (int k
5039 # CHECK-NEXT: B[i] =
5040 # CHECK-NOT: for (
5041       )IR";
5042 
5043   LoopNest nest(par, {a_buf.node(), b_buf.node()});
5044   auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB});
5045 
5046   std::ostringstream oss;
5047   oss << *par;
5048   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5049 
5050   // The first loop after distribution must be same as the original For.
5051   ASSERT_EQ(new_loops.front(), forI);
5052 }
5053 
TEST(LoopNest,DistributeLoopWithOneStmtAsPivot)5054 TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) {
5055   // Input IR:
5056   //   for (int i = 0; i < 20; i++) {
5057   //     A[i] = 0;
5058   //     for (int j = 0; j < 100; j++) {
5059   //       A[i] = A[i] + i * j;
5060   //     }
5061   //     B[i] = A[i];
5062   //     for (int k = 0; k < 50; k++) {
5063   //       B[i] = B[i] + i * k;
5064   //     }
5065   //   }
5066   BufHandle a_buf("A", {20}, kInt);
5067   BufHandle b_buf("B", {20}, kInt);
5068   VarHandle i("i", kInt);
5069   VarHandle j("j", kInt);
5070   VarHandle k("k", kInt);
5071   auto initA = Store::make(a_buf, {i}, 0);
5072   auto forJ = For::make(
5073       j,
5074       0,
5075       100,
5076       Store::make(
5077           a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5078   auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5079   auto forK = For::make(
5080       k,
5081       0,
5082       50,
5083       Store::make(
5084           b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5085   auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5086   auto par = Block::make({forI});
5087 
5088   LoopNest nest(par, {a_buf.node(), b_buf.node()});
5089   auto new_loops = LoopNest::distributeLoop(forI, {forJ});
5090 
5091   std::ostringstream oss;
5092   oss << *par;
5093   const std::string& verification_pattern =
5094       R"IR(
5095 # CHECK: for (int i
5096 # CHECK-NEXT: A[i] = 0
5097 # CHECK-NEXT: for (int j
5098 # CHECK-NEXT: A[i] =
5099 # CHECK: for (int i
5100 # CHECK-NEXT: B[i] = A[i]
5101 # CHECK-NEXT: for (int k
5102 # CHECK-NEXT: B[i] =
5103 # CHECK-NOT: for (
5104       )IR";
5105   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5106 
5107   // The first loop after distribution must be same as the original For.
5108   ASSERT_EQ(new_loops.front(), forI);
5109 }
5110 
TEST(LoopNest,DistributeLoopWithoutAnyPivot)5111 TEST(LoopNest, DistributeLoopWithoutAnyPivot) {
5112   // Input IR:
5113   //   for (int i = 0; i < 20; i++) {
5114   //     A[i] = 0;
5115   //     for (int j = 0; j < 100; j++) {
5116   //       A[i] = A[i] + i * j;
5117   //     }
5118   //     B[i] = A[i];
5119   //     for (int k = 0; k < 50; k++) {
5120   //       B[i] = B[i] + i * k;
5121   //     }
5122   //   }
5123   BufHandle a_buf("A", {20}, kInt);
5124   BufHandle b_buf("B", {20}, kInt);
5125   VarHandle i("i", kInt);
5126   VarHandle j("j", kInt);
5127   VarHandle k("k", kInt);
5128   auto initA = Store::make(a_buf, {i}, 0);
5129   auto forJ = For::make(
5130       j,
5131       0,
5132       100,
5133       Store::make(
5134           a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5135   auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5136   auto forK = For::make(
5137       k,
5138       0,
5139       50,
5140       Store::make(
5141           b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5142   auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5143   auto par = Block::make({forI});
5144 
5145   const std::string& verification_pattern =
5146       R"IR(
5147 # CHECK: for (int i
5148 # CHECK-NEXT: A[i] = 0
5149 # CHECK: for (int i
5150 # CHECK-NEXT: for (int j
5151 # CHECK-NEXT: A[i] =
5152 # CHECK: for (int i
5153 # CHECK-NEXT: B[i] = A[i]
5154 # CHECK: for (int i
5155 # CHECK-NEXT: for (int k
5156 # CHECK-NEXT: B[i] =
5157 # CHECK-NOT: for (
5158       )IR";
5159 
5160   LoopNest nest(par, {a_buf.node(), b_buf.node()});
5161   auto new_loops = LoopNest::distributeLoop(forI);
5162 
5163   std::ostringstream oss;
5164   oss << *par;
5165   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5166 
5167   // The first loop after distribution must be same as the original For.
5168   ASSERT_EQ(new_loops.front(), forI);
5169 }
5170 
TEST(LoopNest,DistributeLoopOverInnerLoops)5171 TEST(LoopNest, DistributeLoopOverInnerLoops) {
5172   // Input IR:
5173   //   for (int i = 0; i < 20; i++) {
5174   //     A[i] = 0;
5175   //     for (int j = 0; j < 100; j++) {
5176   //       A[i] = A[i] + i * j;
5177   //     }
5178   //     B[i] = A[i];
5179   //     for (int k = 0; k < 50; k++) {
5180   //       B[i] = B[i] + i * k;
5181   //     }
5182   //   }
5183   BufHandle a_buf("A", {20}, kInt);
5184   BufHandle b_buf("B", {20}, kInt);
5185   VarHandle i("i", kInt);
5186   VarHandle j("j", kInt);
5187   VarHandle k("k", kInt);
5188   auto initA = Store::make(a_buf, {i}, 0);
5189   auto forJ = For::make(
5190       j,
5191       0,
5192       100,
5193       Store::make(
5194           a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5195   auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5196   auto forK = For::make(
5197       k,
5198       0,
5199       50,
5200       Store::make(
5201           b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5202   auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5203   auto par = Block::make({forI});
5204 
5205   LoopNest nest(par, {a_buf.node(), b_buf.node()});
5206   auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI);
5207 
5208   std::ostringstream oss;
5209   oss << *par;
5210   const std::string& verification_pattern =
5211       R"IR(
5212 # CHECK: for (int i
5213 # CHECK-NEXT: A[i] = 0
5214 # CHECK-NEXT: for (int j
5215 # CHECK-NEXT: A[i] =
5216 # CHECK: for (int i
5217 # CHECK-NEXT: B[i] = A[i]
5218 # CHECK-NEXT: for (int k
5219 # CHECK-NEXT: B[i] =
5220 # CHECK-NOT: for (
5221       )IR";
5222   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5223 
5224   // The first loop after distribution must be same as the original For.
5225   ASSERT_EQ(new_loops.front(), forI);
5226 }
5227 
TEST(LoopNest,DistributeLoopAndParentsWithoutAnyPivot)5228 TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) {
5229   // Input IR:
5230   // for (int m = 0; m < 50; m++) {
5231   //   for (int i = 0; i < 20; i++) {
5232   //     A[m,i] = 0;
5233   //     for (int j = 0; j < 100; j++) {
5234   //       A[m,i] = A[m,i] + i * j;
5235   //     }
5236   //     B[m,i] = A[m,i];
5237   //     for (int k = 0; k < 50; k++) {
5238   //       B[m,i] = B[m,i] + i * k;
5239   //     }
5240   //   }
5241   // }
5242   BufHandle a_buf("A", {100, 100}, kInt);
5243   BufHandle b_buf("B", {100, 100}, kInt);
5244   VarHandle m("m", kInt);
5245   VarHandle i("i", kInt);
5246   VarHandle j("j", kInt);
5247   VarHandle k("k", kInt);
5248   auto initA = Store::make(a_buf, {m, i}, 0);
5249   auto forJ = For::make(
5250       j,
5251       0,
5252       100,
5253       Store::make(
5254           a_buf,
5255           {m, i},
5256           Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j))));
5257   auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i}));
5258   auto forK = For::make(
5259       k,
5260       0,
5261       50,
5262       Store::make(
5263           b_buf,
5264           {m, i},
5265           Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k))));
5266   auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5267 
5268   {
5269     // Check the case of distributing loop and its parents over all the
5270     // statements in the loop.
5271     const std::string& verification_pattern =
5272         R"IR(
5273 # CHECK: for (int m
5274 # CHECK-NEXT: for (int i
5275 # CHECK-NEXT: A[m, i] = 0
5276 # CHECK: for (int m
5277 # CHECK-NEXT: for (int i
5278 # CHECK-NEXT: for (int j
5279 # CHECK-NEXT: A[m, i] =
5280 # CHECK: for (int m
5281 # CHECK-NEXT: for (int i
5282 # CHECK-NEXT: B[m, i] = A[m, i]
5283 # CHECK: for (int m
5284 # CHECK-NEXT: for (int i
5285 # CHECK-NEXT: for (int k
5286 # CHECK-NEXT: B[m, i] =
5287 # CHECK-NOT: for (
5288         )IR";
5289 
5290     auto newForI = to<For>(Stmt::clone(forI));
5291     auto forM = For::make(m, 0, 50, newForI);
5292     auto par = Block::make({forM});
5293     LoopNest nest(par, {a_buf.node(), b_buf.node()});
5294     auto newLoops = LoopNest::distributeLoopAndParents(newForI);
5295 
5296     std::ostringstream oss;
5297     oss << *par;
5298     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5299 
5300     // The first loop after distribution must be same as the original For.
5301     ASSERT_EQ(newLoops.front(), forM);
5302   }
5303 
5304   {
5305     // Check the case of distributing loop and its parents over all the inner
5306     // loops.
5307     const std::string& verification_pattern =
5308         R"IR(
5309 # CHECK: for (int m
5310 # CHECK-NEXT: for (int i
5311 # CHECK-NEXT: A[m, i] = 0
5312 # CHECK-NEXT: for (int j
5313 # CHECK-NEXT: A[m, i] =
5314 # CHECK: for (int m
5315 # CHECK-NEXT: for (int i
5316 # CHECK-NEXT: B[m, i] = A[m, i]
5317 # CHECK-NEXT: for (int k
5318 # CHECK-NEXT: B[m, i] =
5319 # CHECK-NOT: for (
5320         )IR";
5321 
5322     auto newForI = to<For>(Stmt::clone(forI));
5323     auto forM = For::make(m, 0, 50, newForI);
5324     auto par = Block::make({forM});
5325     LoopNest nest(par, {a_buf.node(), b_buf.node()});
5326     auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI);
5327 
5328     std::ostringstream oss;
5329     oss << *par;
5330     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5331 
5332     // The first loop after distribution must be same as the original For.
5333     ASSERT_EQ(newLoops.front(), forM);
5334   }
5335 }
5336 
TEST(LoopNest,fuseLoopsSimple)5337 TEST(LoopNest, fuseLoopsSimple) {
5338   // Input IR:
5339   //   for (int j = 0; j < 100; j++) {
5340   //     A[j] = 10 * j;
5341   //   }
5342   //   for (int k = 0; k < 100; k++) {
5343   //     B[k] = 20 * k;
5344   //   }
5345   BufHandle a_buf("A", {100}, kInt);
5346   BufHandle b_buf("B", {100}, kInt);
5347   VarHandle j("j", kInt);
5348   VarHandle k("k", kInt);
5349   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5350   auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
5351   auto par = Block::make({forJ, forK});
5352   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5353   ForPtr fused_loop;
5354   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5355 
5356   std::ostringstream oss;
5357   oss << *par;
5358   const std::string& verification_pattern =
5359       R"IR(
5360 # CHECK: for (int j
5361 # CHECK-NEXT: A[j] =
5362 # CHECK-NEXT: B[j] =
5363 # CHECK-NOT: for (
5364       )IR";
5365   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5366 
5367   // The fused loop must be the same as the first loop.
5368   ASSERT_EQ(fused_loop, forJ);
5369 }
5370 
TEST(LoopNest,fuseLoopsMultiple)5371 TEST(LoopNest, fuseLoopsMultiple) {
5372   // Input IR:
5373   //   for (int i = 0; i < 100; i++) {
5374   //     A[i+100] = 20 + i;
5375   //   }
5376   //   for (int j = 0; j < 100; j++) {
5377   //     A[j] = 10 * j;
5378   //   }
5379   //   for (int k = 0; k < 100; k++) {
5380   //     B[k] = 20 * k;
5381   //   }
5382   BufHandle a_buf("A", {200}, kInt);
5383   BufHandle b_buf("B", {100}, kInt);
5384   VarHandle i("i", kInt);
5385   VarHandle j("j", kInt);
5386   VarHandle k("k", kInt);
5387   auto forI =
5388       For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i)));
5389   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5390   auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
5391   auto par = Block::make({forI, forJ, forK});
5392   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5393   ForPtr fused_loop;
5394   ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop));
5395 
5396   std::ostringstream oss;
5397   oss << *par;
5398   const std::string& verification_pattern =
5399       R"IR(
5400 # CHECK: for (int i
5401 # CHECK-NEXT: A[i + 100] =
5402 # CHECK-NEXT: A[i] =
5403 # CHECK-NEXT: B[i] =
5404 # CHECK-NOT: for (
5405       )IR";
5406   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5407 
5408   // The fused loop must be the same as the first loop.
5409   ASSERT_EQ(fused_loop, forI);
5410 }
5411 
TEST(LoopNest,fuseLoopsNested)5412 TEST(LoopNest, fuseLoopsNested) {
5413   // Input IR:
5414   //   for (int m = 0; m < 20; m++) {
5415   //     A[m] = 0;
5416   //     for (int j = 0; j < 100; j++) {
5417   //       A[m] = A[m] + m * j;
5418   //     }
5419   //   }
5420   //   for (int n = 0; n < 20; n++) {
5421   //     B[n] = A[n];
5422   //     for (int k = 0; k < 50; k++) {
5423   //       B[n] = B[n] + n * k;
5424   //     }
5425   //   }
5426   BufHandle a_buf("A", {20, 100}, kInt);
5427   BufHandle b_buf("B", {20, 100}, kInt);
5428   VarHandle m("m", kInt);
5429   VarHandle n("n", kInt);
5430   VarHandle j("j", kInt);
5431   VarHandle k("k", kInt);
5432   auto initA = Store::make(a_buf, {m}, 0);
5433   auto forJ = For::make(
5434       j,
5435       0,
5436       100,
5437       Store::make(
5438           a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
5439   auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n}));
5440   auto forK = For::make(
5441       k,
5442       0,
5443       50,
5444       Store::make(
5445           b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
5446   auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
5447   auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
5448   auto par = Block::make({forM, forN});
5449   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5450   ForPtr fused_loop;
5451   ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
5452 
5453   std::ostringstream oss;
5454   oss << *par;
5455   const std::string& verification_pattern =
5456       R"IR(
5457 # CHECK: for (int m
5458 # CHECK-NEXT: A[m] = 0
5459 # CHECK-NEXT: for (int j
5460 # CHECK-NEXT: A[m] =
5461 # CHECK: B[m] = A[m]
5462 # CHECK-NEXT: for (int k
5463 # CHECK-NEXT: B[m] =
5464 # CHECK-NOT: for (
5465       )IR";
5466   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5467 
5468   // The fused loop must be the same as the first loop.
5469   ASSERT_EQ(fused_loop, forM);
5470 }
5471 
TEST(LoopNest,fuseLoopsNested2D)5472 TEST(LoopNest, fuseLoopsNested2D) {
5473   // Input IR:
5474   //   for (int i = 0; i < 20; i++) {
5475   //     for (int j = 0; j < 100; j++) {
5476   //       A[i,j] = i * j * 500;
5477   //     }
5478   //   }
5479   //   for (int m = 0; m < 20; m++) {
5480   //     for (int n = 0; n < 50; n++) {
5481   //       B[m,n] = m + n * 100;
5482   //     }
5483   //   }
5484   BufHandle a_buf("A", {20, 100}, kInt);
5485   BufHandle b_buf("B", {20, 100}, kInt);
5486   VarHandle i("i", kInt);
5487   VarHandle j("j", kInt);
5488   VarHandle m("m", kInt);
5489   VarHandle n("n", kInt);
5490   auto forI = For::make(
5491       i,
5492       0,
5493       20,
5494       For::make(
5495           j,
5496           0,
5497           100,
5498           Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
5499   auto forM = For::make(
5500       m,
5501       0,
5502       20,
5503       For::make(
5504           n,
5505           0,
5506           50,
5507           Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)))));
5508   auto par = Block::make({forI, forM});
5509   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5510   ForPtr fused_loop;
5511   ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5512 
5513   std::ostringstream oss;
5514   oss << *par;
5515   const std::string& verification_pattern =
5516       R"IR(
5517 # CHECK: for (int i
5518 # CHECK-NEXT: for (int j
5519 # CHECK-NEXT: A[i, j] =
5520 # CHECK: for (int n
5521 # CHECK-NEXT: B[i, n] =
5522 # CHECK-NOT: for (
5523       )IR";
5524   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5525 
5526   // The fused loop must be the same as the first loop.
5527   ASSERT_EQ(fused_loop, forI);
5528 }
5529 
TEST(LoopNest,fuseLoopsNested2DInner)5530 TEST(LoopNest, fuseLoopsNested2DInner) {
5531   // Input IR:
5532   //   for (int i = 0; i < 20; i++) {
5533   //     for (int j = 0; j < 100; j++) {
5534   //       A[i,j] = i * j * 500;
5535   //     }
5536   //     for (int n = 0; n < 100; n++) {
5537   //       B[i,n] = m + n * 100;
5538   //     }
5539   //   }
5540   BufHandle a_buf("A", {20, 100}, kInt);
5541   BufHandle b_buf("B", {20, 100}, kInt);
5542   VarHandle i("i", kInt);
5543   VarHandle j("j", kInt);
5544   VarHandle n("n", kInt);
5545   auto forJ = For::make(
5546       j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
5547   auto forN = For::make(
5548       n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100))));
5549   auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
5550   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5551   ForPtr fused_loop;
5552   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
5553 
5554   std::ostringstream oss;
5555   oss << *forI;
5556   const std::string& verification_pattern =
5557       R"IR(
5558 # CHECK: for (int i
5559 # CHECK-NEXT: for (int j
5560 # CHECK-NEXT: A[i, j] =
5561 # CHECK-NEXT: B[i, j] =
5562 # CHECK-NOT: for (
5563       )IR";
5564   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5565 
5566   // The fused loop must be the same as the first loop.
5567   ASSERT_EQ(fused_loop, forJ);
5568 }
5569 
TEST(LoopNest,fuseLoopsDifferentStopBounds)5570 TEST(LoopNest, fuseLoopsDifferentStopBounds) {
5571   // Input IR:
5572   //   for (int j = 0; j < 100; j++) {
5573   //     A[j] = 10 * j;
5574   //   }
5575   //   for (int k = 0; k < 50; k++) {
5576   //     B[k] = 20 * k;
5577   //   }
5578   BufHandle a_buf("A", {100}, kInt);
5579   BufHandle b_buf("B", {100}, kInt);
5580   VarHandle j("j", kInt);
5581   VarHandle k("k", kInt);
5582   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5583   auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k)));
5584   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5585   auto par = Block::make({forJ, forK});
5586   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5587   ForPtr fused_loop;
5588   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5589 }
5590 
TEST(LoopNest,fuseLoopsDifferentStartBounds)5591 TEST(LoopNest, fuseLoopsDifferentStartBounds) {
5592   // Input IR:
5593   //   for (int j = 0; j < 100; j++) {
5594   //     A[j] = 10 * j;
5595   //   }
5596   //   for (int k = 50; k < 100; k++) {
5597   //     B[k] = 20 * k;
5598   //   }
5599   BufHandle a_buf("A", {100}, kInt);
5600   BufHandle b_buf("B", {100}, kInt);
5601   VarHandle j("j", kInt);
5602   VarHandle k("k", kInt);
5603   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5604   auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5605   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5606   auto par = Block::make({forJ, forK});
5607   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5608   ForPtr fused_loop;
5609   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5610 }
5611 
TEST(LoopNest,fuseLoopsNotContiguous)5612 TEST(LoopNest, fuseLoopsNotContiguous) {
5613   // Input IR:
5614   //   for (int j = 0; j < 100; j++) {
5615   //     A[j] = 10 * j;
5616   //   }
5617   //   B[0] = 0;
5618   //   for (int k = 0; k < 100; k++) {
5619   //     B[k] = 20 * k;
5620   //   }
5621   BufHandle a_buf("A", {100}, kInt);
5622   BufHandle b_buf("B", {100}, kInt);
5623   VarHandle j("j", kInt);
5624   VarHandle k("k", kInt);
5625   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5626   auto initB = Store::make(b_buf, {0}, 0);
5627   auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5628   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5629   auto par = Block::make({forJ, initB, forK});
5630   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5631   ForPtr fused_loop;
5632   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5633 }
5634 
TEST(LoopNest,fuseLoopsWithDifferentParents)5635 TEST(LoopNest, fuseLoopsWithDifferentParents) {
5636   // Input IR:
5637   //   for (int i = 0; i < 50; i++) {
5638   //     for (int j = 0; j < 100; j++) {
5639   //       A[i,j] = i * j;
5640   //     }
5641   //   }
5642   //   B[0] = 0;
5643   //   for (int k = 50; k < 100; k++) {
5644   //     B[k] = 20 * k;
5645   //   }
5646   BufHandle a_buf("A", {50, 100}, kInt);
5647   BufHandle b_buf("B", {100}, kInt);
5648   VarHandle i("i", kInt);
5649   VarHandle j("j", kInt);
5650   VarHandle k("k", kInt);
5651   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j)));
5652   auto forI = For::make(i, 0, 50, forJ);
5653   auto initB = Store::make(b_buf, {0}, 0);
5654   auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5655   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5656   auto par = Block::make({forI, initB, forK});
5657   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5658   ForPtr fused_loop;
5659   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5660 }
5661 
TEST(LoopNest,fuseLoopsWithVariableBounds)5662 TEST(LoopNest, fuseLoopsWithVariableBounds) {
5663   // Input IR:
5664   //   for (int j = 0; j < N; j++) {
5665   //     A[j] = 10 * j;
5666   //   }
5667   //   for (int k = 0; k < N; k++) {
5668   //     B[k] = 20 * k;
5669   //   }
5670   BufHandle a_buf("A", {20}, kInt);
5671   BufHandle b_buf("B", {20}, kInt);
5672   VarHandle j("j", kInt);
5673   VarHandle k("k", kInt);
5674   VarHandle N("N", kInt);
5675   auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j)));
5676   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
5677   auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k)));
5678   auto par = Block::make({forJ, forK});
5679   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5680   ForPtr fused_loop;
5681   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5682 
5683   std::ostringstream oss;
5684   oss << *par;
5685   const std::string& verification_pattern =
5686       R"IR(
5687 # CHECK: for (int j
5688 # CHECK-NEXT: A[j] =
5689 # CHECK-NEXT: B[j] =
5690 # CHECK-NOT: for (
5691       )IR";
5692   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5693 
5694   // The fused loop must be the same as the first loop.
5695   ASSERT_EQ(fused_loop, forJ);
5696 }
5697 
TEST(LoopNest,fuseLoopsWithExprBounds)5698 TEST(LoopNest, fuseLoopsWithExprBounds) {
5699   // Input IR:
5700   //   for (int j = 0; j < M + N; j++) {
5701   //     A[j] = 10 * j;
5702   //   }
5703   //   for (int k = 0; k < M + N; k++) {
5704   //     B[k] = 20 * k;
5705   //   }
5706   BufHandle a_buf("A", {20}, kInt);
5707   BufHandle b_buf("B", {20}, kInt);
5708   VarHandle j("j", kInt);
5709   VarHandle k("k", kInt);
5710   VarHandle M("M", kInt);
5711   VarHandle N("N", kInt);
5712   auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j)));
5713   auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k)));
5714   auto par = Block::make({forJ, forK});
5715   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5716   ForPtr fused_loop;
5717   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5718 
5719   std::ostringstream oss;
5720   oss << *par;
5721   const std::string& verification_pattern =
5722       R"IR(
5723 # CHECK: for (int j
5724 # CHECK-NEXT: A[j] =
5725 # CHECK-NEXT: B[j] =
5726 # CHECK-NOT: for (
5727       )IR";
5728   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5729 
5730   // The fused loop must be the same as the first loop.
5731   ASSERT_EQ(fused_loop, forJ);
5732 }
5733 
TEST(LoopNest,fuseLoopsWithDifferentExprBounds)5734 TEST(LoopNest, fuseLoopsWithDifferentExprBounds) {
5735   // Input IR:
5736   //   for (int j = M; j < N * 2; j++) {
5737   //     A[j] = 10 * j;
5738   //   }
5739   //   for (int k = M; k < N + N; k++) {
5740   //     B[k] = 20 * k;
5741   //   }
5742   BufHandle a_buf("A", {20}, kInt);
5743   BufHandle b_buf("B", {20}, kInt);
5744   VarHandle j("j", kInt);
5745   VarHandle k("k", kInt);
5746   VarHandle M("M", kInt);
5747   VarHandle N("N", kInt);
5748   auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j)));
5749   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
5750   auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k)));
5751   auto par = Block::make({forJ, forK});
5752   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5753   ForPtr fused_loop;
5754   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5755 
5756   std::ostringstream oss;
5757   oss << *par;
5758   const std::string& verification_pattern =
5759       R"IR(
5760 # CHECK: for (int j
5761 # CHECK-NEXT: A[j] =
5762 # CHECK-NEXT: B[j] =
5763 # CHECK-NOT: for (
5764       )IR";
5765   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5766 
5767   // The fused loop must be the same as the first loop.
5768   ASSERT_EQ(fused_loop, forJ);
5769 }
5770 
TEST(LoopNest,fuseLoopsWithNonOverlappingBufferAccesses)5771 TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) {
5772   // Input IR:
5773   //   for (int j = 10; j < 100; j++) {
5774   //     A[j] = 10 * j;
5775   //   }
5776   //   for (int k = 10; k < 100; k++) {
5777   //     A[k+100] = 30 * k
5778   //   }
5779   BufHandle a_buf("A", {200}, kInt);
5780   VarHandle j("j", kInt);
5781   VarHandle k("k", kInt);
5782   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5783   auto forK =
5784       For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k)));
5785   auto par = Block::make({forJ, forK});
5786 
5787   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5788   ForPtr fused_loop;
5789   ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5790 
5791   std::ostringstream oss;
5792   oss << *par;
5793   const std::string& verification_pattern =
5794       R"IR(
5795 # CHECK: for (int j
5796 # CHECK-NEXT: A[j] =
5797 # CHECK-NEXT: A[j + 100] =
5798 # CHECK-NOT: for (
5799       )IR";
5800   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5801 
5802   // The fused loop must be the same as the first loop.
5803   ASSERT_EQ(fused_loop, forJ);
5804 }
5805 
TEST(LoopNest,fuseLoopsWithNonOverlapping2DBufferAccesses)5806 TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) {
5807   // Input IR:
5808   //   for (int i = 0; i < 20; i++) {
5809   //     for (int j = 0; j < 100; j++) {
5810   //       A[i,j] = i * j * 500;
5811   //     }
5812   //   }
5813   //   for (int m = 0; m < 20; m++) {
5814   //     for (int n = 0; n < 50; n++) {
5815   //       A[m+20,n+100] = m + n * 100;
5816   //     }
5817   //   }
5818   BufHandle a_buf("A", {20, 100}, kInt);
5819   BufHandle b_buf("B", {20, 50}, kInt);
5820   VarHandle i("i", kInt);
5821   VarHandle j("j", kInt);
5822   VarHandle m("m", kInt);
5823   VarHandle n("n", kInt);
5824   auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
5825   auto forJ = For::make(j, 0, 100, storeA1);
5826   auto forI = For::make(i, 0, 20, forJ);
5827   auto storeA2 =
5828       Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100)));
5829   auto forN = For::make(n, 0, 50, storeA2);
5830   auto forM = For::make(m, 0, 20, forN);
5831   auto par = Block::make({forI, forM});
5832 
5833   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5834   ForPtr fused_loop;
5835   ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5836 
5837   std::ostringstream oss;
5838   oss << *par;
5839   const std::string& verification_pattern =
5840       R"IR(
5841 # CHECK: for (int i
5842 # CHECK-NEXT: for (int j
5843 # CHECK-NEXT: A[i, j] =
5844 # CHECK: for (int n
5845 # CHECK-NEXT: A[i + 20, n + 100] =
5846 # CHECK-NOT: for (
5847       )IR";
5848   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5849 
5850   // The fused loop must be the same as the first loop.
5851   ASSERT_EQ(fused_loop, forI);
5852 }
5853 
TEST(LoopNest,fuseLoopsWithReductions)5854 TEST(LoopNest, fuseLoopsWithReductions) {
5855   // Input IR:
5856   //   for (int i = 0; i < 20; i++) {
5857   //     A[i] = 0
5858   //     for (int j = 0; j < 100; j++) {
5859   //       A[i] = A[i] + B[i,j];
5860   //     }
5861   //   }
5862   //   for (int m = 0; m < 20; m++) {
5863   //     C[m] = A[m];
5864   //   }
5865   BufHandle a_buf("A", {20}, kInt);
5866   BufHandle b_buf("B", {20, 100}, kInt);
5867   BufHandle c_buf("C", {20}, kInt);
5868   VarHandle i("i", kInt);
5869   VarHandle j("j", kInt);
5870   VarHandle m("m", kInt);
5871   auto initA = Store::make(a_buf, {i}, 0);
5872   auto sumA = Store::make(
5873       a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j})));
5874   auto forJ = For::make(j, 0, 100, sumA);
5875   auto forI = For::make(i, 0, 20, Block::make({initA, forJ}));
5876   auto forM =
5877       For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m})));
5878   auto par = Block::make({forI, forM});
5879   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5880   ForPtr fused_loop;
5881   ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5882 
5883   std::ostringstream oss;
5884   oss << *par;
5885   const std::string& verification_pattern =
5886       R"IR(
5887 # CHECK: for (int i
5888 # CHECK-NEXT: A[i] =
5889 # CHECK-NEXT: for (int j
5890 # CHECK-NEXT: A[i] = (A[i]) +
5891 # CHECK-NOT: for (
5892 # CHECK: C[i] = A[i]
5893       )IR";
5894   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5895 
5896   // The fused loop must be the same as the first loop.
5897   ASSERT_EQ(fused_loop, forI);
5898 }
5899 
TEST(LoopNest,fuseLoopsWith2DReductions)5900 TEST(LoopNest, fuseLoopsWith2DReductions) {
5901   // Input IR:
5902   //   for (int i = 0; i < 20; i++) {
5903   //     for (int j = 0; j < 50; j++) {
5904   //       A[i,j] = 0
5905   //       for (int k = 0; k < 100; k++) {
5906   //         A[i,j] = A[i,j] + B[i,j,k];
5907   //       }
5908   //     }
5909   //   }
5910   //   for (int m = 0; m < 20; m++) {
5911   //     for (int n = 0; n < 40; n++) {
5912   //       C[m,n] = A[m,n];
5913   //     }
5914   //   }
5915   BufHandle a_buf("A", {20, 50}, kInt);
5916   BufHandle b_buf("B", {20, 50, 100}, kInt);
5917   BufHandle c_buf("C", {20, 40}, kInt);
5918   VarHandle i("i", kInt);
5919   VarHandle j("j", kInt);
5920   VarHandle k("k", kInt);
5921   VarHandle m("m", kInt);
5922   VarHandle n("n", kInt);
5923   auto initA = Store::make(a_buf, {i, j}, 0);
5924   auto sumA = Store::make(
5925       a_buf,
5926       {i, j},
5927       Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k})));
5928   auto forK = For::make(k, 0, 100, sumA);
5929   auto forJ = For::make(j, 0, 50, Block::make({initA, forK}));
5930   auto forI = For::make(i, 0, 20, forJ);
5931   auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n}));
5932   auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC));
5933   auto par = Block::make({forI, forM});
5934 
5935   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5936   ForPtr fused_loop;
5937   ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5938 
5939   std::ostringstream oss;
5940   oss << *par;
5941   const std::string& verification_pattern =
5942       R"IR(
5943 # CHECK: for (int i
5944 # CHECK-NEXT: for (int j
5945 # CHECK-NEXT: A[i, j] =
5946 # CHECK-NEXT: for (int k
5947 # CHECK-NEXT: A[i, j] = (A[i, j]) +
5948 # CHECK: for (int n
5949 # CHECK-NEXT: C[i, n] = A[i, n]
5950 # CHECK-NOT: for (
5951       )IR";
5952   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5953 
5954   // The fused loop must be the same as the first loop.
5955   ASSERT_EQ(fused_loop, forI);
5956 }
5957 
TEST(LoopNest,fuseLoopsWithComplexIndices)5958 TEST(LoopNest, fuseLoopsWithComplexIndices) {
5959   // Input IR:
5960   //   for (int i = 0; i < 20; i++) {
5961   //     for (int j = 0; j < 20; j++) {
5962   //       A[i,j*20+j+2] = i + j;
5963   //     }
5964   //   }
5965   //   for (int m = 0; m < 20; m++) {
5966   //     for (int n = 0; n < 20; n++) {
5967   //       B[m,n] = A[m,n*20+n+2];
5968   //     }
5969   //   }
5970   BufHandle a_buf("A", {20, 400}, kInt);
5971   BufHandle b_buf("B", {20, 400}, kInt);
5972   VarHandle i("i", kInt);
5973   VarHandle j("j", kInt);
5974   VarHandle m("m", kInt);
5975   VarHandle n("n", kInt);
5976   auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j);
5977   auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
5978   auto storeB =
5979       Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2}));
5980   auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
5981   auto par = Block::make({forI, forM});
5982 
5983   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5984   ForPtr fused_loop;
5985   ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5986 
5987   std::ostringstream oss;
5988   oss << *par;
5989   const std::string& verification_pattern =
5990       R"IR(
5991 # CHECK: for (int i
5992 # CHECK-NEXT: for (int j
5993 # CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j
5994 # CHECK: for (int n
5995 # CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2]
5996 # CHECK-NOT: for (
5997       )IR";
5998   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5999 
6000   // The fused loop must be the same as the first loop.
6001   ASSERT_EQ(fused_loop, forI);
6002 }
6003 
TEST(LoopNest,fuseLoopsWithMixedLoopVarsAsIndices)6004 TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) {
6005   // Input IR:
6006   //   for (int i = 0; i < 20; i++) {
6007   //     for (int j = 0; j < 20; j++) {
6008   //       A[i,i*20+j] = i + j;
6009   //     }
6010   //   }
6011   //   for (int m = 0; m < 20; m++) {
6012   //     for (int n = 0; n < 20; n++) {
6013   //       B[m,n] = A[m,m*20+n];  // Both indices of A use m
6014   //     }
6015   //   }
6016   BufHandle a_buf("A", {20, 500}, kInt);
6017   BufHandle b_buf("B", {20, 500}, kInt);
6018   VarHandle i("i", kInt);
6019   VarHandle j("j", kInt);
6020   VarHandle m("m", kInt);
6021   VarHandle n("n", kInt);
6022   auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j);
6023   auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
6024   auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n}));
6025   auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
6026   auto par = Block::make({forI, forM});
6027 
6028   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6029   ForPtr fused_loop;
6030   ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6031 }
6032 
TEST(LoopNest,fuseLoopsWithTranspose)6033 TEST(LoopNest, fuseLoopsWithTranspose) {
6034   // Input IR:
6035   //   for (int i = 0; i < 20; i++) {
6036   //     for (int j = 0; j < 20; j++) {
6037   //       A[i,j] = i + j;
6038   //     }
6039   //   }
6040   //   for (int m = 0; m < 20; m++) {
6041   //     for (int n = 0; n < 20; n++) {
6042   //       B[m,n] = A[n,m];  // Transpose
6043   //     }
6044   //   }
6045   BufHandle a_buf("A", {20, 20}, kInt);
6046   BufHandle b_buf("B", {20, 20}, kInt);
6047   VarHandle i("i", kInt);
6048   VarHandle j("j", kInt);
6049   VarHandle m("m", kInt);
6050   VarHandle n("n", kInt);
6051   auto writeA = Store::make(a_buf, {i, j}, i + j);
6052   auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
6053   auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m}));
6054   auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
6055   auto par = Block::make({forI, forM});
6056 
6057   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6058   ForPtr fused_loop;
6059   ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6060 }
6061 
TEST(LoopNest,fuseLoopsThatViolateDependencies1)6062 TEST(LoopNest, fuseLoopsThatViolateDependencies1) {
6063   // Input IR:
6064   //   for (int j = 10; j < 100; j++) {
6065   //     A[j] = 10 * j;
6066   //   }
6067   //   for (int k = 10; k < 100; k++) {
6068   //     A[k-1] = 20 * k;
6069   //   }
6070   BufHandle a_buf("A", {100}, kInt);
6071   VarHandle j("j", kInt);
6072   VarHandle k("k", kInt);
6073   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6074   auto forK =
6075       For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k)));
6076   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6077   auto par = Block::make({forJ, forK});
6078   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6079   ForPtr fused_loop;
6080   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6081 }
6082 
TEST(LoopNest,fuseLoopsThatViolateDependencies2)6083 TEST(LoopNest, fuseLoopsThatViolateDependencies2) {
6084   // Input IR:
6085   //   for (int j = 10; j < 100; j++) {
6086   //     A[j] = 10 * j;
6087   //   }
6088   //   for (int k = 10; k < 100; k++) {
6089   //     A[k+50] = 20 * k;
6090   //   }
6091   BufHandle a_buf("A", {150}, kInt);
6092   VarHandle j("j", kInt);
6093   VarHandle k("k", kInt);
6094   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6095   auto forK =
6096       For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k)));
6097   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6098   auto par = Block::make({forJ, forK});
6099   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6100   ForPtr fused_loop;
6101   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6102 }
6103 
TEST(LoopNest,fuseLoopsThatViolateDependencies3)6104 TEST(LoopNest, fuseLoopsThatViolateDependencies3) {
6105   // Input IR:
6106   //   for (int m = 0; m < 20; m++) {
6107   //     A[m] = 0;
6108   //     for (int j = 0; j < 100; j++) {
6109   //       A[m] = A[m] + m * j;
6110   //     }
6111   //   }
6112   //   for (int n = 0; n < 20; n++) {
6113   //     B[n] = A[n+1];
6114   //     for (int k = 0; k < 50; k++) {
6115   //       B[n] = B[n] + n * k;
6116   //     }
6117   //   }
6118   BufHandle a_buf("A", {25, 100}, kInt);
6119   BufHandle b_buf("B", {20, 50}, kInt);
6120   VarHandle m("m", kInt);
6121   VarHandle n("n", kInt);
6122   VarHandle j("j", kInt);
6123   VarHandle k("k", kInt);
6124   auto initA = Store::make(a_buf, {m}, 0);
6125   auto forJ = For::make(
6126       j,
6127       0,
6128       100,
6129       Store::make(
6130           a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
6131   auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1}));
6132   auto forK = For::make(
6133       k,
6134       0,
6135       50,
6136       Store::make(
6137           b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
6138   auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
6139   auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
6140   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6141   auto par = Block::make({forM, forN});
6142   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6143   ForPtr fused_loop;
6144   ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
6145 }
6146 
TEST(LoopNest,fuseLoopsThatViolateDependencies4)6147 TEST(LoopNest, fuseLoopsThatViolateDependencies4) {
6148   // Input IR:
6149   //   for (int i = 0; i < 20; i++) {
6150   //     for (int j = 0; j < 100; j++) {
6151   //       A[i,j] = i * j * 500;
6152   //     }
6153   //   }
6154   //   for (int m = 0; m < 20; m++) {
6155   //     for (int n = 0; n < 50; n++) {
6156   //       A[m+1,n] = m + n * 100;
6157   //     }
6158   //   }
6159   BufHandle a_buf("A", {30, 100}, kInt);
6160   VarHandle i("i", kInt);
6161   VarHandle j("j", kInt);
6162   VarHandle m("m", kInt);
6163   VarHandle n("n", kInt);
6164   auto forI = For::make(
6165       i,
6166       0,
6167       20,
6168       For::make(
6169           j,
6170           0,
6171           100,
6172           Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
6173   auto forM = For::make(
6174       m,
6175       0,
6176       20,
6177       For::make(
6178           n,
6179           0,
6180           50,
6181           Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100)))));
6182   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6183   auto par = Block::make({forI, forM});
6184   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6185   ForPtr fused_loop;
6186   ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6187 }
6188 
TEST(LoopNest,fuseLoopsThatViolateDependencies5)6189 TEST(LoopNest, fuseLoopsThatViolateDependencies5) {
6190   // Input IR:
6191   //   for (int i = 0; i < 20; i++) {
6192   //     for (int j = 0; j < 100; j++) {
6193   //       A[i,j] = i * j * 500;
6194   //     }
6195   //     for (int n = 0; n < 100; n++) {
6196   //       A[i,n+1] = m + n * 100;
6197   //     }
6198   //   }
6199   BufHandle a_buf("A", {20, 200}, kInt);
6200   VarHandle i("i", kInt);
6201   VarHandle j("j", kInt);
6202   VarHandle n("n", kInt);
6203   auto forJ = For::make(
6204       j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
6205   auto forN = For::make(
6206       n,
6207       0,
6208       100,
6209       Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100))));
6210   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers)
6211   auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
6212   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6213   ForPtr fused_loop;
6214   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
6215 }
6216 
TEST(LoopNest,fuseLoopsThatViolateDependencies6)6217 TEST(LoopNest, fuseLoopsThatViolateDependencies6) {
6218   // Input IR:
6219   //   for (int j = 0; j < 100; j++) {
6220   //     A[j] = 10 * j;
6221   //   }
6222   //   for (int k = 0; k < 100; k++) {
6223   //     B[k] = 20 * A[99-k];
6224   //   }
6225   BufHandle a_buf("A", {100}, kInt);
6226   BufHandle b_buf("B", {100}, kInt);
6227   VarHandle j("j", kInt);
6228   VarHandle k("k", kInt);
6229   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6230   auto forK = For::make(
6231       k,
6232       0,
6233       100,
6234       Store::make(
6235           b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
6236   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6237   auto par = Block::make({forJ, forK});
6238   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6239   ForPtr fused_loop;
6240   ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6241 }
6242 
TEST(LoopNest,fuseLoopsThatViolateDependencies7)6243 TEST(LoopNest, fuseLoopsThatViolateDependencies7) {
6244   // Input IR:
6245   //   for (int k = 0; k < 100; k++) {
6246   //     B[k] = 20 * A[99-k];
6247   //   }
6248   //   for (int j = 0; j < 100; j++) {
6249   //     A[j] = 10 * j;
6250   //   }
6251   BufHandle a_buf("A", {100}, kInt);
6252   BufHandle b_buf("B", {100}, kInt);
6253   VarHandle j("j", kInt);
6254   VarHandle k("k", kInt);
6255   auto forK = For::make(
6256       k,
6257       0,
6258       100,
6259       Store::make(
6260           b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
6261   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6262   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6263   auto par = Block::make({forK, forJ});
6264   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6265   ForPtr fused_loop;
6266   ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop));
6267 }
6268 
TEST(LoopNest,areLoopsPerfectlyNested)6269 TEST(LoopNest, areLoopsPerfectlyNested) {
6270   // Input IR:
6271   //   for (int i = 0; i < 20; i++) {
6272   //     for (int j = 0; j < 30; j++) {
6273   //       for (int k = 0; k < 40; k++) {
6274   //         A[i,j,k] = i * j * k;
6275   //       }
6276   //     }
6277   //   }
6278   BufHandle a_buf("A", {20, 30, 40}, kInt);
6279   VarHandle i("i", kInt);
6280   VarHandle j("j", kInt);
6281   VarHandle k("k", kInt);
6282   auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6283   auto forK = For::make(k, 0, 40, store);
6284   auto forJ = For::make(j, 0, 30, forK);
6285   auto forI = For::make(i, 0, 20, forJ);
6286   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6287   auto par = Block::make({forI});
6288   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6289 
6290   // Specifying the loops in any other order fails.
6291   ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK}));
6292   ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ}));
6293   ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI}));
6294 
6295   // Adding a statement to forK body should be OK.
6296   auto init = Store::make(a_buf, {i, j}, 0);
6297   forK->body()->insert_stmt_before(init, store);
6298   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6299 
6300   // Adding a statement in forJ body should fail this test.
6301   forK->body()->remove_stmt(init);
6302   forJ->body()->insert_stmt_before(init, forK);
6303   ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6304 
6305   // Similarly, adding a statement in forI body should fail this test.
6306   forJ->body()->remove_stmt(init);
6307   forI->body()->insert_stmt_before(init, forJ);
6308   ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6309 }
6310 
TEST(LoopNest,reorderNestedLoops2D)6311 TEST(LoopNest, reorderNestedLoops2D) {
6312   // Input IR:
6313   //   for (int i = 0; i < 20; i++) {
6314   //     for (int j = 0; j < 30; j++) {
6315   //       A[i,j] = i * j;
6316   //     }
6317   //   }
6318   BufHandle a_buf("A", {20, 30, 40}, kInt);
6319   VarHandle i("i", kInt);
6320   VarHandle j("j", kInt);
6321   auto store = Store::make(a_buf, {i, j}, Mul::make(i, j));
6322   auto forJ = For::make(j, 0, 30, store);
6323   auto forI = For::make(i, 0, 20, forJ);
6324   auto par = Block::make({forI});
6325 
6326   auto reordered = LoopNest::reorder({forI, forJ}, {1, 0});
6327 
6328   ASSERT_EQ(reordered[0], forJ);
6329   ASSERT_EQ(reordered[1], forI);
6330   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI}));
6331   ASSERT_EQ(forJ->get_parent(), par);
6332   ASSERT_EQ(store->get_parent(), forI->body());
6333 }
6334 
TEST(LoopNest,reorderNestedLoops3D)6335 TEST(LoopNest, reorderNestedLoops3D) {
6336   // Input IR:
6337   //   for (int i = 0; i < 20; i++) {
6338   //     for (int j = 0; j < 30; j++) {
6339   //       for (int k = 0; k < 40; k++) {
6340   //         A[i,j,k] = i * j * k;
6341   //       }
6342   //     }
6343   //   }
6344   BufHandle a_buf("A", {20, 30, 40}, kInt);
6345   VarHandle i("i", kInt);
6346   VarHandle j("j", kInt);
6347   VarHandle k("k", kInt);
6348   auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6349   auto forK = For::make(k, 0, 40, store);
6350   auto forJ = For::make(j, 0, 30, forK);
6351   auto forI = For::make(i, 0, 20, forJ);
6352   auto par = Block::make({forI});
6353 
6354   auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1});
6355 
6356   ASSERT_EQ(reordered[0], forK);
6357   ASSERT_EQ(reordered[1], forI);
6358   ASSERT_EQ(reordered[2], forJ);
6359   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ}));
6360   ASSERT_EQ(forK->get_parent(), par);
6361   ASSERT_EQ(store->get_parent(), forJ->body());
6362 }
6363 
TEST(LoopNest,reorderNestedLoops4D)6364 TEST(LoopNest, reorderNestedLoops4D) {
6365   // Input IR:
6366   //   for (int i = 0; i < 20; i++) {
6367   //     for (int j = 0; j < 30; j++) {
6368   //       for (int k = 0; k < 40; k++) {
6369   //         for (int l = 0; l < 50; l++) {
6370   //           A[i,j,k,l] = i * j * k * l * 500;
6371   //         }
6372   //       }
6373   //     }
6374   //   }
6375   BufHandle a_buf("A", {20, 30, 40, 50}, kInt);
6376   VarHandle i("i", kInt);
6377   VarHandle j("j", kInt);
6378   VarHandle k("k", kInt);
6379   VarHandle l("l", kInt);
6380   auto store = Store::make(
6381       a_buf,
6382       {i, j, k, l},
6383       Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500));
6384   auto forL = For::make(l, 0, 50, store);
6385   auto forK = For::make(k, 0, 40, forL);
6386   auto forJ = For::make(j, 0, 30, forK);
6387   auto forI = For::make(i, 0, 20, forJ);
6388   auto par = Block::make({forI});
6389 
6390   auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1});
6391 
6392   ASSERT_EQ(reordered[0], forK);
6393   ASSERT_EQ(reordered[1], forI);
6394   ASSERT_EQ(reordered[2], forL);
6395   ASSERT_EQ(reordered[3], forJ);
6396   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ}));
6397   ASSERT_EQ(forK->get_parent(), par);
6398   ASSERT_EQ(store->get_parent(), forJ->body());
6399 }
6400 
TEST(LoopNest,reorderTrivialPermutation)6401 TEST(LoopNest, reorderTrivialPermutation) {
6402   // Input IR:
6403   //   for (int i = 0; i < 20; i++) {
6404   //     for (int j = 0; j < 30; j++) {
6405   //       for (int k = 0; k < 40; k++) {
6406   //         A[i,j,k] = i * j * k;
6407   //       }
6408   //     }
6409   //   }
6410   BufHandle a_buf("A", {20, 30, 40}, kInt);
6411   VarHandle i("i", kInt);
6412   VarHandle j("j", kInt);
6413   VarHandle k("k", kInt);
6414   auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6415   auto forK = For::make(k, 0, 40, store);
6416   auto forJ = For::make(j, 0, 30, forK);
6417   auto forI = For::make(i, 0, 20, forJ);
6418   auto par = Block::make({forI});
6419 
6420   auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2});
6421 
6422   ASSERT_EQ(reordered[0], forI);
6423   ASSERT_EQ(reordered[1], forJ);
6424   ASSERT_EQ(reordered[2], forK);
6425   ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6426   ASSERT_EQ(forI->get_parent(), par);
6427   ASSERT_EQ(store->get_parent(), forK->body());
6428 }
6429 
TEST(LoopNest,reorderInvalidPermutations)6430 TEST(LoopNest, reorderInvalidPermutations) {
6431   // Input IR:
6432   //   for (int i = 0; i < 20; i++) {
6433   //     for (int j = 0; j < 30; j++) {
6434   //       for (int k = 0; k < 40; k++) {
6435   //         A[i,j,k] = i * j * k;
6436   //       }
6437   //     }
6438   //   }
6439   BufHandle a_buf("A", {20, 30, 40}, kInt);
6440   VarHandle i("i", kInt);
6441   VarHandle j("j", kInt);
6442   VarHandle k("k", kInt);
6443   auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6444   auto forK = For::make(k, 0, 40, store);
6445   auto forJ = For::make(j, 0, 30, forK);
6446   auto forI = For::make(i, 0, 20, forJ);
6447   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6448   auto par = Block::make({forI});
6449 
6450   ASSERT_THROWS_WITH(
6451       LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}),
6452       "invalid permutation size");
6453   ASSERT_THROWS_WITH(
6454       LoopNest::reorder({forI, forJ, forK}, {1, 2}),
6455       "invalid permutation size");
6456   ASSERT_THROWS_WITH(
6457       LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}),
6458       "invalid permutation for reorder");
6459   ASSERT_THROWS_WITH(
6460       LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}),
6461       "invalid permutation for reorder");
6462   ASSERT_THROWS_WITH(
6463       LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}),
6464       "invalid permutation for reorder");
6465 }
6466 
TEST(LoopNest,reorderInvalidLoopNest)6467 TEST(LoopNest, reorderInvalidLoopNest) {
6468   // Input IR:
6469   //   for (int i = 0; i < 20; i++) {
6470   //     for (int j = 0; j < 30; j++) {
6471   //       A[i,j] = 0
6472   //       for (int k = 0; k < 40; k++) {
6473   //         A[i,j,k] = i * j * k;
6474   //       }
6475   //     }
6476   //   }
6477   BufHandle a_buf("A", {20, 30, 40}, kInt);
6478   VarHandle i("i", kInt);
6479   VarHandle j("j", kInt);
6480   VarHandle k("k", kInt);
6481   auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6482   auto forK = For::make(k, 0, 40, store);
6483   auto forJ = For::make(j, 0, 30, forK);
6484   auto forI = For::make(i, 0, 20, forJ);
6485   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6486   auto par = Block::make({forI});
6487 
6488   // Specifying the loops in incorrect order fails.
6489   ASSERT_THROWS_WITH(
6490       LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}),
6491       "reorder is only allowed on perfectly nested loops");
6492 
6493   // Adding a statement to forJ loop fails.
6494   auto init = Store::make(a_buf, {i}, 0);
6495   forJ->body()->insert_stmt_before(init, forK);
6496   ASSERT_THROWS_WITH(
6497       LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
6498       "reorder is only allowed on perfectly nested loops");
6499 
6500   // Moving that statement to forI loop also fails.
6501   forJ->body()->remove_stmt(init);
6502   forI->body()->insert_stmt_before(init, forJ);
6503   ASSERT_THROWS_WITH(
6504       LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
6505       "reorder is only allowed on perfectly nested loops");
6506 }
6507 
TEST(LoopNest,compressBufferSimple)6508 TEST(LoopNest, compressBufferSimple) {
6509   // Input IR:
6510   // for (int i = 0; i < 100; ++i) {
6511   //   for (int j = 0; j < 200; ++j) {
6512   //     A[i,j] = sin(i*j)
6513   //   }
6514   //   for (int j = 0; j < 199; ++j) {
6515   //     B[i,j] = A[i,j] + A[i, j+1]
6516   //   }
6517   // }
6518   BufHandle aBuf("A", {100, 200}, kInt);
6519   BufHandle bBuf("B", {100, 200}, kInt);
6520   VarHandle i("i", kInt);
6521   VarHandle j("j", kInt);
6522   auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6523   auto forJ2 = For::make(
6524       j,
6525       0,
6526       199,
6527       Store::make(
6528           bBuf,
6529           {i, j},
6530           Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6531   auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6532   auto par = Block::make({forI});
6533   LoopNest::compressBuffer(aBuf.node(), par);
6534 
6535   std::ostringstream oss;
6536   oss << *par;
6537   const std::string& verification_pattern =
6538       R"IR(
6539 # CHECK: for (int i
6540 # CHECK-NEXT: for (int j
6541 # CHECK-NEXT: A[0, j] =
6542 # CHECK: for (int j
6543 # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
6544       )IR";
6545   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6546 
6547   ASSERT_EQ(aBuf.node()->ndim(), 2);
6548   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6549   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6550 }
6551 
TEST(LoopNest,compressBufferMultipleDims)6552 TEST(LoopNest, compressBufferMultipleDims) {
6553   // Input IR:
6554   // for (int i = 0; i < 100; ++i) {
6555   //   for (int j = 0; j < 200; ++j) {
6556   //     A[i,j] = sin(i*j)
6557   //     B[i,j] = A[i,j] + A[i,j]
6558   //   }
6559   // }
6560   BufHandle aBuf("A", {100, 200}, kInt);
6561   BufHandle bBuf("B", {100, 200}, kInt);
6562   VarHandle i("i", kInt);
6563   VarHandle j("j", kInt);
6564   auto store1 = Store::make(aBuf, {i, j}, sin(i * j));
6565   auto store2 = Store::make(
6566       bBuf,
6567       {i, j},
6568       Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j})));
6569   auto forJ = For::make(j, 0, 200, Block::make({store1, store2}));
6570   auto forI = For::make(i, 0, 100, forJ);
6571   auto par = Block::make({forI});
6572   LoopNest::compressBuffer(aBuf.node(), par);
6573 
6574   std::ostringstream oss;
6575   oss << *par;
6576   const std::string& verification_pattern =
6577       R"IR(
6578 # CHECK: for (int i
6579 # CHECK-NEXT: for (int j
6580 # CHECK-NEXT: A[0, 0] =
6581 # CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0])
6582       )IR";
6583   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6584 
6585   ASSERT_EQ(aBuf.node()->ndim(), 2);
6586   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6587   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6588 }
6589 
TEST(LoopNest,compressBufferMultipleDims2)6590 TEST(LoopNest, compressBufferMultipleDims2) {
6591   // Input IR:
6592   // for (int i = 0; i < 100; ++i) {
6593   //   for (int j = 0; j < 200; ++j) {
6594   //     for (int k = 0; k < 300; ++k) {
6595   //       A[i,j,k] = sin(i*j*k)
6596   //     }
6597   //     for (int k = 0; k < 299; ++j) {
6598   //       B[i,j,k] = A[i,j,k] + A[i,j,k+1]
6599   //     }
6600   //   }
6601   // }
6602   BufHandle aBuf("A", {100, 200, 300}, kInt);
6603   BufHandle bBuf("B", {100, 200, 300}, kInt);
6604   VarHandle i("i", kInt);
6605   VarHandle j("j", kInt);
6606   VarHandle k("k", kInt);
6607   auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k));
6608   auto forK1 = For::make(k, 0, 300, store1);
6609   auto store2 = Store::make(
6610       bBuf,
6611       {i, j, k},
6612       Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1})));
6613   auto forK2 = For::make(k, 0, 299, store2);
6614   auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2}));
6615   auto forI = For::make(i, 0, 100, forJ);
6616   auto par = Block::make({forI});
6617   LoopNest::compressBuffer(aBuf.node(), par);
6618 
6619   std::ostringstream oss;
6620   oss << *par;
6621   const std::string& verification_pattern =
6622       R"IR(
6623 # CHECK: for (int i
6624 # CHECK-NEXT: for (int j
6625 # CHECK-NEXT: for (int k
6626 # CHECK-NEXT: A[0, 0, k] =
6627 # CHECK: for (int k
6628 # CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1])
6629       )IR";
6630   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6631 
6632   ASSERT_EQ(aBuf.node()->ndim(), 3);
6633   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6634   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6635   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300);
6636 }
6637 
TEST(LoopNest,compressBufferDifferentOrderIndices)6638 TEST(LoopNest, compressBufferDifferentOrderIndices) {
6639   // Input IR:
6640   // for (int i = 0; i < 100; ++i) {
6641   //   for (int j = 0; j < 200; ++j) {
6642   //     A[j, i] = sin(i*j)
6643   //   }
6644   //   for (int j = 0; j < 99; ++j) {
6645   //     B[i, j] = A[j, i] + A[j+1, 0]
6646   //   }
6647   // }
6648   BufHandle aBuf("A", {100, 200}, kInt);
6649   BufHandle bBuf("B", {100, 200}, kInt);
6650   VarHandle i("i", kInt);
6651   VarHandle j("j", kInt);
6652   auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j)));
6653   auto forJ2 = For::make(
6654       j,
6655       0,
6656       99,
6657       Store::make(
6658           bBuf,
6659           {i, j},
6660           Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i}))));
6661   auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6662   auto par = Block::make({forI});
6663   LoopNest::compressBuffer(aBuf.node(), par);
6664 
6665   std::ostringstream oss;
6666   oss << *par;
6667   const std::string& verification_pattern =
6668       R"IR(
6669 # CHECK: for (int i
6670 # CHECK-NEXT: for (int j
6671 # CHECK-NEXT: A[j, 0] =
6672 # CHECK: for (int j
6673 # CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0])
6674       )IR";
6675   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6676 
6677   ASSERT_EQ(aBuf.node()->ndim(), 2);
6678   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
6679   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6680 }
6681 
TEST(LoopNest,compressBufferVariableBounds)6682 TEST(LoopNest, compressBufferVariableBounds) {
6683   // Input IR:
6684   // for (int i = 0; i < M; ++i) {
6685   //   for (int j = 0; j < N; ++j) {
6686   //     A[i,j] = sin(i*j)
6687   //   }
6688   //   for (int j = 0; j < N-1; ++j) {
6689   //     B[i,j] = A[i,j] + A[i, j+1]
6690   //   }
6691   // }
6692   BufHandle aBuf("A", {100, 200}, kInt);
6693   BufHandle bBuf("B", {100, 200}, kInt);
6694   VarHandle i("i", kInt);
6695   VarHandle j("j", kInt);
6696   VarHandle M("M", kInt);
6697   VarHandle N("N", kInt);
6698   auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j)));
6699   auto forJ2 = For::make(
6700       j,
6701       0,
6702       N - 1,
6703       Store::make(
6704           bBuf,
6705           {i, j},
6706           Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6707   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
6708   auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2}));
6709   auto par = Block::make({forI});
6710   LoopNest::compressBuffer(aBuf.node(), par);
6711 
6712   std::ostringstream oss;
6713   oss << *par;
6714   const std::string& verification_pattern =
6715       R"IR(
6716 # CHECK: for (int i
6717 # CHECK-NEXT: for (int j
6718 # CHECK-NEXT: A[0, j] =
6719 # CHECK: for (int j
6720 # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
6721       )IR";
6722   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6723 
6724   ASSERT_EQ(aBuf.node()->ndim(), 2);
6725   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6726   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6727 }
6728 
TEST(LoopNest,compressBufferNoCommonParentLoops)6729 TEST(LoopNest, compressBufferNoCommonParentLoops) {
6730   // Input IR:
6731   // for (int i = 0; i < 100; ++i) {
6732   //   for (int j = 0; j < 200; ++j) {
6733   //     A[i,j] = sin(i*j)
6734   //   }
6735   // }
6736   // for (int i = 0; i < 100; ++i) {
6737   //   for (int j = 0; j < 199; ++j) {
6738   //     B[i,j] = A[i,j] + A[i, j+1]
6739   //   }
6740   // }
6741   BufHandle aBuf("A", {100, 200}, kInt);
6742   BufHandle bBuf("B", {100, 200}, kInt);
6743   VarHandle i("i", kInt);
6744   VarHandle j("j", kInt);
6745   auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6746   auto forJ2 = For::make(
6747       j,
6748       0,
6749       199,
6750       Store::make(
6751           bBuf,
6752           {i, j},
6753           Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6754   auto forI1 = For::make(i, 0, 100, forJ1);
6755   auto forI2 = For::make(i, 0, 100, forJ2);
6756   auto par = Block::make({forI1, forI2});
6757   LoopNest::compressBuffer(aBuf.node(), par);
6758 
6759   // There should be no change in the buffer or code.
6760   std::ostringstream oss;
6761   oss << *par;
6762   const std::string& verification_pattern =
6763       R"IR(
6764 # CHECK: for (int i
6765 # CHECK-NEXT: for (int j
6766 # CHECK-NEXT: A[i, j] =
6767 # CHECK: for (int i
6768 # CHECK-NEXT: for (int j
6769 # CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1])
6770       )IR";
6771   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6772 
6773   ASSERT_EQ(aBuf.node()->ndim(), 2);
6774   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
6775   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6776 }
6777 
TEST(LoopNest,compressBufferIndicesMixed)6778 TEST(LoopNest, compressBufferIndicesMixed) {
6779   // Input IR:
6780   // for (int i = 0; i < 100; ++i) {
6781   //   for (int j = 0; j < 200; ++j) {
6782   //     A[i + j, j] = sin(i*j)
6783   //   }
6784   //   for (int j = 0; j < 199; ++j) {
6785   //     B[i,j] = A[i + j, j] + A[i + j, j+1]
6786   //   }
6787   // }
6788   BufHandle aBuf("A", {300, 200}, kInt);
6789   BufHandle bBuf("B", {100, 200}, kInt);
6790   VarHandle i("i", kInt);
6791   VarHandle j("j", kInt);
6792   auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j)));
6793   auto forJ2 = For::make(
6794       j,
6795       0,
6796       199,
6797       Store::make(
6798           bBuf,
6799           {i, j},
6800           Add::make(
6801               Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1}))));
6802   auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6803   auto par = Block::make({forI});
6804   LoopNest::compressBuffer(aBuf.node(), par);
6805 
6806   // There should be no change in the buffer or code.
6807   std::ostringstream oss;
6808   oss << *par;
6809   const std::string& verification_pattern =
6810       R"IR(
6811 # CHECK: for (int i
6812 # CHECK-NEXT: for (int j
6813 # CHECK-NEXT: A[i + j, j] =
6814 # CHECK: for (int j
6815 # CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1])
6816       )IR";
6817   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6818 
6819   ASSERT_EQ(aBuf.node()->ndim(), 2);
6820   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300);
6821   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6822 }
6823 
TEST(LoopNest,compressMultipleBuffers)6824 TEST(LoopNest, compressMultipleBuffers) {
6825   // Input IR:
6826   // for (int i = 0; i < 100; ++i) {
6827   //   for (int j = 0; j < 200; ++j) {
6828   //     A[i,j] = sin(i*j)
6829   //   }
6830   //   for (int k = 0; k < 199; ++k) {
6831   //     B[i,k] = A[i,k] + A[i, k+1]
6832   //   }
6833   //   for (int m = 0; m < 50; ++m) {
6834   //     C[i,m] = B[i,m]
6835   //   }
6836   // }
6837   BufHandle aBuf("A", {100, 200}, kInt);
6838   BufHandle bBuf("B", {100, 200}, kInt);
6839   BufHandle cBuf("C", {100, 200}, kInt);
6840   VarHandle i("i", kInt);
6841   VarHandle j("j", kInt);
6842   VarHandle k("k", kInt);
6843   VarHandle m("m", kInt);
6844   auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6845   auto forK = For::make(
6846       k,
6847       0,
6848       199,
6849       Store::make(
6850           bBuf,
6851           {i, k},
6852           Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1}))));
6853   auto forM =
6854       For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m})));
6855   auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM}));
6856   auto par = Block::make({forI});
6857 
6858   // This should compress all buffers A, B, and C as follows:
6859   //   A[100, 200] -> A[1, 200]
6860   //   B[100, 200] -> B[1, 200]
6861   //   C[100, 200] -> C[1, 1]
6862   LoopNest::compressAllBuffers(par);
6863 
6864   std::ostringstream oss;
6865   oss << *par;
6866   const std::string& verification_pattern =
6867       R"IR(
6868 # CHECK: for (int i
6869 # CHECK-NEXT: for (int j
6870 # CHECK-NEXT: A[0, j] =
6871 # CHECK: for (int k
6872 # CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1])
6873 # CHECK: for (int m
6874 # CHECK-NEXT: C[0, 0] = B[0, m]
6875       )IR";
6876   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6877 
6878   ASSERT_EQ(aBuf.node()->ndim(), 2);
6879   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6880   IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6881   ASSERT_EQ(bBuf.node()->ndim(), 2);
6882   IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1);
6883   IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200);
6884   ASSERT_EQ(cBuf.node()->ndim(), 2);
6885   IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1);
6886   IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1);
6887 }
6888 
TEST(LoopNest,sanitizeNames)6889 TEST(LoopNest, sanitizeNames) {
6890   std::vector<ExprHandle> dim_args;
6891   // Let's pick names that would overlap with default index names if not
6892   // sanitized properly:
6893   dim_args.emplace_back(ExprHandle(alloc<Var>("i", kInt)));
6894   dim_args.emplace_back(ExprHandle(alloc<Var>("N:2", kInt)));
6895   // Now let's create a many dimensions so that we had to use the same letter
6896   // for different loops
6897   for (int i = 0; i < 10; i++) {
6898     dim_args.emplace_back(ExprHandle(alloc<Var>("N", kInt)));
6899   }
6900 
6901   // Now create two Computes with conflicting after sanitization names:
6902   Tensor X = Compute("$X:!", dim_args, [&](const std::vector<VarHandle>& v) {
6903     return v[0] + v[1] + v[9] + 1;
6904   });
6905   Tensor Y = Reduce(
6906       "%X\"+",
6907       {},
6908       Sum(),
6909       [&](const std::vector<VarHandle>& v) { return X.load(v); },
6910       dim_args);
6911 
6912   // Finally, let's verify what we got after sanitization:
6913   LoopNest l({X, Y});
6914   StmtPtr s = l.root_stmt();
6915   LoopNest::sanitizeNames(s);
6916 
6917   std::ostringstream oss;
6918   oss << *s;
6919   const std::string& verification_pattern =
6920       R"IR(
6921 # CHECK:  for (int i = 0; i < i_1; i++) {
6922 # CHECK-NEXT:    for (int j = 0; j < N_2_1; j++) {
6923 # CHECK-NEXT:      for (int k = 0; k < N_9; k++) {
6924 # CHECK-NEXT:        for (int l = 0; l < N_8; l++) {
6925 # CHECK-NEXT:          for (int m = 0; m < N_7; m++) {
6926 # CHECK-NEXT:            for (int n = 0; n < N_6; n++) {
6927 # CHECK-NEXT:              for (int o = 0; o < N_5; o++) {
6928 # CHECK-NEXT:                for (int p = 0; p < N_4; p++) {
6929 # CHECK-NEXT:                  for (int i1 = 0; i1 < N_3; i1++) {
6930 # CHECK-NEXT:                    for (int j1 = 0; j1 < N_2; j1++) {
6931 # CHECK-NEXT:                      for (int k1 = 0; k1 < N_1; k1++) {
6932 # CHECK-NEXT:                        for (int l1 = 0; l1 < N; l1++) {
6933 # CHECK-NEXT:                          v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1;
6934 # CHECK:  v_X___1 = int(0);
6935 # CHECK-NEXT:  for (int i_2 = 0; i_2 < i_1; i_2++) {
6936 # CHECK-NEXT:    for (int j_1 = 0; j_1 < N_2_1; j_1++) {
6937 # CHECK-NEXT:      for (int k_1 = 0; k_1 < N_9; k_1++) {
6938 # CHECK-NEXT:        for (int l_1 = 0; l_1 < N_8; l_1++) {
6939 # CHECK-NEXT:          for (int m_1 = 0; m_1 < N_7; m_1++) {
6940 # CHECK-NEXT:            for (int n_1 = 0; n_1 < N_6; n_1++) {
6941 # CHECK-NEXT:              for (int o_1 = 0; o_1 < N_5; o_1++) {
6942 # CHECK-NEXT:                for (int p_1 = 0; p_1 < N_4; p_1++) {
6943 # CHECK-NEXT:                  for (int i1_1 = 0; i1_1 < N_3; i1_1++) {
6944 # CHECK-NEXT:                    for (int j1_1 = 0; j1_1 < N_2; j1_1++) {
6945 # CHECK-NEXT:                      for (int k1_1 = 0; k1_1 < N_1; k1_1++) {
6946 # CHECK-NEXT:                        for (int l1_1 = 0; l1_1 < N; l1_1++) {
6947 # CHECK-NEXT:                          v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1});
6948       )IR";
6949   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6950 }
6951 
6952 } // namespace jit
6953 } // namespace torch
6954