1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/loop_fusion.h"
16
17 #include <algorithm>
18 #include <vector>
19
20 #include "source/opt/ir_context.h"
21 #include "source/opt/loop_dependence.h"
22 #include "source/opt/loop_descriptor.h"
23
24 namespace spvtools {
25 namespace opt {
26
27 namespace {
28
29 // Append all the loops nested in |loop| to |loops|.
CollectChildren(Loop * loop,std::vector<const Loop * > * loops)30 void CollectChildren(Loop* loop, std::vector<const Loop*>* loops) {
31 for (auto child : *loop) {
32 loops->push_back(child);
33 if (child->NumImmediateChildren() != 0) {
34 CollectChildren(child, loops);
35 }
36 }
37 }
38
39 // Return the set of locations accessed by |stores| and |loads|.
GetLocationsAccessed(const std::map<Instruction *,std::vector<Instruction * >> & stores,const std::map<Instruction *,std::vector<Instruction * >> & loads)40 std::set<Instruction*> GetLocationsAccessed(
41 const std::map<Instruction*, std::vector<Instruction*>>& stores,
42 const std::map<Instruction*, std::vector<Instruction*>>& loads) {
43 std::set<Instruction*> locations{};
44
45 for (const auto& kv : stores) {
46 locations.insert(std::get<0>(kv));
47 }
48
49 for (const auto& kv : loads) {
50 locations.insert(std::get<0>(kv));
51 }
52
53 return locations;
54 }
55
56 // Append all dependences from |sources| to |destinations| to |dependences|.
GetDependences(std::vector<DistanceVector> * dependences,LoopDependenceAnalysis * analysis,const std::vector<Instruction * > & sources,const std::vector<Instruction * > & destinations,size_t num_entries)57 void GetDependences(std::vector<DistanceVector>* dependences,
58 LoopDependenceAnalysis* analysis,
59 const std::vector<Instruction*>& sources,
60 const std::vector<Instruction*>& destinations,
61 size_t num_entries) {
62 for (auto source : sources) {
63 for (auto destination : destinations) {
64 DistanceVector dist(num_entries);
65 if (!analysis->GetDependence(source, destination, &dist)) {
66 dependences->push_back(dist);
67 }
68 }
69 }
70 }
71
72 // Apped all instructions in |block| to |instructions|.
AddInstructionsInBlock(std::vector<Instruction * > * instructions,BasicBlock * block)73 void AddInstructionsInBlock(std::vector<Instruction*>* instructions,
74 BasicBlock* block) {
75 for (auto& inst : *block) {
76 instructions->push_back(&inst);
77 }
78
79 instructions->push_back(block->GetLabelInst());
80 }
81
82 } // namespace
83
UsedInContinueOrConditionBlock(Instruction * phi_instruction,Loop * loop)84 bool LoopFusion::UsedInContinueOrConditionBlock(Instruction* phi_instruction,
85 Loop* loop) {
86 auto condition_block = loop->FindConditionBlock()->id();
87 auto continue_block = loop->GetContinueBlock()->id();
88 auto not_used = context_->get_def_use_mgr()->WhileEachUser(
89 phi_instruction,
90 [this, condition_block, continue_block](Instruction* instruction) {
91 auto block_id = context_->get_instr_block(instruction)->id();
92 return block_id != condition_block && block_id != continue_block;
93 });
94
95 return !not_used;
96 }
97
RemoveIfNotUsedContinueOrConditionBlock(std::vector<Instruction * > * instructions,Loop * loop)98 void LoopFusion::RemoveIfNotUsedContinueOrConditionBlock(
99 std::vector<Instruction*>* instructions, Loop* loop) {
100 instructions->erase(
101 std::remove_if(std::begin(*instructions), std::end(*instructions),
102 [this, loop](Instruction* instruction) {
103 return !UsedInContinueOrConditionBlock(instruction,
104 loop);
105 }),
106 std::end(*instructions));
107 }
108
AreCompatible()109 bool LoopFusion::AreCompatible() {
110 // Check that the loops are in the same function.
111 if (loop_0_->GetHeaderBlock()->GetParent() !=
112 loop_1_->GetHeaderBlock()->GetParent()) {
113 return false;
114 }
115
116 // Check that both loops have pre-header blocks.
117 if (!loop_0_->GetPreHeaderBlock() || !loop_1_->GetPreHeaderBlock()) {
118 return false;
119 }
120
121 // Check there are no breaks.
122 if (context_->cfg()->preds(loop_0_->GetMergeBlock()->id()).size() != 1 ||
123 context_->cfg()->preds(loop_1_->GetMergeBlock()->id()).size() != 1) {
124 return false;
125 }
126
127 // Check there are no continues.
128 if (context_->cfg()->preds(loop_0_->GetContinueBlock()->id()).size() != 1 ||
129 context_->cfg()->preds(loop_1_->GetContinueBlock()->id()).size() != 1) {
130 return false;
131 }
132
133 // |GetInductionVariables| returns all OpPhi in the header. Check that both
134 // loops have exactly one that is used in the continue and condition blocks.
135 std::vector<Instruction*> inductions_0{}, inductions_1{};
136 loop_0_->GetInductionVariables(inductions_0);
137 RemoveIfNotUsedContinueOrConditionBlock(&inductions_0, loop_0_);
138
139 if (inductions_0.size() != 1) {
140 return false;
141 }
142
143 induction_0_ = inductions_0.front();
144
145 loop_1_->GetInductionVariables(inductions_1);
146 RemoveIfNotUsedContinueOrConditionBlock(&inductions_1, loop_1_);
147
148 if (inductions_1.size() != 1) {
149 return false;
150 }
151
152 induction_1_ = inductions_1.front();
153
154 if (!CheckInit()) {
155 return false;
156 }
157
158 if (!CheckCondition()) {
159 return false;
160 }
161
162 if (!CheckStep()) {
163 return false;
164 }
165
166 // Check adjacency, |loop_0_| should come just before |loop_1_|.
167 // There is always at least one block between loops, even if it's empty.
168 // We'll check at most 2 preceding blocks.
169
170 auto pre_header_1 = loop_1_->GetPreHeaderBlock();
171
172 std::vector<BasicBlock*> block_to_check{};
173 block_to_check.push_back(pre_header_1);
174
175 if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
176 // Follow CFG for one more block.
177 auto preds = context_->cfg()->preds(pre_header_1->id());
178 if (preds.size() == 1) {
179 auto block = &*containing_function_->FindBlock(preds.front());
180 if (block == loop_0_->GetMergeBlock()) {
181 block_to_check.push_back(block);
182 } else {
183 return false;
184 }
185 } else {
186 return false;
187 }
188 }
189
190 // Check that the separating blocks are either empty or only contains a store
191 // to a local variable that is never read (left behind by
192 // '--eliminate-local-multi-store'). Also allow OpPhi, since the loop could be
193 // in LCSSA form.
194 for (auto block : block_to_check) {
195 for (auto& inst : *block) {
196 if (inst.opcode() == SpvOpStore) {
197 // Get the definition of the target to check it's function scope so
198 // there are no observable side effects.
199 auto variable =
200 context_->get_def_use_mgr()->GetDef(inst.GetSingleWordInOperand(0));
201
202 if (variable->opcode() != SpvOpVariable ||
203 variable->GetSingleWordInOperand(0) != SpvStorageClassFunction) {
204 return false;
205 }
206
207 // Check the target is never loaded.
208 auto is_used = false;
209 context_->get_def_use_mgr()->ForEachUse(
210 inst.GetSingleWordInOperand(0),
211 [&is_used](Instruction* use_inst, uint32_t) {
212 if (use_inst->opcode() == SpvOpLoad) {
213 is_used = true;
214 }
215 });
216
217 if (is_used) {
218 return false;
219 }
220 } else if (inst.opcode() == SpvOpPhi) {
221 if (inst.NumInOperands() != 2) {
222 return false;
223 }
224 } else if (inst.opcode() != SpvOpBranch) {
225 return false;
226 }
227 }
228 }
229
230 return true;
231 } // namespace opt
232
ContainsBarriersOrFunctionCalls(Loop * loop)233 bool LoopFusion::ContainsBarriersOrFunctionCalls(Loop* loop) {
234 for (const auto& block : loop->GetBlocks()) {
235 for (const auto& inst : *containing_function_->FindBlock(block)) {
236 auto opcode = inst.opcode();
237 if (opcode == SpvOpFunctionCall || opcode == SpvOpControlBarrier ||
238 opcode == SpvOpMemoryBarrier || opcode == SpvOpTypeNamedBarrier ||
239 opcode == SpvOpNamedBarrierInitialize ||
240 opcode == SpvOpMemoryNamedBarrier) {
241 return true;
242 }
243 }
244 }
245
246 return false;
247 }
248
CheckInit()249 bool LoopFusion::CheckInit() {
250 int64_t loop_0_init;
251 if (!loop_0_->GetInductionInitValue(induction_0_, &loop_0_init)) {
252 return false;
253 }
254
255 int64_t loop_1_init;
256 if (!loop_1_->GetInductionInitValue(induction_1_, &loop_1_init)) {
257 return false;
258 }
259
260 if (loop_0_init != loop_1_init) {
261 return false;
262 }
263
264 return true;
265 }
266
CheckCondition()267 bool LoopFusion::CheckCondition() {
268 auto condition_0 = loop_0_->GetConditionInst();
269 auto condition_1 = loop_1_->GetConditionInst();
270
271 if (!loop_0_->IsSupportedCondition(condition_0->opcode()) ||
272 !loop_1_->IsSupportedCondition(condition_1->opcode())) {
273 return false;
274 }
275
276 if (condition_0->opcode() != condition_1->opcode()) {
277 return false;
278 }
279
280 for (uint32_t i = 0; i < condition_0->NumInOperandWords(); ++i) {
281 auto arg_0 = context_->get_def_use_mgr()->GetDef(
282 condition_0->GetSingleWordInOperand(i));
283 auto arg_1 = context_->get_def_use_mgr()->GetDef(
284 condition_1->GetSingleWordInOperand(i));
285
286 if (arg_0 == induction_0_ && arg_1 == induction_1_) {
287 continue;
288 }
289
290 if (arg_0 == induction_0_ && arg_1 != induction_1_) {
291 return false;
292 }
293
294 if (arg_1 == induction_1_ && arg_0 != induction_0_) {
295 return false;
296 }
297
298 if (arg_0 != arg_1) {
299 return false;
300 }
301 }
302
303 return true;
304 }
305
CheckStep()306 bool LoopFusion::CheckStep() {
307 auto scalar_analysis = context_->GetScalarEvolutionAnalysis();
308 SENode* induction_node_0 = scalar_analysis->SimplifyExpression(
309 scalar_analysis->AnalyzeInstruction(induction_0_));
310 if (!induction_node_0->AsSERecurrentNode()) {
311 return false;
312 }
313
314 SENode* induction_step_0 =
315 induction_node_0->AsSERecurrentNode()->GetCoefficient();
316 if (!induction_step_0->AsSEConstantNode()) {
317 return false;
318 }
319
320 SENode* induction_node_1 = scalar_analysis->SimplifyExpression(
321 scalar_analysis->AnalyzeInstruction(induction_1_));
322 if (!induction_node_1->AsSERecurrentNode()) {
323 return false;
324 }
325
326 SENode* induction_step_1 =
327 induction_node_1->AsSERecurrentNode()->GetCoefficient();
328 if (!induction_step_1->AsSEConstantNode()) {
329 return false;
330 }
331
332 if (*induction_step_0 != *induction_step_1) {
333 return false;
334 }
335
336 return true;
337 }
338
LocationToMemOps(const std::vector<Instruction * > & mem_ops)339 std::map<Instruction*, std::vector<Instruction*>> LoopFusion::LocationToMemOps(
340 const std::vector<Instruction*>& mem_ops) {
341 std::map<Instruction*, std::vector<Instruction*>> location_map{};
342
343 for (auto instruction : mem_ops) {
344 auto access_location = context_->get_def_use_mgr()->GetDef(
345 instruction->GetSingleWordInOperand(0));
346
347 while (access_location->opcode() == SpvOpAccessChain) {
348 access_location = context_->get_def_use_mgr()->GetDef(
349 access_location->GetSingleWordInOperand(0));
350 }
351
352 location_map[access_location].push_back(instruction);
353 }
354
355 return location_map;
356 }
357
358 std::pair<std::vector<Instruction*>, std::vector<Instruction*>>
GetLoadsAndStoresInLoop(Loop * loop)359 LoopFusion::GetLoadsAndStoresInLoop(Loop* loop) {
360 std::vector<Instruction*> loads{};
361 std::vector<Instruction*> stores{};
362
363 for (auto block_id : loop->GetBlocks()) {
364 if (block_id == loop->GetContinueBlock()->id()) {
365 continue;
366 }
367
368 for (auto& instruction : *containing_function_->FindBlock(block_id)) {
369 if (instruction.opcode() == SpvOpLoad) {
370 loads.push_back(&instruction);
371 } else if (instruction.opcode() == SpvOpStore) {
372 stores.push_back(&instruction);
373 }
374 }
375 }
376
377 return std::make_pair(loads, stores);
378 }
379
IsUsedInLoop(Instruction * instruction,Loop * loop)380 bool LoopFusion::IsUsedInLoop(Instruction* instruction, Loop* loop) {
381 auto not_used = context_->get_def_use_mgr()->WhileEachUser(
382 instruction, [this, loop](Instruction* user) {
383 auto block_id = context_->get_instr_block(user)->id();
384 return !loop->IsInsideLoop(block_id);
385 });
386
387 return !not_used;
388 }
389
IsLegal()390 bool LoopFusion::IsLegal() {
391 assert(AreCompatible() && "Fusion can't be legal, loops are not compatible.");
392
393 // Bail out if there are function calls as they could have side-effects that
394 // cause dependencies or if there are any barriers.
395 if (ContainsBarriersOrFunctionCalls(loop_0_) ||
396 ContainsBarriersOrFunctionCalls(loop_1_)) {
397 return false;
398 }
399
400 std::vector<Instruction*> phi_instructions{};
401 loop_0_->GetInductionVariables(phi_instructions);
402
403 // Check no OpPhi in |loop_0_| is used in |loop_1_|.
404 for (auto phi_instruction : phi_instructions) {
405 if (IsUsedInLoop(phi_instruction, loop_1_)) {
406 return false;
407 }
408 }
409
410 // Check no LCSSA OpPhi in merge block of |loop_0_| is used in |loop_1_|.
411 auto phi_used = false;
412 loop_0_->GetMergeBlock()->ForEachPhiInst(
413 [this, &phi_used](Instruction* phi_instruction) {
414 phi_used |= IsUsedInLoop(phi_instruction, loop_1_);
415 });
416
417 if (phi_used) {
418 return false;
419 }
420
421 // Grab loads & stores from both loops.
422 auto loads_stores_0 = GetLoadsAndStoresInLoop(loop_0_);
423 auto loads_stores_1 = GetLoadsAndStoresInLoop(loop_1_);
424
425 // Build memory location to operation maps.
426 auto load_locs_0 = LocationToMemOps(std::get<0>(loads_stores_0));
427 auto store_locs_0 = LocationToMemOps(std::get<1>(loads_stores_0));
428
429 auto load_locs_1 = LocationToMemOps(std::get<0>(loads_stores_1));
430 auto store_locs_1 = LocationToMemOps(std::get<1>(loads_stores_1));
431
432 // Get the locations accessed in both loops.
433 auto locations_0 = GetLocationsAccessed(store_locs_0, load_locs_0);
434 auto locations_1 = GetLocationsAccessed(store_locs_1, load_locs_1);
435
436 std::vector<Instruction*> potential_clashes{};
437
438 std::set_intersection(std::begin(locations_0), std::end(locations_0),
439 std::begin(locations_1), std::end(locations_1),
440 std::back_inserter(potential_clashes));
441
442 // If the loops don't access the same variables, the fusion is legal.
443 if (potential_clashes.empty()) {
444 return true;
445 }
446
447 // Find variables that have at least one store.
448 std::vector<Instruction*> potential_clashes_with_stores{};
449 for (auto location : potential_clashes) {
450 if (store_locs_0.find(location) != std::end(store_locs_0) ||
451 store_locs_1.find(location) != std::end(store_locs_1)) {
452 potential_clashes_with_stores.push_back(location);
453 }
454 }
455
456 // If there are only loads to the same variables, the fusion is legal.
457 if (potential_clashes_with_stores.empty()) {
458 return true;
459 }
460
461 // Else if loads and at least one store (across loops) to the same variable
462 // there is a potential dependence and we need to check the dependence
463 // distance.
464
465 // Find all the loops in this loop nest for the dependency analysis.
466 std::vector<const Loop*> loops{};
467
468 // Find the parents.
469 for (auto current_loop = loop_0_; current_loop != nullptr;
470 current_loop = current_loop->GetParent()) {
471 loops.push_back(current_loop);
472 }
473
474 auto this_loop_position = loops.size() - 1;
475 std::reverse(std::begin(loops), std::end(loops));
476
477 // Find the children.
478 CollectChildren(loop_0_, &loops);
479 CollectChildren(loop_1_, &loops);
480
481 // Check that any dependes created are legal. That means the fused loops do
482 // not have any dependencies with dependence distance greater than 0 that did
483 // not exist in the original loops.
484
485 LoopDependenceAnalysis analysis(context_, loops);
486
487 analysis.GetScalarEvolution()->AddLoopsToPretendAreTheSame(
488 {loop_0_, loop_1_});
489
490 for (auto location : potential_clashes_with_stores) {
491 // Analyse dependences from |loop_0_| to |loop_1_|.
492 std::vector<DistanceVector> dependences;
493 // Read-After-Write.
494 GetDependences(&dependences, &analysis, store_locs_0[location],
495 load_locs_1[location], loops.size());
496 // Write-After-Read.
497 GetDependences(&dependences, &analysis, load_locs_0[location],
498 store_locs_1[location], loops.size());
499 // Write-After-Write.
500 GetDependences(&dependences, &analysis, store_locs_0[location],
501 store_locs_1[location], loops.size());
502
503 // Check that the induction variables either don't appear in the subscripts
504 // or the dependence distance is negative.
505 for (const auto& dependence : dependences) {
506 const auto& entry = dependence.GetEntries()[this_loop_position];
507 if ((entry.dependence_information ==
508 DistanceEntry::DependenceInformation::DISTANCE &&
509 entry.distance < 1) ||
510 (entry.dependence_information ==
511 DistanceEntry::DependenceInformation::IRRELEVANT)) {
512 continue;
513 } else {
514 return false;
515 }
516 }
517 }
518
519 return true;
520 }
521
ReplacePhiParentWith(Instruction * inst,uint32_t orig_block,uint32_t new_block)522 void ReplacePhiParentWith(Instruction* inst, uint32_t orig_block,
523 uint32_t new_block) {
524 if (inst->GetSingleWordInOperand(1) == orig_block) {
525 inst->SetInOperand(1, {new_block});
526 } else {
527 inst->SetInOperand(3, {new_block});
528 }
529 }
530
Fuse()531 void LoopFusion::Fuse() {
532 assert(AreCompatible() && "Can't fuse, loops aren't compatible");
533 assert(IsLegal() && "Can't fuse, illegal");
534
535 // Save the pointers/ids, won't be found in the middle of doing modifications.
536 auto header_1 = loop_1_->GetHeaderBlock()->id();
537 auto condition_1 = loop_1_->FindConditionBlock()->id();
538 auto continue_1 = loop_1_->GetContinueBlock()->id();
539 auto continue_0 = loop_0_->GetContinueBlock()->id();
540 auto condition_block_of_0 = loop_0_->FindConditionBlock();
541
542 // Find the blocks whose branches need updating.
543 auto first_block_of_1 = &*(++containing_function_->FindBlock(condition_1));
544 auto last_block_of_1 = &*(--containing_function_->FindBlock(continue_1));
545 auto last_block_of_0 = &*(--containing_function_->FindBlock(continue_0));
546
547 // Update the branch for |last_block_of_loop_0| to go to |first_block_of_1|.
548 last_block_of_0->ForEachSuccessorLabel(
549 [first_block_of_1](uint32_t* succ) { *succ = first_block_of_1->id(); });
550
551 // Update the branch for the |last_block_of_loop_1| to go to the continue
552 // block of |loop_0_|.
553 last_block_of_1->ForEachSuccessorLabel(
554 [this](uint32_t* succ) { *succ = loop_0_->GetContinueBlock()->id(); });
555
556 // Update merge block id in the header of |loop_0_| to the merge block of
557 // |loop_1_|.
558 loop_0_->GetHeaderBlock()->ForEachInst([this](Instruction* inst) {
559 if (inst->opcode() == SpvOpLoopMerge) {
560 inst->SetInOperand(0, {loop_1_->GetMergeBlock()->id()});
561 }
562 });
563
564 // Update condition branch target in |loop_0_| to the merge block of
565 // |loop_1_|.
566 condition_block_of_0->ForEachInst([this](Instruction* inst) {
567 if (inst->opcode() == SpvOpBranchConditional) {
568 auto loop_0_merge_block_id = loop_0_->GetMergeBlock()->id();
569
570 if (inst->GetSingleWordInOperand(1) == loop_0_merge_block_id) {
571 inst->SetInOperand(1, {loop_1_->GetMergeBlock()->id()});
572 } else {
573 inst->SetInOperand(2, {loop_1_->GetMergeBlock()->id()});
574 }
575 }
576 });
577
578 // Move OpPhi instructions not corresponding to the induction variable from
579 // the header of |loop_1_| to the header of |loop_0_|.
580 std::vector<Instruction*> instructions_to_move{};
581 for (auto& instruction : *loop_1_->GetHeaderBlock()) {
582 if (instruction.opcode() == SpvOpPhi && &instruction != induction_1_) {
583 instructions_to_move.push_back(&instruction);
584 }
585 }
586
587 for (auto& it : instructions_to_move) {
588 it->RemoveFromList();
589 it->InsertBefore(induction_0_);
590 }
591
592 // Update the OpPhi parents to the correct blocks in |loop_0_|.
593 loop_0_->GetHeaderBlock()->ForEachPhiInst([this](Instruction* i) {
594 ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(),
595 loop_0_->GetPreHeaderBlock()->id());
596
597 ReplacePhiParentWith(i, loop_1_->GetContinueBlock()->id(),
598 loop_0_->GetContinueBlock()->id());
599 });
600
601 // Update instruction to block mapping & DefUseManager.
602 for (auto& phi_instruction : instructions_to_move) {
603 context_->set_instr_block(phi_instruction, loop_0_->GetHeaderBlock());
604 context_->get_def_use_mgr()->AnalyzeInstUse(phi_instruction);
605 }
606
607 // Replace the uses of the induction variable of |loop_1_| with that the
608 // induction variable of |loop_0_|.
609 context_->ReplaceAllUsesWith(induction_1_->result_id(),
610 induction_0_->result_id());
611
612 // Replace LCSSA OpPhi in merge block of |loop_0_|.
613 loop_0_->GetMergeBlock()->ForEachPhiInst([this](Instruction* instruction) {
614 context_->ReplaceAllUsesWith(instruction->result_id(),
615 instruction->GetSingleWordInOperand(0));
616 });
617
618 // Update LCSSA OpPhi in merge block of |loop_1_|.
619 loop_1_->GetMergeBlock()->ForEachPhiInst(
620 [condition_block_of_0](Instruction* instruction) {
621 instruction->SetInOperand(1, {condition_block_of_0->id()});
622 });
623
624 // Move the continue block of |loop_0_| after the last block of |loop_1_|.
625 containing_function_->MoveBasicBlockToAfter(continue_0, last_block_of_1);
626
627 // Gather all instructions to be killed from |loop_1_| (induction variable
628 // initialisation, header, condition and continue blocks).
629 std::vector<Instruction*> instr_to_delete{};
630 AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock());
631 AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock());
632 AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock());
633 AddInstructionsInBlock(&instr_to_delete, loop_1_->GetContinueBlock());
634
635 // There was an additional empty block between the loops, kill that too.
636 if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
637 AddInstructionsInBlock(&instr_to_delete, loop_0_->GetMergeBlock());
638 }
639
640 // Update the CFG, so it wouldn't need invalidating.
641 auto cfg = context_->cfg();
642
643 cfg->ForgetBlock(loop_1_->GetPreHeaderBlock());
644 cfg->ForgetBlock(loop_1_->GetHeaderBlock());
645 cfg->ForgetBlock(loop_1_->FindConditionBlock());
646 cfg->ForgetBlock(loop_1_->GetContinueBlock());
647
648 if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
649 cfg->ForgetBlock(loop_0_->GetMergeBlock());
650 }
651
652 cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetContinueBlock()->id());
653 cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id());
654
655 cfg->AddEdge(last_block_of_1->id(), loop_0_->GetContinueBlock()->id());
656
657 cfg->AddEdge(loop_0_->GetContinueBlock()->id(),
658 loop_1_->GetHeaderBlock()->id());
659
660 cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id());
661
662 // Update DefUseManager.
663 auto def_use_mgr = context_->get_def_use_mgr();
664
665 // Uses of labels that are in updated branches need analysing.
666 def_use_mgr->AnalyzeInstUse(last_block_of_0->terminator());
667 def_use_mgr->AnalyzeInstUse(last_block_of_1->terminator());
668 def_use_mgr->AnalyzeInstUse(loop_0_->GetHeaderBlock()->GetLoopMergeInst());
669 def_use_mgr->AnalyzeInstUse(condition_block_of_0->terminator());
670
671 // Update the LoopDescriptor, so it wouldn't need invalidating.
672 auto ld = context_->GetLoopDescriptor(containing_function_);
673
674 // Create a copy, so the iterator wouldn't be invalidated.
675 std::vector<Loop*> loops_to_add_remove{};
676 for (auto child_loop : *loop_1_) {
677 loops_to_add_remove.push_back(child_loop);
678 }
679
680 for (auto child_loop : loops_to_add_remove) {
681 loop_1_->RemoveChildLoop(child_loop);
682 loop_0_->AddNestedLoop(child_loop);
683 }
684
685 auto loop_1_blocks = loop_1_->GetBlocks();
686
687 for (auto block : loop_1_blocks) {
688 loop_1_->RemoveBasicBlock(block);
689 if (block != header_1 && block != condition_1 && block != continue_1) {
690 loop_0_->AddBasicBlock(block);
691 if ((*ld)[block] == loop_1_) {
692 ld->SetBasicBlockToLoop(block, loop_0_);
693 }
694 }
695
696 if ((*ld)[block] == loop_1_) {
697 ld->ForgetBasicBlock(block);
698 }
699 }
700
701 loop_1_->RemoveBasicBlock(loop_1_->GetPreHeaderBlock()->id());
702 ld->ForgetBasicBlock(loop_1_->GetPreHeaderBlock()->id());
703
704 if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
705 loop_0_->RemoveBasicBlock(loop_0_->GetMergeBlock()->id());
706 ld->ForgetBasicBlock(loop_0_->GetMergeBlock()->id());
707 }
708
709 loop_0_->SetMergeBlock(loop_1_->GetMergeBlock());
710
711 loop_1_->ClearBlocks();
712
713 ld->RemoveLoop(loop_1_);
714
715 // Kill unnecessary instructions and remove all empty blocks.
716 for (auto inst : instr_to_delete) {
717 context_->KillInst(inst);
718 }
719
720 containing_function_->RemoveEmptyBlocks();
721
722 // Invalidate analyses.
723 context_->InvalidateAnalysesExceptFor(
724 IRContext::Analysis::kAnalysisInstrToBlockMapping |
725 IRContext::Analysis::kAnalysisLoopAnalysis |
726 IRContext::Analysis::kAnalysisDefUse | IRContext::Analysis::kAnalysisCFG);
727 }
728
729 } // namespace opt
730 } // namespace spvtools
731