• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimize_string_concat.h"
17 
18 #include "compiler_logger.h"
19 
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23 #include "optimizer/ir/analysis.h"
24 #include "optimizer/ir/datatype.h"
25 #include "optimizer/ir/graph.h"
26 #include "optimizer/ir/inst.h"
27 
28 #include "optimizer/ir/runtime_interface.h"
29 #include "optimizer/optimizations/cleanup.h"
30 #include "optimizer/optimizations/string_builder_utils.h"
31 
32 namespace ark::compiler {
33 
OptimizeStringConcat(Graph * graph)34 OptimizeStringConcat::OptimizeStringConcat(Graph *graph) : Optimization(graph) {}
35 
GetStringBuilderClassId(Graph * graph)36 RuntimeInterface::IdType GetStringBuilderClassId(Graph *graph)
37 {
38     auto runtime = graph->GetRuntime();
39     auto klass = runtime->GetStringBuilderClass();
40     return klass == nullptr ? 0 : runtime->GetClassIdWithinFile(graph->GetMethod(), klass);
41 }
42 
RunImpl()43 bool OptimizeStringConcat::RunImpl()
44 {
45     bool isApplied = false;
46 
47     if (GetStringBuilderClassId(GetGraph()) == 0) {
48         COMPILER_LOG(WARNING, OPTIMIZE_STRING_CONCAT) << "StringBuilder class not found";
49         return isApplied;
50     }
51 
52     if (GetGraph()->IsAotMode()) {
53         // NOTE(mivanov): Creating StringBuilder.ctor calls in AOT mode not yet supported
54         return isApplied;
55     }
56 
57     for (auto block : GetGraph()->GetBlocksRPO()) {
58         for (auto inst : block->Insts()) {
59             if (!IsMethodStringConcat(inst)) {
60                 continue;
61             }
62 
63             ReplaceStringConcatWithStringBuilderAppend(inst);
64             isApplied = true;
65         }
66     }
67 
68     COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Optimize String.concat complete";
69 
70     if (isApplied) {
71         GetGraph()->RunPass<compiler::Cleanup>();
72     }
73 
74     return isApplied;
75 }
76 
InvalidateAnalyses()77 void OptimizeStringConcat::InvalidateAnalyses()
78 {
79     GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
80     GetGraph()->InvalidateAnalysis<AliasAnalysis>();
81 }
82 
GetPhiConstantInput(Inst * phi)83 Inst *GetPhiConstantInput(Inst *phi)
84 {
85     ASSERT(phi->GetInputsCount() == 2U);  // NOLINT(readability-magic-numbers)
86     ASSERT(phi->GetDataFlowInput(1)->IsPhi());
87     auto inputInst0 = phi->GetDataFlowInput(0);
88     if (inputInst0->IsConst()) {
89         return inputInst0;
90     }
91     if (inputInst0->IsPhi()) {
92         return GetPhiConstantInput(inputInst0);
93     }
94     UNREACHABLE();
95 }
96 
GetArrayLength(Inst * newArray)97 Inst *GetArrayLength(Inst *newArray)
98 {
99     ASSERT(newArray->GetInputsCount() > 1);
100     auto inputInst1 = newArray->GetDataFlowInput(1);
101     if (inputInst1->IsConst()) {
102         return inputInst1;
103     }
104     if (inputInst1->IsPhi()) {
105         return GetPhiConstantInput(inputInst1);
106     }
107     UNREACHABLE();
108 }
109 
CreateInstructionStringBuilderInstance(Graph * graph,uint32_t pc,SaveStateInst * saveState)110 Inst *CreateInstructionStringBuilderInstance(Graph *graph, uint32_t pc, SaveStateInst *saveState)
111 {
112     auto runtime = graph->GetRuntime();
113     auto method = graph->GetMethod();
114 
115     auto classId = GetStringBuilderClassId(graph);
116     auto loadClass =
117         graph->CreateInstLoadAndInitClass(DataType::REFERENCE, pc, CopySaveState(graph, saveState),
118                                           TypeIdMixin {classId, method}, runtime->ResolveType(method, classId));
119     auto newObject = graph->CreateInstNewObject(DataType::REFERENCE, pc, loadClass, CopySaveState(graph, saveState),
120                                                 TypeIdMixin {classId, method});
121 
122     return newObject;
123 }
124 
CreateStringBuilderAppendStringIntrinsic(Graph * graph,Inst * instance,Inst * arg,SaveStateInst * saveState)125 IntrinsicInst *CreateStringBuilderAppendStringIntrinsic(Graph *graph, Inst *instance, Inst *arg,
126                                                         SaveStateInst *saveState)
127 {
128     auto appendIntrinsic = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderAppendStringIntrinsicId());
129     ASSERT(appendIntrinsic->RequireState());
130 
131     appendIntrinsic->SetType(DataType::REFERENCE);
132     auto saveStateClone = CopySaveState(graph, saveState);
133     appendIntrinsic->SetInputs(
134         graph->GetAllocator(),
135         {{instance, instance->GetType()}, {arg, arg->GetType()}, {saveStateClone, saveStateClone->GetType()}});
136 
137     return appendIntrinsic;
138 }
139 
CreateStringBuilderToStringIntrinsic(Graph * graph,Inst * instance,SaveStateInst * saveState)140 IntrinsicInst *CreateStringBuilderToStringIntrinsic(Graph *graph, Inst *instance, SaveStateInst *saveState)
141 {
142     auto toStringCall = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderToStringIntrinsicId());
143     ASSERT(toStringCall->RequireState());
144 
145     toStringCall->SetType(DataType::REFERENCE);
146     auto saveStateClone = CopySaveState(graph, saveState);
147     toStringCall->SetInputs(graph->GetAllocator(),
148                             {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
149 
150     return toStringCall;
151 }
152 
CreateStringBuilderDefaultConstructorCall(Graph * graph,Inst * instance,SaveStateInst * saveState)153 CallInst *CreateStringBuilderDefaultConstructorCall(Graph *graph, Inst *instance, SaveStateInst *saveState)
154 {
155     auto runtime = graph->GetRuntime();
156     auto method = runtime->GetStringBuilderDefaultConstructor();
157     auto methodId = runtime->GetMethodId(method);
158 
159     auto ctorCall = graph->CreateInstCallStatic(DataType::VOID, instance->GetPc(), methodId, method);
160     ASSERT(ctorCall->RequireState());
161 
162     auto saveStateClone = CopySaveState(graph, saveState);
163     ctorCall->SetInputs(graph->GetAllocator(),
164                         {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
165 
166     return ctorCall;
167 }
168 
CreateLoadArray(Graph * graph,Inst * array,Inst * index)169 Inst *CreateLoadArray(Graph *graph, Inst *array, Inst *index)
170 {
171     return graph->CreateInstLoadArray(DataType::REFERENCE, array->GetPc(), array, index);
172 }
173 
CreateLoadArray(Graph * graph,Inst * array,uint64_t index)174 Inst *CreateLoadArray(Graph *graph, Inst *array, uint64_t index)
175 {
176     return CreateLoadArray(graph, array, graph->FindOrCreateConstant(index));
177 }
178 
CreateLenArray(Graph * graph,Inst * newArray)179 Inst *CreateLenArray(Graph *graph, Inst *newArray)
180 {
181     return graph->CreateInstLenArray(DataType::INT32, newArray->GetPc(), newArray);
182 }
183 
FixBrokenSaveStates(Inst * source,Inst * target)184 void OptimizeStringConcat::FixBrokenSaveStates(Inst *source, Inst *target)
185 {
186     if (source->IsMovableObject()) {
187         ssb_.SearchAndCreateMissingObjInSaveState(GetGraph(), source, target);
188     }
189 }
190 
CreateAppendArgsIntrinsics(Inst * instance,Inst * args,uint64_t arrayLengthValue,SaveStateInst * saveState)191 void OptimizeStringConcat::CreateAppendArgsIntrinsics(Inst *instance, Inst *args, uint64_t arrayLengthValue,
192                                                       SaveStateInst *saveState)
193 {
194     for (uint64_t index = 0; index < arrayLengthValue; ++index) {
195         auto arg = CreateLoadArray(GetGraph(), args, index);
196         auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
197         InsertBeforeWithInputs(appendIntrinsic, saveState);
198 
199         FixBrokenSaveStates(arg, appendIntrinsic);
200         FixBrokenSaveStates(instance, appendIntrinsic);
201 
202         COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
203             << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
204     }
205 }
206 
CreateSafePoint(Graph * graph,uint32_t pc,SaveStateInst * saveState)207 Inst *CreateSafePoint(Graph *graph, uint32_t pc, SaveStateInst *saveState)
208 {
209     auto safePoint =
210         graph->CreateInstSafePoint(pc, graph->GetMethod(), saveState->GetCallerInst(), saveState->GetInliningDepth());
211 
212     for (size_t index = 0; index < saveState->GetInputsCount(); ++index) {
213         safePoint->AppendInput(saveState->GetInput(index));
214         safePoint->SetVirtualRegister(index, saveState->GetVirtualRegister(index));
215     }
216 
217     return safePoint;
218 }
219 
CreateAppendArgsLoop(Inst * instance,Inst * str,Inst * args,LengthMethodInst * arrayLength,Inst * concatCall)220 BasicBlock *OptimizeStringConcat::CreateAppendArgsLoop(Inst *instance, Inst *str, Inst *args,
221                                                        LengthMethodInst *arrayLength, Inst *concatCall)
222 {
223     auto preHeader = concatCall->GetBasicBlock();
224     auto postExit = preHeader->SplitBlockAfterInstruction(concatCall, false);
225     auto saveState = concatCall->GetSaveState();
226 
227     // Create loop CFG
228     auto header = GetGraph()->CreateEmptyBlock(preHeader);
229     auto backEdge = GetGraph()->CreateEmptyBlock(preHeader);
230     preHeader->AddSucc(header);
231     header->AddSucc(postExit);
232     header->AddSucc(backEdge);
233     backEdge->AddSucc(header);
234 
235     // Declare loop variables
236     auto start = GetGraph()->FindOrCreateConstant(0);
237     auto stop = arrayLength;
238     auto step = GetGraph()->FindOrCreateConstant(1);
239 
240     auto pc = instance->GetPc();
241 
242     // Build header
243     auto induction = GetGraph()->CreateInstPhi(DataType::INT32, pc);
244     auto safePoint = CreateSafePoint(GetGraph(), pc, saveState);
245     auto compare =
246         GetGraph()->CreateInstCompare(DataType::BOOL, pc, stop, induction, DataType::INT32, ConditionCode::CC_LE);
247     auto ifImm = GetGraph()->CreateInstIfImm(DataType::BOOL, pc, compare, 0, DataType::BOOL, ConditionCode::CC_NE);
248     header->AppendPhi(induction);
249     header->AppendInsts({
250         safePoint,
251         compare,
252         ifImm,
253     });
254 
255     // Build back edge
256     auto arg = CreateLoadArray(GetGraph(), args, induction);
257     auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
258     auto add = GetGraph()->CreateInstAdd(DataType::INT32, pc, induction, step);
259     backEdge->AppendInsts({
260         arg,
261         appendIntrinsic->GetSaveState(),
262         appendIntrinsic,
263         add,
264     });
265 
266     // Connect loop induction variable inputs
267     induction->AppendInput(start);
268     induction->AppendInput(add);
269 
270     FixBrokenSaveStates(str, appendIntrinsic);
271     FixBrokenSaveStates(args, appendIntrinsic);
272     FixBrokenSaveStates(arg, appendIntrinsic);
273     FixBrokenSaveStates(instance, appendIntrinsic);
274 
275     COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
276         << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
277 
278     return postExit;
279 }
280 
ReplaceStringConcatWithStringBuilderAppend(Inst * concatCall)281 void OptimizeStringConcat::ReplaceStringConcatWithStringBuilderAppend(Inst *concatCall)
282 {
283     // Input:
284     //  let result = str.concat(...args)
285     //
286     // Output:
287     //  let instance = new StringBuilder(str)
288     //  instance.append(args[0])
289     //  ...
290     //  instance.append(args[args.length-1])
291     //  let result = instance.toString()
292 
293     ASSERT(concatCall->GetInputsCount() > 1);
294 
295     auto str = concatCall->GetDataFlowInput(0);
296     auto args = concatCall->GetDataFlowInput(1);
297 
298     auto instance = CreateInstructionStringBuilderInstance(GetGraph(), concatCall->GetPc(), concatCall->GetSaveState());
299     InsertBeforeWithInputs(instance, concatCall->GetSaveState());
300 
301     auto ctorCall = CreateStringBuilderDefaultConstructorCall(GetGraph(), instance, concatCall->GetSaveState());
302     InsertBeforeWithSaveState(ctorCall, concatCall->GetSaveState());
303     auto appendArgIntrinsic =
304         CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, str, concatCall->GetSaveState());
305     InsertBeforeWithSaveState(appendArgIntrinsic, concatCall->GetSaveState());
306     FixBrokenSaveStates(instance, appendArgIntrinsic);
307     FixBrokenSaveStates(str, appendArgIntrinsic);
308 
309     COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
310         << "Insert StringBuilder.append intrinsic (id=" << appendArgIntrinsic->GetId() << ")";
311 
312     auto toStringCall = CreateStringBuilderToStringIntrinsic(GetGraph(), instance, concatCall->GetSaveState());
313 
314     if (args->GetOpcode() == Opcode::NewArray) {
315         auto arrayLength = GetArrayLength(args);
316         CreateAppendArgsIntrinsics(instance, args, arrayLength->CastToConstant()->GetIntValue(),
317                                    concatCall->GetSaveState());
318         InsertBeforeWithSaveState(toStringCall, concatCall->GetSaveState());
319 
320         COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Replace String.concat call (id=" << concatCall->GetId()
321                                                     << ") with StringBuilder instance (id=" << instance->GetId() << ")";
322     } else {
323         auto arrayLength = CreateLenArray(GetGraph(), args);
324         concatCall->GetSaveState()->InsertBefore(arrayLength);
325         auto postExit = CreateAppendArgsLoop(instance, str, args, arrayLength->CastToLenArray(), concatCall);
326 
327         postExit->PrependInst(toStringCall);
328         postExit->PrependInst(toStringCall->GetSaveState());
329 
330         InvalidateBlocksOrderAnalyzes(GetGraph());
331         GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
332 
333         COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Replace String.concat call (id=" << concatCall->GetId()
334                                                     << ") with StringBuilder instance (id=" << instance->GetId() << ")";
335     }
336 
337     FixBrokenSaveStates(instance, toStringCall);
338 
339     concatCall->ReplaceUsers(toStringCall);
340 
341     concatCall->ClearFlag(inst_flags::NO_DCE);
342     if (concatCall->GetInput(0).GetInst()->IsCheck()) {
343         concatCall->GetInput(0).GetInst()->ClearFlag(inst_flags::NO_DCE);
344     }
345 }
346 
347 }  // namespace ark::compiler
348