1 /* 2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H 17 #define COMPILER_OPTIMIZER_CODEGEN_OPERANDS_H 18 19 /* 20 Arch-feature definitions 21 */ 22 #include <bitset> 23 #include <cstdint> 24 #include <type_traits> 25 26 #include "type_info.h" 27 #include "utils/arch.h" 28 #include "utils/arena_containers.h" 29 #include "utils/bit_field.h" 30 #include "utils/bit_utils.h" 31 #include "utils/regmask.h" 32 #include "compiler/optimizer/ir/constants.h" 33 #include "compiler/optimizer/ir/datatype.h" 34 #include "utils/type_helpers.h" 35 36 namespace ark::compiler { 37 // Mapping model for registers: 38 // reg-reg - support getters for small parts of registers 39 // reg-other - mapping between types of registers 40 enum RegMapping : uint32_t { 41 SCALAR_SCALAR = 1UL << 0UL, 42 SCALAR_VECTOR = 1UL << 1UL, 43 SCALAR_FLOAT = 1UL << 2UL, 44 VECTOR_VECTOR = 1UL << 3UL, 45 VECTOR_FLOAT = 1UL << 4UL, 46 FLOAT_FLOAT = 1UL << 5UL 47 }; 48 49 constexpr uint8_t INVALID_REG_ID = std::numeric_limits<uint8_t>::max(); 50 constexpr uint8_t ACC_REG_ID = INVALID_REG_ID - 1U; 51 52 class Reg final { 53 public: 54 using RegIDType = uint8_t; 55 using RegSizeType = size_t; 56 57 constexpr Reg() = default; 58 DEFAULT_MOVE_SEMANTIC(Reg); 59 DEFAULT_COPY_SEMANTIC(Reg); 60 ~Reg() = default; 61 62 // Default register constructor Reg(RegIDType id,TypeInfo type)63 constexpr Reg(RegIDType id, TypeInfo type) : id_(id), type_(type) {} 64 GetId()65 constexpr RegIDType GetId() const 66 { 67 return id_; 68 } 69 GetMask()70 constexpr size_t GetMask() const 71 { 72 CHECK_LT(id_, 32U); 73 return (1U << id_); 74 } 75 GetType()76 constexpr TypeInfo GetType() const 77 { 78 return type_; 79 } 80 GetSize()81 RegSizeType GetSize() const 82 { 83 return GetType().GetSize(); 84 } 85 IsScalar()86 bool IsScalar() const 87 { 88 return GetType().IsScalar(); 89 } 90 IsFloat()91 bool IsFloat() const 92 { 93 return GetType().IsFloat(); 94 } 95 IsValid()96 constexpr bool IsValid() const 97 { 98 return type_ != INVALID_TYPE && id_ != INVALID_REG_ID; 99 } 100 As(TypeInfo type)101 Reg As(TypeInfo type) const 102 { 103 return Reg(GetId(), type); 104 } 105 106 constexpr bool operator==(Reg other) const 107 { 108 return (GetId() == other.GetId()) && (GetType() == other.GetType()); 109 } 110 111 constexpr bool operator!=(Reg other) const 112 { 113 return !operator==(other); 114 } 115 Dump()116 void Dump() 117 { 118 std::cerr << " Reg: id = " << static_cast<int64_t>(id_) << ", "; 119 type_.Dump(); 120 std::cerr << "\n"; 121 } 122 123 private: 124 RegIDType id_ {INVALID_REG_ID}; 125 TypeInfo type_ {INVALID_TYPE}; 126 }; // Reg 127 128 constexpr Reg INVALID_REGISTER = Reg(); 129 130 static_assert(!INVALID_REGISTER.IsValid()); 131 static_assert(sizeof(Reg) <= sizeof(uintptr_t)); 132 133 /** 134 * Immediate class may hold only int or float values (maybe vectors in future). 135 * It knows nothing about pointers and bools (bools maybe be in future). 136 */ 137 class Imm final { 138 static constexpr size_t UNDEFINED_SIZE = 0; 139 static constexpr size_t INT64_SIZE = 64; 140 static constexpr size_t FLOAT32_SIZE = 32; 141 static constexpr size_t FLOAT64_SIZE = 64; 142 143 public: 144 constexpr Imm() = default; 145 146 template <typename T> Imm(T value)147 constexpr explicit Imm(T value) : value_(static_cast<int64_t>(value)) 148 { 149 using Type = std::decay_t<T>; 150 static_assert(std::is_integral_v<Type> || std::is_enum_v<Type>); 151 } 152 153 // Partial template specialization Imm(int64_t value)154 constexpr explicit Imm(int64_t value) : value_(value) {}; 155 #ifndef NDEBUG Imm(double value)156 constexpr explicit Imm(double value) : value_(value) {}; Imm(float value)157 constexpr explicit Imm(float value) : value_(value) {}; 158 #else Imm(double value)159 explicit Imm(double value) : value_(bit_cast<uint64_t>(value)) {}; Imm(float value)160 explicit Imm(float value) : value_(bit_cast<uint32_t>(value)) {}; 161 #endif // !NDEBUG 162 163 DEFAULT_MOVE_SEMANTIC(Imm); 164 DEFAULT_COPY_SEMANTIC(Imm); 165 ~Imm() = default; 166 167 #ifdef NDEBUG GetAsInt()168 constexpr int64_t GetAsInt() const 169 { 170 return value_; 171 } 172 GetAsFloat()173 float GetAsFloat() const 174 { 175 return bit_cast<float>(static_cast<int32_t>(value_)); 176 } 177 GetAsDouble()178 double GetAsDouble() const 179 { 180 return bit_cast<double>(value_); 181 } 182 GetRawValue()183 constexpr int64_t GetRawValue() const 184 { 185 return value_; 186 } 187 188 #else GetAsInt()189 constexpr int64_t GetAsInt() const 190 { 191 ASSERT(std::holds_alternative<int64_t>(value_)); 192 return std::get<int64_t>(value_); 193 } 194 GetAsFloat()195 float GetAsFloat() const 196 { 197 ASSERT(std::holds_alternative<float>(value_)); 198 return std::get<float>(value_); 199 } 200 GetAsDouble()201 double GetAsDouble() const 202 { 203 ASSERT(std::holds_alternative<double>(value_)); 204 return std::get<double>(value_); 205 } 206 GetRawValue()207 constexpr int64_t GetRawValue() const 208 { 209 if (value_.index() == 0) { 210 UNREACHABLE(); 211 } else if (value_.index() == 1) { 212 return std::get<int64_t>(value_); 213 } else if (value_.index() == 2U) { 214 return static_cast<int64_t>(bit_cast<int32_t>(std::get<float>(value_))); 215 } else if (value_.index() == 3U) { 216 return bit_cast<int64_t>(std::get<double>(value_)); 217 } 218 UNREACHABLE(); 219 } 220 221 enum VariantID { 222 V_INVALID = 0, // Pointer used for invalidate variants 223 V_INT64 = 1, 224 V_FLOAT32 = 2, 225 V_FLOAT64 = 3 226 }; 227 228 template <class T> CheckVariantID()229 constexpr bool CheckVariantID() const 230 { 231 #ifndef __clang_analyzer__ 232 // Immediate could be only signed (int/float) 233 // look at value_-type. 234 static_assert(std::is_signed_v<T>); 235 if constexpr (std::is_same<T, int64_t>()) { 236 return value_.index() == V_INT64; 237 } 238 if constexpr (std::is_same<T, float>()) { 239 return value_.index() == V_FLOAT32; 240 } 241 if constexpr (std::is_same<T, double>()) { 242 return value_.index() == V_FLOAT64; 243 } 244 return false; 245 #else 246 return true; 247 #endif // !__clang_analyzer__ 248 } 249 IsValid()250 constexpr bool IsValid() const 251 { 252 return !std::holds_alternative<void *>(value_); 253 } 254 GetType()255 TypeInfo GetType() const 256 { 257 switch (value_.index()) { 258 case V_INT64: 259 return INT64_TYPE; 260 case V_FLOAT32: 261 return FLOAT32_TYPE; 262 case V_FLOAT64: 263 return FLOAT64_TYPE; 264 default: 265 UNREACHABLE(); 266 return INVALID_TYPE; 267 } 268 } 269 GetSize()270 constexpr size_t GetSize() const 271 { 272 switch (value_.index()) { 273 case V_INT64: 274 return INT64_SIZE; 275 case V_FLOAT32: 276 return FLOAT32_SIZE; 277 case V_FLOAT64: 278 return FLOAT64_SIZE; 279 default: 280 UNREACHABLE(); 281 return UNDEFINED_SIZE; 282 } 283 } 284 #endif // NDEBUG 285 286 bool operator==(Imm other) const 287 { 288 return value_ == other.value_; 289 } 290 291 bool operator!=(Imm other) const 292 { 293 return !(operator==(other)); 294 } 295 296 private: 297 #ifndef NDEBUG 298 std::variant<void *, int64_t, float, double> value_ {nullptr}; 299 #else 300 int64_t value_ {0}; 301 #endif // NDEBUG 302 }; // Imm 303 304 class TypedImm final { 305 public: 306 template <typename T> TypedImm(T imm)307 constexpr explicit TypedImm(T imm) : type_(imm), imm_(imm) 308 { 309 } 310 GetType()311 TypeInfo GetType() const 312 { 313 return type_; 314 } 315 GetImm()316 Imm GetImm() const 317 { 318 return imm_; 319 } 320 321 private: 322 TypeInfo type_ {INVALID_TYPE}; 323 Imm imm_ {0}; 324 }; 325 326 // Why memory ref - because you may create one link for one encode-session 327 // And when you see this one - you can easy understand, what type of memory 328 // you use. But if you load/store dirrectly address - you need to decode it 329 // each time, when you read code 330 // model -> base + index<<scale + disp 331 class MemRef final { 332 public: 333 MemRef() = default; 334 MemRef(Reg base)335 explicit MemRef(Reg base) : MemRef(base, 0) {} MemRef(Reg base,ssize_t disp)336 MemRef(Reg base, ssize_t disp) : MemRef(base, INVALID_REGISTER, 0, disp) {} MemRef(Reg base,Reg index,uint16_t scale)337 MemRef(Reg base, Reg index, uint16_t scale) : MemRef(base, index, scale, 0) {} MemRef(Reg base,Reg index,uint16_t scale,ssize_t disp)338 MemRef(Reg base, Reg index, uint16_t scale, ssize_t disp) : disp_(disp), scale_(scale), base_(base), index_(index) 339 { 340 CHECK_LE(disp, std::numeric_limits<decltype(disp_)>::max()); 341 CHECK_LE(scale, std::numeric_limits<decltype(scale_)>::max()); 342 } 343 DEFAULT_MOVE_SEMANTIC(MemRef); 344 DEFAULT_COPY_SEMANTIC(MemRef); 345 ~MemRef() = default; 346 GetBase()347 Reg GetBase() const 348 { 349 return base_; 350 } GetIndex()351 Reg GetIndex() const 352 { 353 return index_; 354 } GetScale()355 auto GetScale() const 356 { 357 return scale_; 358 } GetDisp()359 auto GetDisp() const 360 { 361 return disp_; 362 } 363 HasBase()364 bool HasBase() const 365 { 366 return base_.IsValid(); 367 } HasIndex()368 bool HasIndex() const 369 { 370 return index_.IsValid(); 371 } HasScale()372 bool HasScale() const 373 { 374 return HasIndex() && scale_ != 0; 375 } HasDisp()376 bool HasDisp() const 377 { 378 return disp_ != 0; 379 } 380 // Ref must contain at least one of field IsValid()381 bool IsValid() const 382 { 383 return HasBase() || HasIndex() || HasScale() || HasDisp(); 384 } 385 386 // return true if mem doesn't has index and scalar IsOffsetMem()387 bool IsOffsetMem() const 388 { 389 return !HasIndex() && !HasScale(); 390 } 391 392 bool operator==(MemRef other) const 393 { 394 return (base_ == other.base_) && (index_ == other.index_) && (scale_ == other.scale_) && (disp_ == other.disp_); 395 } 396 bool operator!=(MemRef other) const 397 { 398 return !(operator==(other)); 399 } 400 401 private: 402 ssize_t disp_ {0}; 403 uint16_t scale_ {0}; 404 Reg base_ {INVALID_REGISTER}; 405 Reg index_ {INVALID_REGISTER}; 406 }; // MemRef 407 408 class Shift final { 409 public: Shift(Reg base,ShiftType type,uint32_t scale)410 explicit Shift(Reg base, ShiftType type, uint32_t scale) : scale_(scale), base_(base), type_(type) {} Shift(Reg base,uint32_t scale)411 explicit Shift(Reg base, uint32_t scale) : Shift(base, ShiftType::LSL, scale) {} 412 413 DEFAULT_MOVE_SEMANTIC(Shift); 414 DEFAULT_COPY_SEMANTIC(Shift); 415 ~Shift() = default; 416 GetBase()417 Reg GetBase() const 418 { 419 return base_; 420 } 421 GetType()422 ShiftType GetType() const 423 { 424 return type_; 425 } 426 GetScale()427 uint32_t GetScale() const 428 { 429 return scale_; 430 } 431 432 private: 433 uint32_t scale_ {0}; 434 Reg base_; 435 ShiftType type_ {INVALID_SHIFT}; 436 }; 437 438 } // namespace ark::compiler 439 #endif // COMPILER_OPTIMIZER_CODEGEN_REGISTERS_H 440