1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "object_type_propagation.h"
17 #include "optimizer/ir/basicblock.h"
18 #include "optimizer/ir/inst.h"
19
20 namespace ark::compiler {
RunImpl()21 bool ObjectTypePropagation::RunImpl()
22 {
23 VisitGraph();
24 InstVector visitedPhis(GetGraph()->GetLocalAllocator()->Adapter());
25 visitedPhis_ = &visitedPhis;
26 visited_ = GetGraph()->NewMarker();
27 for (auto bb : GetGraph()->GetBlocksRPO()) {
28 for (auto phi : bb->PhiInsts()) {
29 auto typeInfo = GetPhiTypeInfo(phi);
30 for (auto visitedPhi : visitedPhis) {
31 ASSERT(visitedPhi->GetObjectTypeInfo() == ObjectTypeInfo::UNKNOWN);
32 visitedPhi->SetObjectTypeInfo(typeInfo);
33 }
34 visitedPhis.clear();
35 }
36 }
37 GetGraph()->EraseMarker(visited_);
38 return true;
39 }
40
VisitNewObject(GraphVisitor * v,Inst * i)41 void ObjectTypePropagation::VisitNewObject(GraphVisitor *v, Inst *i)
42 {
43 auto self = static_cast<ObjectTypePropagation *>(v);
44 auto inst = i->CastToNewObject();
45 auto klass = self->GetGraph()->GetRuntime()->GetClass(inst->GetMethod(), inst->GetTypeId());
46 if (klass != nullptr) {
47 inst->SetObjectTypeInfo({klass, true});
48 }
49 }
50
VisitNewArray(GraphVisitor * v,Inst * i)51 void ObjectTypePropagation::VisitNewArray(GraphVisitor *v, Inst *i)
52 {
53 auto self = static_cast<ObjectTypePropagation *>(v);
54 auto inst = i->CastToNewArray();
55 auto klass = self->GetGraph()->GetRuntime()->GetClass(inst->GetMethod(), inst->GetTypeId());
56 if (klass != nullptr) {
57 inst->SetObjectTypeInfo({klass, true});
58 }
59 }
60
VisitLoadString(GraphVisitor * v,Inst * i)61 void ObjectTypePropagation::VisitLoadString(GraphVisitor *v, Inst *i)
62 {
63 auto self = static_cast<ObjectTypePropagation *>(v);
64 auto inst = i->CastToLoadString();
65 auto klass = self->GetGraph()->GetRuntime()->GetStringClass(inst->GetMethod(), nullptr);
66 if (klass != nullptr) {
67 inst->SetObjectTypeInfo({klass, true});
68 }
69 }
70
VisitLoadArray(GraphVisitor * v,Inst * i)71 void ObjectTypePropagation::VisitLoadArray([[maybe_unused]] GraphVisitor *v, [[maybe_unused]] Inst *i)
72 {
73 // LoadArray should be processed more carefully, because it may contain object of the derived class with own method
74 // implementation. We need to check all array stores and method calls between NewArray and LoadArray.
75 // NOTE(mshertennikov): Support it.
76 }
77
VisitLoadObject(GraphVisitor * v,Inst * i)78 void ObjectTypePropagation::VisitLoadObject(GraphVisitor *v, Inst *i)
79 {
80 if (i->GetType() != DataType::REFERENCE || i->CastToLoadObject()->GetObjectType() != ObjectType::MEM_OBJECT) {
81 return;
82 }
83 auto self = static_cast<ObjectTypePropagation *>(v);
84 auto inst = i->CastToLoadObject();
85 auto fieldId = inst->GetTypeId();
86 if (fieldId == 0) {
87 return;
88 }
89 auto runtime = self->GetGraph()->GetRuntime();
90 auto method = inst->GetMethod();
91 auto typeId = runtime->GetFieldValueTypeId(method, fieldId);
92 auto klass = runtime->GetClass(method, typeId);
93 if (klass != nullptr) {
94 auto isExact = runtime->GetClassType(method, typeId) == ClassType::FINAL_CLASS;
95 inst->SetObjectTypeInfo({klass, isExact});
96 }
97 }
98
VisitCallStatic(GraphVisitor * v,Inst * i)99 void ObjectTypePropagation::VisitCallStatic(GraphVisitor *v, Inst *i)
100 {
101 ProcessManagedCall(v, i->CastToCallStatic());
102 }
103
VisitCallVirtual(GraphVisitor * v,Inst * i)104 void ObjectTypePropagation::VisitCallVirtual(GraphVisitor *v, Inst *i)
105 {
106 ProcessManagedCall(v, i->CastToCallVirtual());
107 }
108
VisitNullCheck(GraphVisitor * v,Inst * i)109 void ObjectTypePropagation::VisitNullCheck([[maybe_unused]] GraphVisitor *v, Inst *i)
110 {
111 auto inst = i->CastToNullCheck();
112 inst->SetObjectTypeInfo(inst->GetInput(0).GetInst()->GetObjectTypeInfo());
113 }
114
VisitRefTypeCheck(GraphVisitor * v,Inst * i)115 void ObjectTypePropagation::VisitRefTypeCheck([[maybe_unused]] GraphVisitor *v, Inst *i)
116 {
117 auto inst = i->CastToRefTypeCheck();
118 inst->SetObjectTypeInfo(inst->GetInput(0).GetInst()->GetObjectTypeInfo());
119 }
120
VisitParameter(GraphVisitor * v,Inst * i)121 void ObjectTypePropagation::VisitParameter([[maybe_unused]] GraphVisitor *v, Inst *i)
122 {
123 auto inst = i->CastToParameter();
124 auto graph = i->GetBasicBlock()->GetGraph();
125 if (inst->GetType() != DataType::REFERENCE || graph->IsBytecodeOptimizer() || inst->HasObjectTypeInfo()) {
126 return;
127 }
128 auto refNum = inst->GetArgRefNumber();
129 auto runtime = graph->GetRuntime();
130 auto method = graph->GetMethod();
131 RuntimeInterface::ClassPtr klass;
132 if (refNum == ParameterInst::INVALID_ARG_REF_NUM) {
133 // This parametr doesn't have ArgRefNumber
134 if (inst->GetArgNumber() != 0 || runtime->IsMethodStatic(method)) {
135 return;
136 }
137 klass = runtime->GetClass(method);
138 } else {
139 auto typeId = runtime->GetMethodArgReferenceTypeId(method, refNum);
140 klass = runtime->GetClass(method, typeId);
141 }
142 if (klass != nullptr) {
143 auto isExact = runtime->GetClassType(klass) == ClassType::FINAL_CLASS;
144 inst->SetObjectTypeInfo({klass, isExact});
145 }
146 }
147
ProcessManagedCall(GraphVisitor * v,CallInst * inst)148 void ObjectTypePropagation::ProcessManagedCall(GraphVisitor *v, CallInst *inst)
149 {
150 if (inst->GetType() != DataType::REFERENCE) {
151 return;
152 }
153 auto self = static_cast<ObjectTypePropagation *>(v);
154 auto runtime = self->GetGraph()->GetRuntime();
155 auto method = inst->GetCallMethod();
156 auto typeId = runtime->GetMethodReturnTypeId(method);
157 auto klass = runtime->GetClass(method, typeId);
158 if (klass != nullptr) {
159 auto isExact = runtime->GetClassType(method, typeId) == ClassType::FINAL_CLASS;
160 inst->SetObjectTypeInfo({klass, isExact});
161 }
162 }
163
GetPhiTypeInfo(Inst * inst)164 ObjectTypeInfo ObjectTypePropagation::GetPhiTypeInfo(Inst *inst)
165 {
166 if (!inst->IsPhi() || inst->SetMarker(visited_)) {
167 return inst->GetObjectTypeInfo();
168 }
169 auto typeInfo = ObjectTypeInfo::UNKNOWN;
170 inst->SetObjectTypeInfo(typeInfo);
171 bool needUpdate = false;
172 for (auto input : inst->GetInputs()) {
173 auto inputInfo = GetPhiTypeInfo(input.GetInst());
174 if (inputInfo == ObjectTypeInfo::UNKNOWN) {
175 ASSERT(input.GetInst()->IsPhi());
176 needUpdate = true;
177 continue;
178 }
179 if (inputInfo == ObjectTypeInfo::INVALID ||
180 (typeInfo.IsValid() && typeInfo.GetClass() != inputInfo.GetClass())) {
181 inst->SetObjectTypeInfo(ObjectTypeInfo::INVALID);
182 return ObjectTypeInfo::INVALID;
183 }
184 if (typeInfo == ObjectTypeInfo::UNKNOWN) {
185 typeInfo = inputInfo;
186 continue;
187 }
188 typeInfo = {typeInfo.GetClass(), typeInfo.IsExact() && inputInfo.IsExact()};
189 }
190 if (needUpdate) {
191 visitedPhis_->push_back(inst);
192 } else {
193 inst->SetObjectTypeInfo(typeInfo);
194 }
195 return typeInfo;
196 }
197
198 } // namespace ark::compiler
199