• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include "torch/csrc/jit/tensorexpr/eval.h"
4 #include "torch/csrc/jit/tensorexpr/ir.h"
5 #include "torch/csrc/jit/tensorexpr/tensor.h"
6 
7 namespace torch {
8 namespace jit {
9 using namespace torch::jit::tensorexpr;
10 
TEST(Type,Test01)11 TEST(Type, Test01) {
12   {
13     Dtype dt1 = kInt;
14     ASSERT_EQ(dt1, kInt);
15   }
16   {
17     Dtype dt2_a(kInt, 8);
18     Dtype dt2_b(kInt, 4);
19     Dtype dt2_c(ScalarType::Int, 8);
20     ASSERT_EQ(dt2_a, dt2_c);
21     ASSERT_NE(dt2_a, dt2_b);
22   }
23   {
24     ASSERT_EQ(kInt, ToDtype<int>());
25     ASSERT_EQ(kFloat, ToDtype<float>());
26     ASSERT_EQ(kByte, ToDtype<uint8_t>());
27     ASSERT_EQ(kChar, ToDtype<int8_t>());
28     ASSERT_EQ(kShort, ToDtype<int16_t>());
29     ASSERT_EQ(kLong, ToDtype<int64_t>());
30     ASSERT_EQ(kHalf, ToDtype<at::Half>());
31     ASSERT_EQ(kDouble, ToDtype<double>());
32     ASSERT_EQ(kBool, ToDtype<bool>());
33   }
34   {
35     Dtype int32x8(kInt, 8);
36     Dtype float32x8(kFloat, 8);
37     ASSERT_NE(int32x8, float32x8);
38     ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
39     ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
40     ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
41     ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
42   }
43 }
44 
TEST(Type,BitCasting)45 TEST(Type, BitCasting) {
46   {
47     VarHandle x("x", kFloat);
48     ExprHandle y = bitcast<int32_t>(x);
49     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
50     ASSERT_EQ(y.dtype(), kInt);
51   }
52   {
53     VarHandle x("x", kInt);
54     ExprHandle y = bitcast<float>(x);
55     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
56     ASSERT_EQ(y.dtype(), kFloat);
57   }
58   {
59     VarHandle x("x", kShort);
60     ExprHandle y = bitcast<at::Half>(x);
61     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
62     ASSERT_EQ(y.dtype(), kHalf);
63   }
64   {
65     VarHandle x("x", kHalf);
66     ExprHandle y = bitcast<int16_t>(x);
67     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
68     ASSERT_EQ(y.dtype(), kShort);
69   }
70 
71   constexpr int32_t ref32 = 1337;
72   constexpr int64_t ref64 = 1337;
73   constexpr float reff32 = 1337.0f;
74   constexpr double reff64 = 1337.0f;
75   using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
76   // this is broken
77   /*{
78     constexpr int16_t ref16 = 1337;
79     at::Half k_;
80     at::Half* k = &k_;
81     *reinterpret_cast<int16_t*>(k) = ref16;
82     auto a = HalfImm::make(*k);
83     auto b = BitCast::make(kShort, a);
84     SimpleIRExprEval cg(b);
85     ASSERT_EQ(cg.value<int16_t>(), ref16);
86   }*/
87 
88   {
89     float k = raw_bitcast<float>(ref32);
90     auto a = FloatImm::make(k);
91     auto b = BitCast::make(kInt, a);
92     SimpleIRExprEval cg(b);
93     ASSERT_EQ(cg.value<int32_t>(), ref32);
94   }
95 
96   {
97     double k = raw_bitcast<double>(ref64);
98     auto a = DoubleImm::make(k);
99     auto b = BitCast::make(kLong, a);
100     SimpleIRExprEval cg(b);
101     ASSERT_EQ(cg.value<int64_t>(), ref64);
102   }
103 
104   {
105     int64_t k = raw_bitcast<int64_t>(reff64);
106     auto a = LongImm::make(k);
107     auto b = BitCast::make(kDouble, a);
108     SimpleIRExprEval cg(b);
109     ASSERT_EQ(cg.value<double>(), reff64);
110   }
111 
112   {
113     int32_t k = raw_bitcast<int32_t>(reff32);
114     auto a = IntImm::make(k);
115     auto b = BitCast::make(kFloat, a);
116     SimpleIRExprEval cg(b);
117     ASSERT_EQ(cg.value<float>(), reff32);
118   }
119 
120   // This segfaults :(
121   /*{
122     VarHandle x("x", kDouble);
123     ASSERT_ANY_THROW(ExprHandle y = bitcast<int32_t>(x));
124   }
125   {
126     VarHandle x("x", kFloat);
127     ASSERT_ANY_THROW(ExprHandle y = bitcast<int64_t>(x));
128   }
129   {
130     VarHandle x("x", kLong);
131     ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
132   }
133   {
134     VarHandle x("x", kShort);
135     ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
136   }
137   {
138     VarHandle x("x", kInt);
139     ASSERT_ANY_THROW(ExprHandle y = bitcast<at::Half>(x));
140   }*/
141 }
142 
TEST(Type,Propagation)143 TEST(Type, Propagation) {
144   // Same types:
145   {
146     VarHandle x("x", kFloat);
147     VarHandle y("y", kFloat);
148     ExprHandle body = FloatImm::make(2.f) +
149         (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y);
150     ASSERT_EQ(body.dtype(), kFloat);
151   }
152   // Int to bigger int:
153   {
154     VarHandle x("x", kShort);
155     VarHandle y("y", kLong);
156     ExprHandle body =
157         ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y);
158     ASSERT_EQ(body.dtype(), kLong);
159   }
160   // Float to bigger float:
161   {
162     VarHandle x("x", kHalf);
163     VarHandle y("y", kDouble);
164     ExprHandle body =
165         HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
166     ASSERT_EQ(body.dtype(), kDouble);
167   }
168   // Int to Float:
169   {
170     VarHandle x("x", kFloat);
171     VarHandle y("y", kInt);
172     ExprHandle body =
173         IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y);
174     ASSERT_EQ(body.dtype(), kFloat);
175   }
176   // Smaller float, bigger Int:
177   {
178     VarHandle x("x", kHalf);
179     VarHandle y("y", kLong);
180     ExprHandle body =
181         HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
182     ASSERT_EQ(body.dtype(), kHalf);
183   }
184   // Bigger float, smaller Int:
185   {
186     VarHandle x("x", kChar);
187     VarHandle y("y", kDouble);
188     ExprHandle body =
189         CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
190     ASSERT_EQ(body.dtype(), kDouble);
191   }
192   // Sign change char/byte upgrades to short:
193   {
194     VarHandle x("x", kChar);
195     VarHandle y("y", kByte);
196     ExprHandle body =
197         CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
198     ASSERT_EQ(body.dtype(), kShort);
199   }
200 }
201 } // namespace jit
202 } // namespace torch
203