• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_OPTIMIZER_CODEGEN_SCOPED_TMP_REG_H
17 #define COMPILER_OPTIMIZER_CODEGEN_SCOPED_TMP_REG_H
18 
19 #include "encode.h"
20 
21 namespace ark::compiler {
22 
23 /**
24  * This class is using to acquire/release temp register using RAII technique.
25  *
26  * @tparam lazy if true, temp register will be acquired in the constructor, otherwise user should acquire it explicitly.
27  */
28 template <bool LAZY>
29 class ScopedTmpRegImpl {
30 public:
ScopedTmpRegImpl(Encoder * encoder)31     explicit ScopedTmpRegImpl(Encoder *encoder) : ScopedTmpRegImpl(encoder, false) {}
ScopedTmpRegImpl(Encoder * encoder,bool withLr)32     ScopedTmpRegImpl(Encoder *encoder, bool withLr) : encoder_(encoder)
33     {
34         if constexpr (!LAZY) {  // NOLINT
35             auto linkReg = encoder->GetTarget().GetLinkReg();
36             withLr &= encoder->IsLrAsTempRegEnabled();
37             if (withLr && encoder->IsScratchRegisterReleased(linkReg)) {
38                 reg_ = linkReg;
39                 encoder->AcquireScratchRegister(linkReg);
40             } else {
41                 reg_ = encoder->AcquireScratchRegister(Is64BitsArch(encoder->GetArch()) ? INT64_TYPE : INT32_TYPE);
42             }
43         }
44     }
45 
ScopedTmpRegImpl(Encoder * encoder,TypeInfo type)46     ScopedTmpRegImpl(Encoder *encoder, TypeInfo type) : encoder_(encoder), reg_(encoder->AcquireScratchRegister(type))
47     {
48         static_assert(!LAZY);
49     }
50 
ScopedTmpRegImpl(Encoder * encoder,Reg reg)51     ScopedTmpRegImpl(Encoder *encoder, Reg reg) : encoder_(encoder), reg_(reg)
52     {
53         static_assert(!LAZY);
54         encoder->AcquireScratchRegister(reg);
55     }
56 
ScopedTmpRegImpl(ScopedTmpRegImpl && other)57     ScopedTmpRegImpl(ScopedTmpRegImpl &&other) noexcept
58     {
59         encoder_ = other.encoder_;
60         reg_ = other.reg_;
61         other.reg_ = Reg();
62         ASSERT(!other.reg_.IsValid());
63     }
64 
~ScopedTmpRegImpl()65     virtual ~ScopedTmpRegImpl()
66     {
67         if (reg_.IsValid()) {
68             encoder_->ReleaseScratchRegister(reg_);
69         }
70     }
71 
72     NO_COPY_SEMANTIC(ScopedTmpRegImpl);
73     NO_MOVE_OPERATOR(ScopedTmpRegImpl);
74 
GetReg()75     Reg GetReg() const
76     {
77         return reg_;
78     }
79 
GetType()80     TypeInfo GetType() const
81     {
82         return reg_.GetType();
83     }
84 
85     // NOLINTNEXTLINE(*-explicit-constructor)
Reg()86     operator Reg() const
87     {
88         return reg_;
89     }
90 
ChangeType(TypeInfo tp)91     void ChangeType(TypeInfo tp)
92     {
93         ASSERT(tp.IsScalar() == reg_.IsScalar());
94         reg_ = Reg(reg_.GetId(), tp);
95     }
96 
Release()97     virtual void Release()
98     {
99         if (reg_.IsValid()) {
100             encoder_->ReleaseScratchRegister(reg_);
101             reg_ = INVALID_REGISTER;
102         }
103     }
104 
Acquire()105     void Acquire()
106     {
107         ASSERT(!reg_.IsValid());
108         reg_ = encoder_->AcquireScratchRegister(Is64BitsArch(encoder_->GetArch()) ? INT64_TYPE : INT32_TYPE);
109         ASSERT(reg_.IsValid());
110     }
111 
AcquireWithLr()112     void AcquireWithLr()
113     {
114         ASSERT(!reg_.IsValid());
115         auto linkReg = encoder_->GetTarget().GetLinkReg();
116         if (encoder_->IsLrAsTempRegEnabled() && encoder_->IsScratchRegisterReleased(linkReg)) {
117             reg_ = linkReg;
118             encoder_->AcquireScratchRegister(linkReg);
119         } else {
120             reg_ = encoder_->AcquireScratchRegister(Is64BitsArch(encoder_->GetArch()) ? INT64_TYPE : INT32_TYPE);
121         }
122         ASSERT(reg_.IsValid());
123     }
124 
AcquireIfInvalid()125     void AcquireIfInvalid()
126     {
127         if (!reg_.IsValid()) {
128             reg_ = encoder_->AcquireScratchRegister(Is64BitsArch(encoder_->GetArch()) ? INT64_TYPE : INT32_TYPE);
129             ASSERT(reg_.IsValid());
130         }
131     }
132 
133 protected:
GetEncoder()134     Encoder *GetEncoder()
135     {
136         return encoder_;
137     }
138 
139 private:
140     Encoder *encoder_ {nullptr};
141     Reg reg_;
142 };
143 
144 struct ScopedTmpReg : public ScopedTmpRegImpl<false> {
145     using ScopedTmpRegImpl<false>::ScopedTmpRegImpl;
146 };
147 
148 struct ScopedTmpRegLazy : public ScopedTmpRegImpl<true> {
149     using ScopedTmpRegImpl<true>::ScopedTmpRegImpl;
150 };
151 
152 struct ScopedTmpRegU16 : public ScopedTmpReg {
ScopedTmpRegU16ScopedTmpRegU16153     explicit ScopedTmpRegU16(Encoder *encoder) : ScopedTmpReg(encoder, INT16_TYPE) {}
154 };
155 
156 struct ScopedTmpRegU32 : public ScopedTmpReg {
ScopedTmpRegU32ScopedTmpRegU32157     explicit ScopedTmpRegU32(Encoder *encoder) : ScopedTmpReg(encoder, INT32_TYPE) {}
158 };
159 
160 struct ScopedTmpRegU64 : public ScopedTmpReg {
ScopedTmpRegU64ScopedTmpRegU64161     explicit ScopedTmpRegU64(Encoder *encoder) : ScopedTmpReg(encoder, INT64_TYPE) {}
162 };
163 
164 struct ScopedTmpRegF32 : public ScopedTmpReg {
ScopedTmpRegF32ScopedTmpRegF32165     explicit ScopedTmpRegF32(Encoder *encoder) : ScopedTmpReg(encoder, FLOAT32_TYPE) {}
166 };
167 
168 struct ScopedTmpRegF64 : public ScopedTmpReg {
ScopedTmpRegF64ScopedTmpRegF64169     explicit ScopedTmpRegF64(Encoder *encoder) : ScopedTmpReg(encoder, FLOAT64_TYPE) {}
170 };
171 
172 struct ScopedTmpRegRef : public ScopedTmpReg {
ScopedTmpRegRefScopedTmpRegRef173     explicit ScopedTmpRegRef(Encoder *encoder) : ScopedTmpReg(encoder, encoder->GetRefType()) {}
174 };
175 
176 class ScopedLiveTmpReg : public ScopedTmpReg {
177 public:
ScopedLiveTmpReg(Encoder * encoder)178     explicit ScopedLiveTmpReg(Encoder *encoder) : ScopedTmpReg(encoder)
179     {
180         encoder->AddRegInLiveMask(GetReg());
181     }
ScopedLiveTmpReg(Encoder * encoder,TypeInfo type)182     ScopedLiveTmpReg(Encoder *encoder, TypeInfo type) : ScopedTmpReg(encoder, type)
183     {
184         encoder->AddRegInLiveMask(GetReg());
185     }
186 
Release()187     void Release() final
188     {
189         ASSERT(GetReg().IsValid());
190         GetEncoder()->RemoveRegFromLiveMask(GetReg());
191         ScopedTmpReg::Release();
192     }
193 
~ScopedLiveTmpReg()194     ~ScopedLiveTmpReg() override
195     {
196         if (GetReg().IsValid()) {
197             GetEncoder()->RemoveRegFromLiveMask(GetReg());
198         }
199     }
200 
201     NO_COPY_SEMANTIC(ScopedLiveTmpReg);
202     NO_MOVE_SEMANTIC(ScopedLiveTmpReg);
203 };
204 
205 }  // namespace ark::compiler
206 
207 #endif  // COMPILER_OPTIMIZER_CODEGEN_SCOPED_TMP_REG_H
208