• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <stdexcept>
4 #include "test/cpp/tensorexpr/test_base.h"
5 
6 #include <torch/csrc/jit/tensorexpr/expr.h>
7 #include <torch/csrc/jit/tensorexpr/ir.h>
8 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
9 #include <torch/csrc/jit/tensorexpr/loopnest.h>
10 #include <torch/csrc/jit/tensorexpr/tensor.h>
11 #include <torch/csrc/jit/testing/file_check.h>
12 
13 #include <sstream>
14 namespace torch {
15 namespace jit {
16 
17 using namespace torch::jit::tensorexpr;
18 
TEST(IRVerifier,BitwiseOps)19 TEST(IRVerifier, BitwiseOps) {
20   VarPtr X = alloc<Var>("x", kInt);
21   VarPtr Y = alloc<Var>("y", kFloat);
22   {
23     auto a = alloc<And>(X, Y);
24     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
25     EXPECT_ANY_THROW(verify(a));
26   }
27   {
28     auto a = alloc<Or>(X, Y);
29     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
30     EXPECT_ANY_THROW(verify(a));
31   }
32   {
33     auto a = alloc<Xor>(X, Y);
34     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
35     EXPECT_ANY_THROW(verify(a));
36   }
37   {
38     auto a = alloc<Lshift>(X, Y);
39     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
40     EXPECT_ANY_THROW(verify(a));
41   }
42   {
43     auto a = alloc<Rshift>(X, Y);
44     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
45     EXPECT_ANY_THROW(verify(a));
46   }
47 }
48 
TEST(IRVerifier,CompareSelect)49 TEST(IRVerifier, CompareSelect) {
50   ExprPtr X = alloc<IntImm>(1);
51   ExprPtr Y = alloc<FloatImm>(3.14f);
52   {
53     auto a = alloc<CompareSelect>(X, X, X, Y, kEQ);
54     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
55     EXPECT_ANY_THROW(verify(a));
56   }
57   {
58     auto a = alloc<CompareSelect>(X, Y, X, X, kEQ);
59     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
60     EXPECT_ANY_THROW(verify(a));
61   }
62 }
63 
TEST(IRVerifier,Ramp)64 TEST(IRVerifier, Ramp) {
65   VarPtr I = alloc<Var>("i", kInt);
66   VarPtr J = alloc<Var>("j", kFloat);
67   {
68     auto a = alloc<Ramp>(I, J, 4);
69     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
70     EXPECT_ANY_THROW(verify(a));
71   }
72 }
73 
TEST(IRVerifier,Load)74 TEST(IRVerifier, Load) {
75   VarPtr I = alloc<Var>("i", kInt);
76   VarPtr J = alloc<Var>("j", kLong);
77   VarPtr K = alloc<Var>("k", kFloat);
78   BufPtr B = alloc<Buf>(
79       "b",
80       std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
81       kFloat);
82   {
83     // Indices with different int dtypes (kInt, kLong) are ok
84     auto a = alloc<Load>(B, std::vector<ExprPtr>({I, J}));
85     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
86     EXPECT_NO_THROW(verify(a));
87   }
88   {
89     // Float index
90     auto a = alloc<Load>(B, std::vector<ExprPtr>({K, K}));
91     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
92     EXPECT_ANY_THROW(verify(a));
93   }
94   {
95     // Multilanes are only allowed in flattened indices
96     auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
97     auto a = alloc<Load>(B, std::vector<ExprPtr>({I, multilane_index}));
98     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
99     EXPECT_ANY_THROW(verify(a));
100   }
101 }
102 
TEST(IRVerifier,IfThenElse)103 TEST(IRVerifier, IfThenElse) {
104   VarPtr I = alloc<Var>("i", kInt);
105   VarPtr J = alloc<Var>("j", kLong);
106   VarPtr K = alloc<Var>("k", kFloat);
107   {
108     // Condition must be integral
109     auto a = alloc<IfThenElse>(K, I, I);
110     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
111     EXPECT_ANY_THROW(verify(a));
112   }
113   {
114     // Dtypes of true and false exprs must match
115     auto a = alloc<IfThenElse>(I, I, J);
116     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
117     EXPECT_ANY_THROW(verify(a));
118   }
119   {
120     // Can't have multiple lanes in condition expr
121     auto a = alloc<IfThenElse>(alloc<Broadcast>(I, 4), I, I);
122     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
123     EXPECT_ANY_THROW(verify(a));
124   }
125 }
126 
TEST(IRVerifier,For)127 TEST(IRVerifier, For) {
128   VarPtr I = alloc<Var>("i", kInt);
129   VarPtr J = alloc<Var>("j", kInt);
130   StmtPtr body = alloc<Block>(std::vector<StmtPtr>({}));
131   {
132     // Can't have nullptr as a Var
133     auto a = alloc<For>(nullptr, I, J, body);
134     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
135     EXPECT_ANY_THROW(verify(a));
136   }
137 }
138 
TEST(IRVerifier,Block)139 TEST(IRVerifier, Block) {
140   VarPtr I = alloc<Var>("i", kInt);
141   BufPtr B = alloc<Buf>("B", std::vector<ExprPtr>({alloc<IntImm>(10)}), kInt);
142   {
143     StmtPtr store = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
144     // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
145     StmtPtr block1 = alloc<Block>(std::vector<StmtPtr>({store}));
146     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
147     StmtPtr block2 = alloc<Block>(std::vector<StmtPtr>({store}));
148     // Stmt can't have multiple parents, thus inserting it into several blocks
149     // is illegal
150     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
151     EXPECT_ANY_THROW(verify(block2));
152   }
153 }
154 
TEST(IRVerifier,Store)155 TEST(IRVerifier, Store) {
156   VarPtr I = alloc<Var>("i", kInt);
157   VarPtr J = alloc<Var>("j", kLong);
158   VarPtr K = alloc<Var>("k", kFloat);
159   BufPtr B = alloc<Buf>(
160       "b",
161       std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
162       kFloat);
163   {
164     // Indices with different int dtypes (kInt, kLong) are ok
165     auto a = alloc<Store>(B, std::vector<ExprPtr>({I, J}), K);
166     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
167     EXPECT_NO_THROW(verify(a));
168   }
169   {
170     // Float index
171     auto a = alloc<Store>(B, std::vector<ExprPtr>({K, K}), K);
172     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
173     EXPECT_ANY_THROW(verify(a));
174   }
175   {
176     // Multilanes are only allowed in flattened indices
177     auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
178     auto a = alloc<Store>(B, std::vector<ExprPtr>({I, multilane_index}), K);
179     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
180     EXPECT_ANY_THROW(verify(a));
181   }
182   {
183     // Value and buf dtypes mismatch
184     auto a = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
185     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
186     EXPECT_ANY_THROW(verify(a));
187   }
188 }
189 
190 } // namespace jit
191 } // namespace torch
192