• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
17 #define COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H
18 
19 /*
20 Arch-feature definitions
21 */
22 #include <bitset>
23 #include <cstdint>
24 #include <type_traits>
25 
26 #include "type_info.h"
27 #include "utils/arch.h"
28 #include "utils/arena_containers.h"
29 #include "utils/bit_field.h"
30 #include "utils/bit_utils.h"
31 #include "utils/regmask.h"
32 #include "compiler/optimizer/ir/constants.h"
33 #include "compiler/optimizer/ir/datatype.h"
34 #include "utils/type_helpers.h"
35 
36 namespace ark::compiler {
37 // Mapping model for registers:
38 // reg-reg - support getters for small parts of registers
39 // reg-other - mapping between types of registers
40 enum RegMapping : uint32_t {
41     SCALAR_SCALAR = 1UL << 0UL,
42     SCALAR_VECTOR = 1UL << 1UL,
43     SCALAR_FLOAT = 1UL << 2UL,
44     VECTOR_VECTOR = 1UL << 3UL,
45     VECTOR_FLOAT = 1UL << 4UL,
46     FLOAT_FLOAT = 1UL << 5UL
47 };
48 
49 constexpr uint8_t INVALID_REG_ID = std::numeric_limits<uint8_t>::max();
50 constexpr uint8_t ACC_REG_ID = INVALID_REG_ID - 1U;
51 
52 class Reg final {
53 public:
54     using RegIDType = uint8_t;
55     using RegSizeType = size_t;
56 
57     constexpr Reg() = default;
58     DEFAULT_MOVE_SEMANTIC(Reg);
59     DEFAULT_COPY_SEMANTIC(Reg);
60     ~Reg() = default;
61 
62     // Default register constructor
Reg(RegIDType id,TypeInfo type)63     constexpr Reg(RegIDType id, TypeInfo type) : id_(id), type_(type) {}
64 
GetId()65     constexpr RegIDType GetId() const
66     {
67         return id_;
68     }
69 
GetMask()70     constexpr size_t GetMask() const
71     {
72         CHECK_LT(id_, 32U);
73         return (1U << id_);
74     }
75 
GetType()76     constexpr TypeInfo GetType() const
77     {
78         return type_;
79     }
80 
GetSize()81     RegSizeType GetSize() const
82     {
83         return GetType().GetSize();
84     }
85 
IsScalar()86     bool IsScalar() const
87     {
88         return GetType().IsScalar();
89     }
90 
IsFloat()91     bool IsFloat() const
92     {
93         return GetType().IsFloat();
94     }
95 
IsValid()96     constexpr bool IsValid() const
97     {
98         return type_ != INVALID_TYPE && id_ != INVALID_REG_ID;
99     }
100 
As(TypeInfo type)101     Reg As(TypeInfo type) const
102     {
103         return Reg(GetId(), type);
104     }
105 
106     constexpr bool operator==(Reg other) const
107     {
108         return (GetId() == other.GetId()) && (GetType() == other.GetType());
109     }
110 
111     constexpr bool operator!=(Reg other) const
112     {
113         return !operator==(other);
114     }
115 
Dump()116     void Dump()
117     {
118         std::cerr << " Reg: id = " << static_cast<int64_t>(id_) << ", ";
119         type_.Dump();
120         std::cerr << "\n";
121     }
122 
123 private:
124     RegIDType id_ {INVALID_REG_ID};
125     TypeInfo type_ {INVALID_TYPE};
126 };  // Reg
127 
128 constexpr Reg INVALID_REGISTER = Reg();
129 
130 static_assert(!INVALID_REGISTER.IsValid());
131 static_assert(sizeof(Reg) <= sizeof(uintptr_t));
132 
133 /**
134  * Immediate class may hold only int or float values (maybe vectors in future).
135  * It knows nothing about pointers and bools (bools maybe be in future).
136  */
137 class Imm final {
138     static constexpr size_t UNDEFINED_SIZE = 0;
139     static constexpr size_t INT64_SIZE = 64;
140     static constexpr size_t FLOAT32_SIZE = 32;
141     static constexpr size_t FLOAT64_SIZE = 64;
142 
143 public:
144     constexpr Imm() = default;
145 
146     template <typename T>
Imm(T value)147     constexpr explicit Imm(T value) : value_(static_cast<int64_t>(value))
148     {
149         using Type = std::decay_t<T>;
150         static_assert(std::is_integral_v<Type> || std::is_enum_v<Type>);
151     }
152 
153     // Partial template specialization
Imm(int64_t value)154     constexpr explicit Imm(int64_t value) : value_(value) {};
155 #ifndef NDEBUG
Imm(double value)156     constexpr explicit Imm(double value) : value_(value) {};
Imm(float value)157     constexpr explicit Imm(float value) : value_(value) {};
158 #else
Imm(double value)159     explicit Imm(double value) : value_(bit_cast<uint64_t>(value)) {};
Imm(float value)160     explicit Imm(float value) : value_(bit_cast<uint32_t>(value)) {};
161 #endif  // !NDEBUG
162 
163     DEFAULT_MOVE_SEMANTIC(Imm);
164     DEFAULT_COPY_SEMANTIC(Imm);
165     ~Imm() = default;
166 
167 #ifdef NDEBUG
GetAsInt()168     constexpr int64_t GetAsInt() const
169     {
170         return value_;
171     }
172 
GetAsFloat()173     float GetAsFloat() const
174     {
175         return bit_cast<float>(static_cast<int32_t>(value_));
176     }
177 
GetAsDouble()178     double GetAsDouble() const
179     {
180         return bit_cast<double>(value_);
181     }
182 
GetRawValue()183     constexpr int64_t GetRawValue() const
184     {
185         return value_;
186     }
187 
188 #else
GetAsInt()189     constexpr int64_t GetAsInt() const
190     {
191         ASSERT(std::holds_alternative<int64_t>(value_));
192         return std::get<int64_t>(value_);
193     }
194 
GetAsFloat()195     float GetAsFloat() const
196     {
197         ASSERT(std::holds_alternative<float>(value_));
198         return std::get<float>(value_);
199     }
200 
GetAsDouble()201     double GetAsDouble() const
202     {
203         ASSERT(std::holds_alternative<double>(value_));
204         return std::get<double>(value_);
205     }
206 
GetRawValue()207     constexpr int64_t GetRawValue() const
208     {
209         if (value_.index() == 0) {
210             UNREACHABLE();
211         } else if (value_.index() == 1) {
212             return std::get<int64_t>(value_);
213         } else if (value_.index() == 2U) {
214             return static_cast<int64_t>(bit_cast<int32_t>(std::get<float>(value_)));
215         } else if (value_.index() == 3U) {
216             return bit_cast<int64_t>(std::get<double>(value_));
217         }
218         UNREACHABLE();
219     }
220 
221     enum VariantID {
222         V_INVALID = 0,  // Pointer used for invalidate variants
223         V_INT64 = 1,
224         V_FLOAT32 = 2,
225         V_FLOAT64 = 3
226     };
227 
228     template <class T>
CheckVariantID()229     constexpr bool CheckVariantID() const
230     {
231 #ifndef __clang_analyzer__
232         // Immediate could be only signed (int/float)
233         // look at value_-type.
234         static_assert(std::is_signed_v<T>);
235         if constexpr (std::is_same<T, int64_t>()) {
236             return value_.index() == V_INT64;
237         }
238         if constexpr (std::is_same<T, float>()) {
239             return value_.index() == V_FLOAT32;
240         }
241         if constexpr (std::is_same<T, double>()) {
242             return value_.index() == V_FLOAT64;
243         }
244         return false;
245 #else
246         return true;
247 #endif  // !__clang_analyzer__
248     }
249 
IsValid()250     constexpr bool IsValid() const
251     {
252         return !std::holds_alternative<void *>(value_);
253     }
254 
GetType()255     TypeInfo GetType() const
256     {
257         switch (value_.index()) {
258             case V_INT64:
259                 return INT64_TYPE;
260             case V_FLOAT32:
261                 return FLOAT32_TYPE;
262             case V_FLOAT64:
263                 return FLOAT64_TYPE;
264             default:
265                 UNREACHABLE();
266                 return INVALID_TYPE;
267         }
268     }
269 
GetSize()270     constexpr size_t GetSize() const
271     {
272         switch (value_.index()) {
273             case V_INT64:
274                 return INT64_SIZE;
275             case V_FLOAT32:
276                 return FLOAT32_SIZE;
277             case V_FLOAT64:
278                 return FLOAT64_SIZE;
279             default:
280                 UNREACHABLE();
281                 return UNDEFINED_SIZE;
282         }
283     }
284 #endif  // NDEBUG
285 
286     bool operator==(Imm other) const
287     {
288         return value_ == other.value_;
289     }
290 
291     bool operator!=(Imm other) const
292     {
293         return !(operator==(other));
294     }
295 
296 private:
297 #ifndef NDEBUG
298     std::variant<void *, int64_t, float, double> value_ {nullptr};
299 #else
300     int64_t value_ {0};
301 #endif  // NDEBUG
302 };      // Imm
303 
304 class TypedImm final {
305 public:
306     template <typename T>
TypedImm(T imm)307     constexpr explicit TypedImm(T imm) : type_(imm), imm_(imm)
308     {
309     }
310 
GetType()311     TypeInfo GetType() const
312     {
313         return type_;
314     }
315 
GetImm()316     Imm GetImm() const
317     {
318         return imm_;
319     }
320 
321 private:
322     TypeInfo type_ {INVALID_TYPE};
323     Imm imm_ {0};
324 };
325 
326 // Why memory ref - because you may create one link for one encode-session
327 // And when you see this one - you can easy understand, what type of memory
328 //   you use. But if you load/store dirrectly address - you need to decode it
329 //   each time, when you read code
330 // model -> base + index<<scale + disp
331 class MemRef final {
332 public:
333     MemRef() = default;
334 
MemRef(Reg base)335     explicit MemRef(Reg base) : MemRef(base, 0) {}
MemRef(Reg base,ssize_t disp)336     MemRef(Reg base, ssize_t disp) : MemRef(base, INVALID_REGISTER, 0, disp) {}
MemRef(Reg base,Reg index,uint16_t scale)337     MemRef(Reg base, Reg index, uint16_t scale) : MemRef(base, index, scale, 0) {}
MemRef(Reg base,Reg index,uint16_t scale,ssize_t disp)338     MemRef(Reg base, Reg index, uint16_t scale, ssize_t disp) : disp_(disp), scale_(scale), base_(base), index_(index)
339     {
340         CHECK_LE(disp, std::numeric_limits<decltype(disp_)>::max());
341         CHECK_LE(scale, std::numeric_limits<decltype(scale_)>::max());
342     }
343     DEFAULT_MOVE_SEMANTIC(MemRef);
344     DEFAULT_COPY_SEMANTIC(MemRef);
345     ~MemRef() = default;
346 
GetBase()347     Reg GetBase() const
348     {
349         return base_;
350     }
GetIndex()351     Reg GetIndex() const
352     {
353         return index_;
354     }
GetScale()355     auto GetScale() const
356     {
357         return scale_;
358     }
GetDisp()359     auto GetDisp() const
360     {
361         return disp_;
362     }
363 
HasBase()364     bool HasBase() const
365     {
366         return base_.IsValid();
367     }
HasIndex()368     bool HasIndex() const
369     {
370         return index_.IsValid();
371     }
HasScale()372     bool HasScale() const
373     {
374         return HasIndex() && scale_ != 0;
375     }
HasDisp()376     bool HasDisp() const
377     {
378         return disp_ != 0;
379     }
380     // Ref must contain at least one of field
IsValid()381     bool IsValid() const
382     {
383         return HasBase() || HasIndex() || HasScale() || HasDisp();
384     }
385 
386     // return true if mem doesn't has index and scalar
IsOffsetMem()387     bool IsOffsetMem() const
388     {
389         return !HasIndex() && !HasScale();
390     }
391 
392     bool operator==(MemRef other) const
393     {
394         return (base_ == other.base_) && (index_ == other.index_) && (scale_ == other.scale_) && (disp_ == other.disp_);
395     }
396     bool operator!=(MemRef other) const
397     {
398         return !(operator==(other));
399     }
400 
401 private:
402     ssize_t disp_ {0};
403     uint16_t scale_ {0};
404     Reg base_ {INVALID_REGISTER};
405     Reg index_ {INVALID_REGISTER};
406 };  // MemRef
407 
408 class Shift final {
409 public:
Shift(Reg base,ShiftType type,uint32_t scale)410     explicit Shift(Reg base, ShiftType type, uint32_t scale) : scale_(scale), base_(base), type_(type) {}
Shift(Reg base,uint32_t scale)411     explicit Shift(Reg base, uint32_t scale) : Shift(base, ShiftType::LSL, scale) {}
412 
413     DEFAULT_MOVE_SEMANTIC(Shift);
414     DEFAULT_COPY_SEMANTIC(Shift);
415     ~Shift() = default;
416 
GetBase()417     Reg GetBase() const
418     {
419         return base_;
420     }
421 
GetType()422     ShiftType GetType() const
423     {
424         return type_;
425     }
426 
GetScale()427     uint32_t GetScale() const
428     {
429         return scale_;
430     }
431 
432 private:
433     uint32_t scale_ {0};
434     Reg base_;
435     ShiftType type_ {INVALID_SHIFT};
436 };
437 
438 }  // namespace ark::compiler
439 #endif  // COMPILER_OPTIMIZER_CODEGEN_REGISTERS_H
440