1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
17 #define COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
18
19 /*
20 Arch-feature definitions
21 */
22 #include <bitset>
23 #include <cstdint>
24 #include <type_traits>
25
26 #include "type_info.h"
27 #include "utils/arch.h"
28 #include "utils/arena_containers.h"
29 #include "utils/bit_field.h"
30 #include "utils/bit_utils.h"
31 #include "utils/regmask.h"
32 #include "compiler/optimizer/ir/constants.h"
33 #include "compiler/optimizer/ir/datatype.h"
34 #include "utils/type_helpers.h"
35
36 namespace ark::compiler {
37 // Mapping model for registers:
38 // reg-reg - support getters for small parts of registers
39 // reg-other - mapping between types of registers
40 enum RegMapping : uint32_t {
41 SCALAR_SCALAR = 1UL << 0UL,
42 SCALAR_VECTOR = 1UL << 1UL,
43 SCALAR_FLOAT = 1UL << 2UL,
44 VECTOR_VECTOR = 1UL << 3UL,
45 VECTOR_FLOAT = 1UL << 4UL,
46 FLOAT_FLOAT = 1UL << 5UL
47 };
48
49 constexpr uint8_t INVALID_REG_ID = std::numeric_limits<uint8_t>::max();
50 constexpr uint8_t ACC_REG_ID = INVALID_REG_ID - 1U;
51 #ifdef ENABLE_LIBABCKIT
GetAccReg()52 inline Register GetAccReg()
53 {
54 return IsFrameSizeLarge() ? INVALID_REG_LARGE - 1U : ACC_REG_ID;
55 }
56 #else
GetAccReg()57 inline Register GetAccReg()
58 {
59 return ACC_REG_ID;
60 }
61 #endif
62
63 class Reg final {
64 public:
65 using RegIDType = uint8_t;
66 using RegSizeType = size_t;
67
68 constexpr Reg() = default;
69 DEFAULT_MOVE_SEMANTIC(Reg);
70 DEFAULT_COPY_SEMANTIC(Reg);
71 ~Reg() = default;
72
73 // Default register constructor
Reg(RegIDType id,TypeInfo type)74 constexpr Reg(RegIDType id, TypeInfo type) : id_(id), type_(type) {}
75
GetId()76 constexpr RegIDType GetId() const
77 {
78 return id_;
79 }
80
GetMask()81 constexpr size_t GetMask() const
82 {
83 CHECK_LT(id_, 32U);
84 return (1U << id_);
85 }
86
GetType()87 constexpr TypeInfo GetType() const
88 {
89 return type_;
90 }
91
GetSize()92 RegSizeType GetSize() const
93 {
94 return GetType().GetSize();
95 }
96
IsScalar()97 bool IsScalar() const
98 {
99 return GetType().IsScalar();
100 }
101
IsFloat()102 bool IsFloat() const
103 {
104 return GetType().IsFloat();
105 }
106
IsValid()107 constexpr bool IsValid() const
108 {
109 return type_ != INVALID_TYPE && id_ != INVALID_REG_ID;
110 }
111
As(TypeInfo type)112 Reg As(TypeInfo type) const
113 {
114 return Reg(GetId(), type);
115 }
116
117 constexpr bool operator==(Reg other) const
118 {
119 return (GetId() == other.GetId()) && (GetType() == other.GetType());
120 }
121
122 constexpr bool operator!=(Reg other) const
123 {
124 return !operator==(other);
125 }
126
Dump()127 void Dump()
128 {
129 std::cerr << " Reg: id = " << static_cast<int64_t>(id_) << ", ";
130 type_.Dump();
131 std::cerr << "\n";
132 }
133
134 private:
135 RegIDType id_ {INVALID_REG_ID};
136 TypeInfo type_ {INVALID_TYPE};
137 }; // Reg
138
139 constexpr Reg INVALID_REGISTER = Reg();
140
141 static_assert(!INVALID_REGISTER.IsValid());
142 static_assert(sizeof(Reg) <= sizeof(uintptr_t));
143
144 /**
145 * Immediate class may hold only int or float values (maybe vectors in future).
146 * It knows nothing about pointers and bools (bools maybe be in future).
147 */
148 class Imm final {
149 static constexpr size_t UNDEFINED_SIZE = 0;
150 static constexpr size_t INT64_SIZE = 64;
151 static constexpr size_t FLOAT32_SIZE = 32;
152 static constexpr size_t FLOAT64_SIZE = 64;
153
154 public:
155 constexpr Imm() = default;
156
157 template <typename T>
Imm(T value)158 constexpr explicit Imm(T value) : value_(static_cast<int64_t>(value))
159 {
160 using Type = std::decay_t<T>;
161 static_assert(std::is_integral_v<Type> || std::is_enum_v<Type>);
162 }
163
164 // Partial template specialization
Imm(int64_t value)165 constexpr explicit Imm(int64_t value) : value_(value) {};
166 #ifndef NDEBUG
Imm(double value)167 constexpr explicit Imm(double value) : value_(value) {};
Imm(float value)168 constexpr explicit Imm(float value) : value_(value) {};
169 #else
Imm(double value)170 explicit Imm(double value) : value_(bit_cast<uint64_t>(value)) {};
Imm(float value)171 explicit Imm(float value) : value_(bit_cast<uint32_t>(value)) {};
172 #endif // !NDEBUG
173
174 DEFAULT_MOVE_SEMANTIC(Imm);
175 DEFAULT_COPY_SEMANTIC(Imm);
176 ~Imm() = default;
177
178 #ifdef NDEBUG
GetAsInt()179 constexpr int64_t GetAsInt() const
180 {
181 return value_;
182 }
183
GetAsFloat()184 float GetAsFloat() const
185 {
186 return bit_cast<float>(static_cast<int32_t>(value_));
187 }
188
GetAsDouble()189 double GetAsDouble() const
190 {
191 return bit_cast<double>(value_);
192 }
193
GetRawValue()194 constexpr int64_t GetRawValue() const
195 {
196 return value_;
197 }
198
199 #else
GetAsInt()200 constexpr int64_t GetAsInt() const
201 {
202 ASSERT(std::holds_alternative<int64_t>(value_));
203 return std::get<int64_t>(value_);
204 }
205
GetAsFloat()206 float GetAsFloat() const
207 {
208 ASSERT(std::holds_alternative<float>(value_));
209 return std::get<float>(value_);
210 }
211
GetAsDouble()212 double GetAsDouble() const
213 {
214 ASSERT(std::holds_alternative<double>(value_));
215 return std::get<double>(value_);
216 }
217
GetRawValue()218 constexpr int64_t GetRawValue() const
219 {
220 if (value_.index() == 0) {
221 UNREACHABLE();
222 } else if (value_.index() == 1) {
223 return std::get<int64_t>(value_);
224 } else if (value_.index() == 2U) {
225 return static_cast<int64_t>(bit_cast<int32_t>(std::get<float>(value_)));
226 } else if (value_.index() == 3U) {
227 return bit_cast<int64_t>(std::get<double>(value_));
228 }
229 UNREACHABLE();
230 }
231
232 enum VariantID {
233 V_INVALID = 0, // Pointer used for invalidate variants
234 V_INT64 = 1,
235 V_FLOAT32 = 2,
236 V_FLOAT64 = 3
237 };
238
239 template <class T>
CheckVariantID()240 constexpr bool CheckVariantID() const
241 {
242 #ifndef __clang_analyzer__
243 // Immediate could be only signed (int/float)
244 // look at value_-type.
245 static_assert(std::is_signed_v<T>);
246 if constexpr (std::is_same<T, int64_t>()) {
247 return value_.index() == V_INT64;
248 }
249 if constexpr (std::is_same<T, float>()) {
250 return value_.index() == V_FLOAT32;
251 }
252 if constexpr (std::is_same<T, double>()) {
253 return value_.index() == V_FLOAT64;
254 }
255 return false;
256 #else
257 return true;
258 #endif // !__clang_analyzer__
259 }
260
IsValid()261 constexpr bool IsValid() const
262 {
263 return !std::holds_alternative<void *>(value_);
264 }
265
GetType()266 TypeInfo GetType() const
267 {
268 switch (value_.index()) {
269 case V_INT64:
270 return INT64_TYPE;
271 case V_FLOAT32:
272 return FLOAT32_TYPE;
273 case V_FLOAT64:
274 return FLOAT64_TYPE;
275 default:
276 UNREACHABLE();
277 return INVALID_TYPE;
278 }
279 }
280
GetSize()281 constexpr size_t GetSize() const
282 {
283 switch (value_.index()) {
284 case V_INT64:
285 return INT64_SIZE;
286 case V_FLOAT32:
287 return FLOAT32_SIZE;
288 case V_FLOAT64:
289 return FLOAT64_SIZE;
290 default:
291 UNREACHABLE();
292 return UNDEFINED_SIZE;
293 }
294 }
295 #endif // NDEBUG
296
297 bool operator==(Imm other) const
298 {
299 return value_ == other.value_;
300 }
301
302 bool operator!=(Imm other) const
303 {
304 return !(operator==(other));
305 }
306
307 private:
308 #ifndef NDEBUG
309 std::variant<void *, int64_t, float, double> value_ {nullptr};
310 #else
311 int64_t value_ {0};
312 #endif // NDEBUG
313 }; // Imm
314
315 class TypedImm final {
316 public:
317 template <typename T>
TypedImm(T imm)318 constexpr explicit TypedImm(T imm) : type_(imm), imm_(imm)
319 {
320 }
321
GetType()322 TypeInfo GetType() const
323 {
324 return type_;
325 }
326
GetImm()327 Imm GetImm() const
328 {
329 return imm_;
330 }
331
332 private:
333 TypeInfo type_ {INVALID_TYPE};
334 Imm imm_ {0};
335 };
336
337 // Why memory ref - because you may create one link for one encode-session
338 // And when you see this one - you can easy understand, what type of memory
339 // you use. But if you load/store dirrectly address - you need to decode it
340 // each time, when you read code
341 // model -> base + index<<scale + disp
342 class MemRef final {
343 public:
344 MemRef() = default;
345
MemRef(Reg base)346 explicit MemRef(Reg base) : MemRef(base, 0) {}
MemRef(Reg base,ssize_t disp)347 MemRef(Reg base, ssize_t disp) : MemRef(base, INVALID_REGISTER, 0, disp) {}
MemRef(Reg base,Reg index,uint16_t scale)348 MemRef(Reg base, Reg index, uint16_t scale) : MemRef(base, index, scale, 0) {}
MemRef(Reg base,Reg index,uint16_t scale,ssize_t disp)349 MemRef(Reg base, Reg index, uint16_t scale, ssize_t disp) : disp_(disp), scale_(scale), base_(base), index_(index)
350 {
351 CHECK_LE(disp, std::numeric_limits<decltype(disp_)>::max());
352 CHECK_LE(scale, std::numeric_limits<decltype(scale_)>::max());
353 }
354 DEFAULT_MOVE_SEMANTIC(MemRef);
355 DEFAULT_COPY_SEMANTIC(MemRef);
356 ~MemRef() = default;
357
GetBase()358 Reg GetBase() const
359 {
360 return base_;
361 }
GetIndex()362 Reg GetIndex() const
363 {
364 return index_;
365 }
GetScale()366 auto GetScale() const
367 {
368 return scale_;
369 }
GetDisp()370 auto GetDisp() const
371 {
372 return disp_;
373 }
374
HasBase()375 bool HasBase() const
376 {
377 return base_.IsValid();
378 }
HasIndex()379 bool HasIndex() const
380 {
381 return index_.IsValid();
382 }
HasScale()383 bool HasScale() const
384 {
385 return HasIndex() && scale_ != 0;
386 }
HasDisp()387 bool HasDisp() const
388 {
389 return disp_ != 0;
390 }
391 // Ref must contain at least one of field
IsValid()392 bool IsValid() const
393 {
394 return HasBase() || HasIndex() || HasScale() || HasDisp();
395 }
396
397 // return true if mem doesn't has index and scalar
IsOffsetMem()398 bool IsOffsetMem() const
399 {
400 return !HasIndex() && !HasScale();
401 }
402
403 bool operator==(MemRef other) const
404 {
405 return (base_ == other.base_) && (index_ == other.index_) && (scale_ == other.scale_) && (disp_ == other.disp_);
406 }
407 bool operator!=(MemRef other) const
408 {
409 return !(operator==(other));
410 }
411
412 private:
413 ssize_t disp_ {0};
414 uint16_t scale_ {0};
415 Reg base_ {INVALID_REGISTER};
416 Reg index_ {INVALID_REGISTER};
417 }; // MemRef
418
419 class Shift final {
420 public:
Shift(Reg base,ShiftType type,uint32_t scale)421 explicit Shift(Reg base, ShiftType type, uint32_t scale) : scale_(scale), base_(base), type_(type) {}
Shift(Reg base,uint32_t scale)422 explicit Shift(Reg base, uint32_t scale) : Shift(base, ShiftType::LSL, scale) {}
423
424 DEFAULT_MOVE_SEMANTIC(Shift);
425 DEFAULT_COPY_SEMANTIC(Shift);
426 ~Shift() = default;
427
GetBase()428 Reg GetBase() const
429 {
430 return base_;
431 }
432
GetType()433 ShiftType GetType() const
434 {
435 return type_;
436 }
437
GetScale()438 uint32_t GetScale() const
439 {
440 return scale_;
441 }
442
443 private:
444 uint32_t scale_ {0};
445 Reg base_;
446 ShiftType type_ {INVALID_SHIFT};
447 };
448
449 } // namespace ark::compiler
450 #endif // COMPILER_OPTIMIZER_CODEGEN_REGISTERS_H
451