• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
17 #define COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
18 
19 /*
20 Arch-feature definitions
21 */
22 #include <bitset>
23 #include <cstdint>
24 #include <type_traits>
25 
26 #include "type_info.h"
27 #include "utils/arch.h"
28 #include "utils/arena_containers.h"
29 #include "utils/bit_field.h"
30 #include "utils/bit_utils.h"
31 #include "utils/regmask.h"
32 #include "compiler/optimizer/ir/constants.h"
33 #include "compiler/optimizer/ir/datatype.h"
34 #include "utils/type_helpers.h"
35 
36 namespace ark::compiler {
37 // Mapping model for registers:
38 // reg-reg - support getters for small parts of registers
39 // reg-other - mapping between types of registers
40 enum RegMapping : uint32_t {
41     SCALAR_SCALAR = 1UL << 0UL,
42     SCALAR_VECTOR = 1UL << 1UL,
43     SCALAR_FLOAT = 1UL << 2UL,
44     VECTOR_VECTOR = 1UL << 3UL,
45     VECTOR_FLOAT = 1UL << 4UL,
46     FLOAT_FLOAT = 1UL << 5UL
47 };
48 
49 constexpr uint8_t INVALID_REG_ID = std::numeric_limits<uint8_t>::max();
50 constexpr uint8_t ACC_REG_ID = INVALID_REG_ID - 1U;
51 #ifdef ENABLE_LIBABCKIT
GetAccReg()52 inline Register GetAccReg()
53 {
54     return IsFrameSizeLarge() ? INVALID_REG_LARGE - 1U : ACC_REG_ID;
55 }
56 #else
GetAccReg()57 inline Register GetAccReg()
58 {
59     return ACC_REG_ID;
60 }
61 #endif
62 
63 class Reg final {
64 public:
65     using RegIDType = uint8_t;
66     using RegSizeType = size_t;
67 
68     constexpr Reg() = default;
69     DEFAULT_MOVE_SEMANTIC(Reg);
70     DEFAULT_COPY_SEMANTIC(Reg);
71     ~Reg() = default;
72 
73     // Default register constructor
Reg(RegIDType id,TypeInfo type)74     constexpr Reg(RegIDType id, TypeInfo type) : id_(id), type_(type) {}
75 
GetId()76     constexpr RegIDType GetId() const
77     {
78         return id_;
79     }
80 
GetMask()81     constexpr size_t GetMask() const
82     {
83         CHECK_LT(id_, 32U);
84         return (1U << id_);
85     }
86 
GetType()87     constexpr TypeInfo GetType() const
88     {
89         return type_;
90     }
91 
GetSize()92     RegSizeType GetSize() const
93     {
94         return GetType().GetSize();
95     }
96 
IsScalar()97     bool IsScalar() const
98     {
99         return GetType().IsScalar();
100     }
101 
IsFloat()102     bool IsFloat() const
103     {
104         return GetType().IsFloat();
105     }
106 
IsValid()107     constexpr bool IsValid() const
108     {
109         return type_ != INVALID_TYPE && id_ != INVALID_REG_ID;
110     }
111 
As(TypeInfo type)112     Reg As(TypeInfo type) const
113     {
114         return Reg(GetId(), type);
115     }
116 
117     constexpr bool operator==(Reg other) const
118     {
119         return (GetId() == other.GetId()) && (GetType() == other.GetType());
120     }
121 
122     constexpr bool operator!=(Reg other) const
123     {
124         return !operator==(other);
125     }
126 
Dump()127     void Dump()
128     {
129         std::cerr << " Reg: id = " << static_cast<int64_t>(id_) << ", ";
130         type_.Dump();
131         std::cerr << "\n";
132     }
133 
134 private:
135     RegIDType id_ {INVALID_REG_ID};
136     TypeInfo type_ {INVALID_TYPE};
137 };  // Reg
138 
139 constexpr Reg INVALID_REGISTER = Reg();
140 
141 static_assert(!INVALID_REGISTER.IsValid());
142 static_assert(sizeof(Reg) <= sizeof(uintptr_t));
143 
144 /**
145  * Immediate class may hold only int or float values (maybe vectors in future).
146  * It knows nothing about pointers and bools (bools maybe be in future).
147  */
148 class Imm final {
149     static constexpr size_t UNDEFINED_SIZE = 0;
150     static constexpr size_t INT64_SIZE = 64;
151     static constexpr size_t FLOAT32_SIZE = 32;
152     static constexpr size_t FLOAT64_SIZE = 64;
153 
154 public:
155     constexpr Imm() = default;
156 
157     template <typename T>
Imm(T value)158     constexpr explicit Imm(T value) : value_(static_cast<int64_t>(value))
159     {
160         using Type = std::decay_t<T>;
161         static_assert(std::is_integral_v<Type> || std::is_enum_v<Type>);
162     }
163 
164     // Partial template specialization
Imm(int64_t value)165     constexpr explicit Imm(int64_t value) : value_(value) {};
166 #ifndef NDEBUG
Imm(double value)167     constexpr explicit Imm(double value) : value_(value) {};
Imm(float value)168     constexpr explicit Imm(float value) : value_(value) {};
169 #else
Imm(double value)170     explicit Imm(double value) : value_(bit_cast<uint64_t>(value)) {};
Imm(float value)171     explicit Imm(float value) : value_(bit_cast<uint32_t>(value)) {};
172 #endif  // !NDEBUG
173 
174     DEFAULT_MOVE_SEMANTIC(Imm);
175     DEFAULT_COPY_SEMANTIC(Imm);
176     ~Imm() = default;
177 
178 #ifdef NDEBUG
GetAsInt()179     constexpr int64_t GetAsInt() const
180     {
181         return value_;
182     }
183 
GetAsFloat()184     float GetAsFloat() const
185     {
186         return bit_cast<float>(static_cast<int32_t>(value_));
187     }
188 
GetAsDouble()189     double GetAsDouble() const
190     {
191         return bit_cast<double>(value_);
192     }
193 
GetRawValue()194     constexpr int64_t GetRawValue() const
195     {
196         return value_;
197     }
198 
199 #else
GetAsInt()200     constexpr int64_t GetAsInt() const
201     {
202         ASSERT(std::holds_alternative<int64_t>(value_));
203         return std::get<int64_t>(value_);
204     }
205 
GetAsFloat()206     float GetAsFloat() const
207     {
208         ASSERT(std::holds_alternative<float>(value_));
209         return std::get<float>(value_);
210     }
211 
GetAsDouble()212     double GetAsDouble() const
213     {
214         ASSERT(std::holds_alternative<double>(value_));
215         return std::get<double>(value_);
216     }
217 
GetRawValue()218     constexpr int64_t GetRawValue() const
219     {
220         if (value_.index() == 0) {
221             UNREACHABLE();
222         } else if (value_.index() == 1) {
223             return std::get<int64_t>(value_);
224         } else if (value_.index() == 2U) {
225             return static_cast<int64_t>(bit_cast<int32_t>(std::get<float>(value_)));
226         } else if (value_.index() == 3U) {
227             return bit_cast<int64_t>(std::get<double>(value_));
228         }
229         UNREACHABLE();
230     }
231 
232     enum VariantID {
233         V_INVALID = 0,  // Pointer used for invalidate variants
234         V_INT64 = 1,
235         V_FLOAT32 = 2,
236         V_FLOAT64 = 3
237     };
238 
239     template <class T>
CheckVariantID()240     constexpr bool CheckVariantID() const
241     {
242 #ifndef __clang_analyzer__
243         // Immediate could be only signed (int/float)
244         // look at value_-type.
245         static_assert(std::is_signed_v<T>);
246         if constexpr (std::is_same<T, int64_t>()) {
247             return value_.index() == V_INT64;
248         }
249         if constexpr (std::is_same<T, float>()) {
250             return value_.index() == V_FLOAT32;
251         }
252         if constexpr (std::is_same<T, double>()) {
253             return value_.index() == V_FLOAT64;
254         }
255         return false;
256 #else
257         return true;
258 #endif  // !__clang_analyzer__
259     }
260 
IsValid()261     constexpr bool IsValid() const
262     {
263         return !std::holds_alternative<void *>(value_);
264     }
265 
GetType()266     TypeInfo GetType() const
267     {
268         switch (value_.index()) {
269             case V_INT64:
270                 return INT64_TYPE;
271             case V_FLOAT32:
272                 return FLOAT32_TYPE;
273             case V_FLOAT64:
274                 return FLOAT64_TYPE;
275             default:
276                 UNREACHABLE();
277                 return INVALID_TYPE;
278         }
279     }
280 
GetSize()281     constexpr size_t GetSize() const
282     {
283         switch (value_.index()) {
284             case V_INT64:
285                 return INT64_SIZE;
286             case V_FLOAT32:
287                 return FLOAT32_SIZE;
288             case V_FLOAT64:
289                 return FLOAT64_SIZE;
290             default:
291                 UNREACHABLE();
292                 return UNDEFINED_SIZE;
293         }
294     }
295 #endif  // NDEBUG
296 
297     bool operator==(Imm other) const
298     {
299         return value_ == other.value_;
300     }
301 
302     bool operator!=(Imm other) const
303     {
304         return !(operator==(other));
305     }
306 
307 private:
308 #ifndef NDEBUG
309     std::variant<void *, int64_t, float, double> value_ {nullptr};
310 #else
311     int64_t value_ {0};
312 #endif  // NDEBUG
313 };      // Imm
314 
315 class TypedImm final {
316 public:
317     template <typename T>
TypedImm(T imm)318     constexpr explicit TypedImm(T imm) : type_(imm), imm_(imm)
319     {
320     }
321 
GetType()322     TypeInfo GetType() const
323     {
324         return type_;
325     }
326 
GetImm()327     Imm GetImm() const
328     {
329         return imm_;
330     }
331 
332 private:
333     TypeInfo type_ {INVALID_TYPE};
334     Imm imm_ {0};
335 };
336 
337 // Why memory ref - because you may create one link for one encode-session
338 // And when you see this one - you can easy understand, what type of memory
339 //   you use. But if you load/store dirrectly address - you need to decode it
340 //   each time, when you read code
341 // model -> base + index<<scale + disp
342 class MemRef final {
343 public:
344     MemRef() = default;
345 
MemRef(Reg base)346     explicit MemRef(Reg base) : MemRef(base, 0) {}
MemRef(Reg base,ssize_t disp)347     MemRef(Reg base, ssize_t disp) : MemRef(base, INVALID_REGISTER, 0, disp) {}
MemRef(Reg base,Reg index,uint16_t scale)348     MemRef(Reg base, Reg index, uint16_t scale) : MemRef(base, index, scale, 0) {}
MemRef(Reg base,Reg index,uint16_t scale,ssize_t disp)349     MemRef(Reg base, Reg index, uint16_t scale, ssize_t disp) : disp_(disp), scale_(scale), base_(base), index_(index)
350     {
351         CHECK_LE(disp, std::numeric_limits<decltype(disp_)>::max());
352         CHECK_LE(scale, std::numeric_limits<decltype(scale_)>::max());
353     }
354     DEFAULT_MOVE_SEMANTIC(MemRef);
355     DEFAULT_COPY_SEMANTIC(MemRef);
356     ~MemRef() = default;
357 
GetBase()358     Reg GetBase() const
359     {
360         return base_;
361     }
GetIndex()362     Reg GetIndex() const
363     {
364         return index_;
365     }
GetScale()366     auto GetScale() const
367     {
368         return scale_;
369     }
GetDisp()370     auto GetDisp() const
371     {
372         return disp_;
373     }
374 
HasBase()375     bool HasBase() const
376     {
377         return base_.IsValid();
378     }
HasIndex()379     bool HasIndex() const
380     {
381         return index_.IsValid();
382     }
HasScale()383     bool HasScale() const
384     {
385         return HasIndex() && scale_ != 0;
386     }
HasDisp()387     bool HasDisp() const
388     {
389         return disp_ != 0;
390     }
391     // Ref must contain at least one of field
IsValid()392     bool IsValid() const
393     {
394         return HasBase() || HasIndex() || HasScale() || HasDisp();
395     }
396 
397     // return true if mem doesn't has index and scalar
IsOffsetMem()398     bool IsOffsetMem() const
399     {
400         return !HasIndex() && !HasScale();
401     }
402 
403     bool operator==(MemRef other) const
404     {
405         return (base_ == other.base_) && (index_ == other.index_) && (scale_ == other.scale_) && (disp_ == other.disp_);
406     }
407     bool operator!=(MemRef other) const
408     {
409         return !(operator==(other));
410     }
411 
412 private:
413     ssize_t disp_ {0};
414     uint16_t scale_ {0};
415     Reg base_ {INVALID_REGISTER};
416     Reg index_ {INVALID_REGISTER};
417 };  // MemRef
418 
419 class Shift final {
420 public:
Shift(Reg base,ShiftType type,uint32_t scale)421     explicit Shift(Reg base, ShiftType type, uint32_t scale) : scale_(scale), base_(base), type_(type) {}
Shift(Reg base,uint32_t scale)422     explicit Shift(Reg base, uint32_t scale) : Shift(base, ShiftType::LSL, scale) {}
423 
424     DEFAULT_MOVE_SEMANTIC(Shift);
425     DEFAULT_COPY_SEMANTIC(Shift);
426     ~Shift() = default;
427 
GetBase()428     Reg GetBase() const
429     {
430         return base_;
431     }
432 
GetType()433     ShiftType GetType() const
434     {
435         return type_;
436     }
437 
GetScale()438     uint32_t GetScale() const
439     {
440         return scale_;
441     }
442 
443 private:
444     uint32_t scale_ {0};
445     Reg base_;
446     ShiftType type_ {INVALID_SHIFT};
447 };
448 
449 }  // namespace ark::compiler
450 #endif  // COMPILER_OPTIMIZER_CODEGEN_REGISTERS_H
451