• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_OPTIMIZER_CODEGEN_CODEGEN_INL_H
17 #define COMPILER_OPTIMIZER_CODEGEN_CODEGEN_INL_H
18 
19 namespace ark::compiler {
20 
21 /// 'live_inputs' shows that inst's source registers should be added the the mask
22 template <bool LIVE_INPUTS>
GetLiveRegisters(Inst * inst)23 std::pair<RegMask, VRegMask> Codegen::GetLiveRegisters(Inst *inst)
24 {
25     RegMask liveRegs;
26     VRegMask liveFpRegs;
27     if (!g_options.IsCompilerSaveOnlyLiveRegisters() || inst == nullptr) {
28         liveRegs.set();
29         liveFpRegs.set();
30         return {liveRegs, liveFpRegs};
31     }
32     // Run LiveRegisters pass only if it is actually required
33     if (!GetGraph()->IsAnalysisValid<LiveRegisters>()) {
34         GetGraph()->RunPass<LiveRegisters>();
35     }
36 
37     // Add registers from intervals that are live at inst's definition
38     auto &lr = GetGraph()->GetAnalysis<LiveRegisters>();
39     lr.VisitIntervalsWithLiveRegisters<LIVE_INPUTS>(inst, [&liveRegs, &liveFpRegs, this](const auto &li) {
40         auto reg = ConvertRegister(li->GetReg(), li->GetType());
41         GetEncoder()->SetRegister(&liveRegs, &liveFpRegs, reg);
42     });
43 
44     // Add live temp registers
45     liveRegs |= GetEncoder()->GetLiveTmpRegMask();
46     liveFpRegs |= GetEncoder()->GetLiveTmpFpRegMask();
47 
48     return {liveRegs, liveFpRegs};
49 }
50 
51 template <typename T, typename... Args>
CreateSlowPath(Inst * inst,Args &&...args)52 T *Codegen::CreateSlowPath(Inst *inst, Args &&...args)
53 {
54     static_assert(std::is_base_of_v<SlowPathBase, T>);
55     auto label = GetEncoder()->CreateLabel();
56     auto slowPath = GetLocalAllocator()->New<T>(label, inst, std::forward<Args>(args)...);
57     slowPaths_.push_back(slowPath);
58     return slowPath;
59 }
60 
61 /**
62  * Insert tracing code to the generated code. See `Trace` method in the `runtime/entrypoints.cpp`.
63  * NOTE(compiler): we should rework parameters assigning algorithm, that is duplicated here.
64  * @param params parameters to be passed to the TRACE entrypoint, first parameter must be TraceId value.
65  */
66 template <typename... Args>
InsertTrace(Args &&...params)67 void Codegen::InsertTrace(Args &&...params)
68 {
69     SCOPED_DISASM_STR(this, "Trace");
70     [[maybe_unused]] constexpr size_t MAX_PARAM_NUM = 8;
71     static_assert(sizeof...(Args) <= MAX_PARAM_NUM);
72     auto regfile = GetRegfile();
73     auto saveRegs = regfile->GetCallerSavedRegMask();
74     saveRegs.set(GetTarget().GetReturnRegId());
75     auto saveVregs = regfile->GetCallerSavedVRegMask();
76     saveVregs.set(GetTarget().GetReturnFpRegId());
77 
78     SaveCallerRegisters(saveRegs, saveVregs, false);
79     FillCallParams(std::forward<Args>(params)...);
80     EmitCallRuntimeCode(nullptr, EntrypointId::TRACE);
81     LoadCallerRegisters(saveRegs, saveVregs, false);
82 }
83 
84 template <bool IS_FASTPATH, typename... Args>
CallEntrypoint(Inst * inst,EntrypointId id,Reg dstReg,RegMask preservedRegs,Args &&...params)85 void Codegen::CallEntrypoint(Inst *inst, EntrypointId id, Reg dstReg, RegMask preservedRegs, Args &&...params)
86 {
87     ASSERT(inst != nullptr);
88     CHECK_EQ(sizeof...(Args), GetRuntime()->GetEntrypointArgsNum(id));
89     if (GetArch() == Arch::AARCH32) {
90         // There is a problem with 64-bit parameters:
91         // params number passed from entrypoints_gen.S.erb will be inconsistent with Aarch32 ABI.
92         // Thus, runtime bridges will have wrong params number (\paramsnum macro argument).
93         ASSERT(EnsureParamsFitIn32Bit({params...}));
94         ASSERT(!dstReg.IsValid() || dstReg.GetSize() <= WORD_SIZE);
95     }
96 
97     SCOPED_DISASM_STR(this, std::string("CallEntrypoint: ") + GetRuntime()->GetEntrypointName(id));
98     RegMask liveRegs {preservedRegs | GetLiveRegisters(inst).first};
99     RegMask paramsMask;
100     if (inst->HasImplicitRuntimeCall() && !GetRuntime()->IsEntrypointNoreturn(id)) {
101         SaveRegistersForImplicitRuntime(inst, &paramsMask, &liveRegs);
102     }
103 
104     ASSERT(IS_FASTPATH == GetRuntime()->IsEntrypointFastPath(id));
105     bool retRegAlive {liveRegs.Test(GetTarget().GetReturnRegId())};
106     // parameter regs: their initial values must be stored by the caller
107     // Other caller regs stored in bridges
108     FillOnlyParameters(&liveRegs, sizeof...(Args), IS_FASTPATH);
109 
110     if (IS_FASTPATH && retRegAlive && dstReg.IsValid()) {
111         Reg retReg = GetTarget().GetReturnReg(dstReg.GetType());
112         if (dstReg.GetId() != retReg.GetId()) {
113             GetEncoder()->SetRegister(&liveRegs, nullptr, retReg, true);
114         }
115     }
116 
117     GetEncoder()->SetRegister(&liveRegs, nullptr, dstReg, false);
118     SaveCallerRegisters(liveRegs, VRegMask(), true);
119 
120     if (sizeof...(Args) != 0) {
121         FillCallParams(std::forward<Args>(params)...);
122     }
123 
124     // Call Code
125     if (!EmitCallRuntimeCode(inst, id)) {
126         return;
127     }
128     if (dstReg.IsValid()) {
129         ASSERT(dstReg.IsScalar());
130         Reg retReg = GetTarget().GetReturnReg(dstReg.GetType());
131         if (!IS_FASTPATH && retRegAlive && dstReg.GetId() != retReg.GetId() &&
132             (!GetTarget().FirstParamIsReturnReg(retReg.GetType()) || sizeof...(Args) == 0U)) {
133             GetEncoder()->SetRegister(&liveRegs, nullptr, retReg, true);
134         }
135 
136         // We must:
137         //  sign extended INT8 and INT16 to INT32
138         //  zero extended UINT8 and UINT16 to UINT32
139         if (dstReg.GetSize() < WORD_SIZE) {
140             bool isSigned = DataType::IsTypeSigned(inst->GetType());
141             GetEncoder()->EncodeCast(dstReg.As(INT32_TYPE), isSigned, retReg, isSigned);
142         } else {
143             GetEncoder()->EncodeMov(dstReg, retReg);
144         }
145     }
146     CallEntrypointFinalize(liveRegs, paramsMask, inst);
147 }
148 
149 // The function is used for calling runtime functions through special bridges.
150 // !NOTE Don't use the function for calling runtime without bridges(it save only parameters on stack)
151 template <typename... Args>
CallRuntime(Inst * inst,EntrypointId id,Reg dstReg,RegMask preservedRegs,Args &&...params)152 void Codegen::CallRuntime(Inst *inst, EntrypointId id, Reg dstReg, RegMask preservedRegs, Args &&...params)
153 {
154     CallEntrypoint<false>(inst, id, dstReg, preservedRegs, std::forward<Args>(params)...);
155 }
156 
157 template <typename... Args>
CallFastPath(Inst * inst,EntrypointId id,Reg dstReg,RegMask preservedRegs,Args &&...params)158 void Codegen::CallFastPath(Inst *inst, EntrypointId id, Reg dstReg, RegMask preservedRegs, Args &&...params)
159 {
160     CallEntrypoint<true>(inst, id, dstReg, preservedRegs, std::forward<Args>(params)...);
161 }
162 
163 template <typename... Args>
CallRuntimeWithMethod(Inst * inst,void * method,EntrypointId eid,Reg dstReg,Args &&...params)164 void Codegen::CallRuntimeWithMethod(Inst *inst, void *method, EntrypointId eid, Reg dstReg, Args &&...params)
165 {
166     if (GetGraph()->IsAotMode()) {
167         ScopedTmpReg methodReg(GetEncoder());
168         LoadMethod(methodReg);
169         CallRuntime(inst, eid, dstReg, RegMask::GetZeroMask(), methodReg, std::forward<Args>(params)...);
170     } else {
171         if (Is64BitsArch(GetArch())) {
172             CallRuntime(inst, eid, dstReg, RegMask::GetZeroMask(), TypedImm(reinterpret_cast<uint64_t>(method)),
173                         std::forward<Args>(params)...);
174         } else {
175             // uintptr_t causes problems on host cross-jit compilation
176             CallRuntime(inst, eid, dstReg, RegMask::GetZeroMask(), TypedImm(down_cast<uint32_t>(method)),
177                         std::forward<Args>(params)...);
178         }
179     }
180 }
181 
182 template <typename... Args>
CallBarrier(RegMask liveRegs,VRegMask liveVregs,std::variant<EntrypointId,Reg> entrypoint,Args &&...params)183 void Codegen::CallBarrier(RegMask liveRegs, VRegMask liveVregs, std::variant<EntrypointId, Reg> entrypoint,
184                           Args &&...params)
185 {
186     SaveCallerRegisters(liveRegs, liveVregs, true);
187     FillCallParams(std::forward<Args>(params)...);
188     EmitCallRuntimeCode(nullptr, entrypoint);
189     LoadCallerRegisters(liveRegs, liveVregs, true);
190 }
191 
192 template <typename T>
CreateUnaryCheck(Inst * inst,RuntimeInterface::EntrypointId id,DeoptimizeType type,Condition cc)193 void Codegen::CreateUnaryCheck(Inst *inst, RuntimeInterface::EntrypointId id, DeoptimizeType type, Condition cc)
194 {
195     [[maybe_unused]] auto ss = inst->GetSaveState();
196     ASSERT(ss != nullptr && (ss->GetOpcode() == Opcode::SaveState || ss->GetOpcode() == Opcode::SaveStateDeoptimize));
197 
198     LabelHolder::LabelId slowPath;
199     if (inst->CanDeoptimize()) {
200         slowPath = CreateSlowPath<SlowPathDeoptimize>(inst, type)->GetLabel();
201     } else {
202         slowPath = CreateSlowPath<T>(inst, id)->GetLabel();
203     }
204     auto srcType = inst->GetInputType(0);
205     auto src = ConvertRegister(inst->GetSrcReg(0), srcType);
206     GetEncoder()->EncodeJump(slowPath, src, cc);
207 }
208 
209 // The function alignment up the value from alignment_reg using tmp_reg.
210 
GetStackOffset(Location location)211 inline ssize_t Codegen::GetStackOffset(Location location)
212 {
213     if (location.GetKind() == LocationType::STACK_ARGUMENT) {
214         return location.GetValue() * GetFrameLayout().GetSlotSize();
215     }
216 
217     if (location.GetKind() == LocationType::STACK_PARAMETER) {
218         return GetFrameLayout().GetFrameSize<CFrameLayout::OffsetUnit::BYTES>() +
219                (location.GetValue() * GetFrameLayout().GetSlotSize());
220     }
221 
222     ASSERT(location.GetKind() == LocationType::STACK);
223     return GetFrameLayout().GetSpillOffsetFromSpInBytes(location.GetValue());
224 }
225 
GetMemRefForSlot(Location location)226 inline MemRef Codegen::GetMemRefForSlot(Location location)
227 {
228     ASSERT(location.IsAnyStack());
229     return MemRef(SpReg(), GetStackOffset(location));
230 }
231 
SpReg()232 inline Reg Codegen::SpReg() const
233 {
234     return GetTarget().GetStackReg();
235 }
236 
FpReg()237 inline Reg Codegen::FpReg() const
238 {
239     return GetTarget().GetFrameReg();
240 }
241 
GetDisasm()242 inline const Disassembly *Codegen::GetDisasm() const
243 {
244     return &disasm_;
245 }
246 
GetDisasm()247 inline Disassembly *Codegen::GetDisasm()
248 {
249     return &disasm_;
250 }
251 
AddLiveOut(const BasicBlock * bb,const Register reg)252 inline void Codegen::AddLiveOut(const BasicBlock *bb, const Register reg)
253 {
254     liveOuts_[bb].Set(reg);
255 }
256 
GetLiveOut(const BasicBlock * bb)257 inline RegMask Codegen::GetLiveOut(const BasicBlock *bb) const
258 {
259     auto it = liveOuts_.find(bb);
260     return it != liveOuts_.end() ? it->second : RegMask();
261 }
262 
ThreadReg()263 inline Reg Codegen::ThreadReg() const
264 {
265     return Reg(GetThreadReg(GetArch()), GetTarget().GetPtrRegType());
266 }
267 
OffsetFitReferenceTypeSize(uint64_t offset)268 inline bool Codegen::OffsetFitReferenceTypeSize(uint64_t offset) const
269 {
270     // -1 because some arch uses signed offset
271     // NOLINTNEXTLINE(hicpp-signed-bitwise)
272     uint64_t maxOffset = 1ULL << (DataType::GetTypeSize(DataType::REFERENCE, GetArch()) - 1);
273     return offset < maxOffset;
274 }
275 
GetUsedRegs()276 inline RegMask Codegen::GetUsedRegs() const
277 {
278     return usedRegs_;
279 }
GetUsedVRegs()280 inline RegMask Codegen::GetUsedVRegs() const
281 {
282     return usedVregs_;
283 }
284 
GetVtableShift()285 inline uint32_t Codegen::GetVtableShift()
286 {
287     // The size of the VTable element is equal to the size of pointers for the architecture
288     // (not the size of pointer to objects)
289     constexpr uint32_t SHIFT_64_BITS = 3;
290     constexpr uint32_t SHIFT_32_BITS = 2;
291     return Is64BitsArch(GetGraph()->GetArch()) ? SHIFT_64_BITS : SHIFT_32_BITS;
292 }
293 
294 template <typename Arg, typename... Args>
AddParamRegsInLiveMasksHandleArgs(ParameterInfo * paramInfo,RegMask * liveRegs,VRegMask * liveVregs,Arg param,Args &&...params)295 ALWAYS_INLINE void Codegen::AddParamRegsInLiveMasksHandleArgs(ParameterInfo *paramInfo, RegMask *liveRegs,
296                                                               VRegMask *liveVregs, Arg param, Args &&...params)
297 {
298     auto currDst = paramInfo->GetNativeParam(param.GetType());
299     if (std::holds_alternative<Reg>(currDst)) {
300         auto reg = std::get<Reg>(currDst);
301         if (reg.IsScalar()) {
302             liveRegs->set(reg.GetId());
303         } else {
304             liveVregs->set(reg.GetId());
305         }
306     } else {
307         GetEncoder()->SetFalseResult();
308         UNREACHABLE();
309     }
310     if constexpr (sizeof...(Args) != 0) {
311         AddParamRegsInLiveMasksHandleArgs(paramInfo, liveRegs, liveVregs, std::forward<Args>(params)...);
312     }
313 }
314 
315 template <typename... Args>
AddParamRegsInLiveMasks(RegMask * liveRegs,VRegMask * liveVregs,Args &&...params)316 void Codegen::AddParamRegsInLiveMasks(RegMask *liveRegs, VRegMask *liveVregs, Args &&...params)
317 {
318     auto callconv = GetCallingConvention();
319     auto paramInfo = callconv->GetParameterInfo(0);
320     AddParamRegsInLiveMasksHandleArgs(paramInfo, liveRegs, liveVregs, std::forward<Args>(params)...);
321 }
322 
323 template <typename... Args>
CreateStubCall(Inst * inst,RuntimeInterface::IntrinsicId intrinsicId,Reg dst,Args &&...params)324 void Codegen::CreateStubCall(Inst *inst, RuntimeInterface::IntrinsicId intrinsicId, Reg dst, Args &&...params)
325 {
326     VRegMask liveVregs;
327     RegMask liveRegs;
328     AddParamRegsInLiveMasks(&liveRegs, &liveVregs, params...);
329     auto enc = GetEncoder();
330     {
331         SCOPED_DISASM_STR(this, "Save caller saved regs");
332         SaveCallerRegisters(liveRegs, liveVregs, true);
333     }
334 
335     FillCallParams(std::forward<Args>(params)...);
336     CallIntrinsic(inst, intrinsicId);
337 
338     if (inst->GetSaveState() != nullptr) {
339         CreateStackMap(inst);
340     }
341 
342     if (dst.IsValid()) {
343         Reg retVal = GetTarget().GetReturnReg(dst.GetType());
344         if (dst.GetId() != retVal.GetId()) {
345             enc->SetRegister(&liveRegs, &liveVregs, retVal, true);
346         }
347         ASSERT(dst.IsScalar());
348         enc->EncodeMov(dst, retVal);
349     }
350 
351     {
352         SCOPED_DISASM_STR(this, "Restore caller saved regs");
353         enc->SetRegister(&liveRegs, &liveVregs, dst, false);
354         LoadCallerRegisters(liveRegs, liveVregs, true);
355     }
356 }
357 
358 template <typename T>
EncodeImms(const T & imms,bool skipFirstLocation)359 void Codegen::EncodeImms(const T &imms, bool skipFirstLocation)
360 {
361     auto paramInfo = GetCallingConvention()->GetParameterInfo(0);
362     auto immType = DataType::INT32;
363     if (skipFirstLocation) {
364         paramInfo->GetNextLocation(immType);
365     }
366     for (auto imm : imms) {
367         auto location = paramInfo->GetNextLocation(immType);
368         ASSERT(location.IsFixedRegister());
369         auto dstReg = ConvertRegister(location.GetValue(), immType);
370         GetEncoder()->EncodeMov(dstReg, Imm(imm));
371     }
372 }
373 
374 template <typename... Args>
375 void FillPostWrbCallParams(MemRef mem, Args &&...params);
376 
377 template <size_t IMM_ARRAY_SIZE>
378 class Codegen::FillCallParamsHelper {
379 public:
380     using ImmsIter = typename std::array<std::pair<Reg, Imm>, IMM_ARRAY_SIZE>::iterator;
381 
FillCallParamsHelper(Codegen * cg,ParameterInfo * paramInfo,SpillFillInst * regMoves,ArenaVector<Reg> * spMoves,ImmsIter immsIter)382     FillCallParamsHelper(Codegen *cg, ParameterInfo *paramInfo, SpillFillInst *regMoves, ArenaVector<Reg> *spMoves,
383                          ImmsIter immsIter)
384         : cg_(cg), paramInfo_(paramInfo), regMoves_(regMoves), spMoves_(spMoves), immsIter_(immsIter)
385     {
386     }
387 
388     template <typename Arg, typename... Args>
FillCallParamsHandleOperands(Arg && arg,Args &&...params)389     ALWAYS_INLINE void FillCallParamsHandleOperands(Arg &&arg, Args &&...params)
390     {
391         Location dst;
392         auto type = arg.GetType().ToDataType();
393         dst = paramInfo_->GetNextLocation(type);
394         if (dst.IsStackArgument()) {
395             cg_->GetEncoder()->SetFalseResult();
396             UNREACHABLE();  // Move to BoundaryFrame
397         }
398 
399         static_assert(std::is_same_v<std::decay_t<Arg>, TypedImm> || std::is_convertible_v<Arg, Reg>);
400         if constexpr (std::is_same_v<std::decay_t<Arg>, TypedImm>) {
401             auto reg = cg_->ConvertRegister(dst.GetValue(), type);
402             *immsIter_ = {reg, arg.GetImm()};
403             immsIter_++;
404         } else {
405             Reg reg(std::forward<Arg>(arg));
406             if (reg == cg_->SpReg()) {
407                 // SP should be handled separately, since on the ARM64 target it has ID out of range
408                 spMoves_->emplace_back(cg_->ConvertRegister(dst.GetValue(), type));
409             } else {
410                 regMoves_->AddSpillFill(Location::MakeRegister(reg.GetId(), type), dst, type);
411             }
412         }
413         if constexpr (sizeof...(Args) != 0) {
414             FillCallParamsHandleOperands(std::forward<Args>(params)...);
415         }
416     }
417 
418 private:
419     Codegen *cg_ {};
420     ParameterInfo *paramInfo_ {};
421     SpillFillInst *regMoves_ {};
422     ArenaVector<Reg> *spMoves_ {};
423     ImmsIter immsIter_ {};
424 };
425 
426 template <typename T, typename... Args>
CountParameters()427 constexpr std::pair<size_t, size_t> CountParameters()
428 {
429     static_assert(std::is_same_v<std::decay_t<T>, TypedImm> != std::is_convertible_v<T, Reg>);
430     if constexpr (sizeof...(Args) != 0) {
431         constexpr auto IMM_REG_COUNT = CountParameters<Args...>();
432 
433         if constexpr (std::is_same_v<std::decay_t<T>, TypedImm>) {
434             return {IMM_REG_COUNT.first + 1, IMM_REG_COUNT.second};
435         } else if constexpr (std::is_convertible_v<T, Reg>) {
436             return {IMM_REG_COUNT.first, IMM_REG_COUNT.second + 1};
437         }
438     }
439     return {std::is_same_v<std::decay_t<T>, TypedImm>, std::is_convertible_v<T, Reg>};
440 }
441 
442 template <typename... Args>
FillCallParams(Args &&...params)443 void Codegen::FillCallParams(Args &&...params)
444 {
445     SCOPED_DISASM_STR(this, "FillCallParams");
446     if constexpr (sizeof...(Args) != 0) {
447         constexpr size_t IMMEDIATES_COUNT = CountParameters<Args...>().first;
448         constexpr size_t REGS_COUNT = CountParameters<Args...>().second;
449         // Native call - do not add reserve parameters
450         auto paramInfo = GetCallingConvention()->GetParameterInfo(0);
451         std::array<std::pair<Reg, Imm>, IMMEDIATES_COUNT> immediates {};
452         ArenaVector<Reg> spMoves(GetLocalAllocator()->Adapter());
453         auto regMoves = GetGraph()->CreateInstSpillFill();
454         spMoves.reserve(REGS_COUNT);
455         regMoves->GetSpillFills().reserve(REGS_COUNT);
456 
457         FillCallParamsHelper<IMMEDIATES_COUNT> h {this, paramInfo, regMoves, &spMoves, immediates.begin()};
458         h.FillCallParamsHandleOperands(std::forward<Args>(params)...);
459 
460         // Resolve registers move order and encode
461         spillFillsResolver_.ResolveIfRequired(regMoves);
462         SpillFillEncoder(this, regMoves).EncodeSpillFill();
463 
464         // Encode immediates moves
465         for (auto &immValues : immediates) {
466             GetEncoder()->EncodeMov(immValues.first, immValues.second);
467         }
468 
469         // Encode moves from SP reg
470         for (auto dst : spMoves) {
471             GetEncoder()->EncodeMov(dst, SpReg());
472         }
473     }
474 }
475 
476 template <typename... Args>
FillPostWrbCallParams(MemRef mem,Args &&...params)477 void Codegen::FillPostWrbCallParams(MemRef mem, Args &&...params)
478 {
479     auto base {mem.GetBase().As(TypeInfo::FromDataType(DataType::REFERENCE, GetArch()))};
480     if (mem.HasIndex()) {
481         ASSERT(mem.GetScale() == 0 && !mem.HasDisp());
482         FillCallParams(base, mem.GetIndex(), std::forward<Args>(params)...);
483     } else {
484         FillCallParams(base, TypedImm(mem.GetDisp()), std::forward<Args>(params)...);
485     }
486 }
487 
488 }  // namespace ark::compiler
489 
490 #endif  // COMPILER_OPTIMIZER_CODEGEN_CODEGEN_H
491