1 // Copyright (c) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "modify_maximal_reconvergence.h"
16
17 #include "source/opt/ir_context.h"
18 #include "source/util/make_unique.h"
19
20 namespace spvtools {
21 namespace opt {
22
Process()23 Pass::Status ModifyMaximalReconvergence::Process() {
24 bool changed = false;
25 if (add_) {
26 changed = AddMaximalReconvergence();
27 } else {
28 changed = RemoveMaximalReconvergence();
29 }
30 return changed ? Pass::Status::SuccessWithChange
31 : Pass::Status::SuccessWithoutChange;
32 }
33
AddMaximalReconvergence()34 bool ModifyMaximalReconvergence::AddMaximalReconvergence() {
35 bool changed = false;
36 bool has_extension = false;
37 bool has_shader =
38 context()->get_feature_mgr()->HasCapability(spv::Capability::Shader);
39 for (auto extension : context()->extensions()) {
40 if (extension.GetOperand(0).AsString() == "SPV_KHR_maximal_reconvergence") {
41 has_extension = true;
42 break;
43 }
44 }
45
46 std::unordered_set<uint32_t> entry_points_with_mode;
47 for (auto mode : get_module()->execution_modes()) {
48 if (spv::ExecutionMode(mode.GetSingleWordInOperand(1)) ==
49 spv::ExecutionMode::MaximallyReconvergesKHR) {
50 entry_points_with_mode.insert(mode.GetSingleWordInOperand(0));
51 }
52 }
53
54 for (auto entry_point : get_module()->entry_points()) {
55 const uint32_t id = entry_point.GetSingleWordInOperand(1);
56 if (!entry_points_with_mode.count(id)) {
57 changed = true;
58 if (!has_extension) {
59 context()->AddExtension("SPV_KHR_maximal_reconvergence");
60 has_extension = true;
61 }
62 if (!has_shader) {
63 context()->AddCapability(spv::Capability::Shader);
64 has_shader = true;
65 }
66 context()->AddExecutionMode(MakeUnique<Instruction>(
67 context(), spv::Op::OpExecutionMode, 0, 0,
68 std::initializer_list<Operand>{
69 {SPV_OPERAND_TYPE_ID, {id}},
70 {SPV_OPERAND_TYPE_EXECUTION_MODE,
71 {static_cast<uint32_t>(
72 spv::ExecutionMode::MaximallyReconvergesKHR)}}}));
73 entry_points_with_mode.insert(id);
74 }
75 }
76
77 return changed;
78 }
79
RemoveMaximalReconvergence()80 bool ModifyMaximalReconvergence::RemoveMaximalReconvergence() {
81 bool changed = false;
82 std::vector<Instruction*> to_remove;
83 Instruction* mode = &*get_module()->execution_mode_begin();
84 while (mode) {
85 if (mode->opcode() != spv::Op::OpExecutionMode &&
86 mode->opcode() != spv::Op::OpExecutionModeId) {
87 break;
88 }
89 if (spv::ExecutionMode(mode->GetSingleWordInOperand(1)) ==
90 spv::ExecutionMode::MaximallyReconvergesKHR) {
91 mode = context()->KillInst(mode);
92 changed = true;
93 } else {
94 mode = mode->NextNode();
95 }
96 }
97
98 changed |=
99 context()->RemoveExtension(Extension::kSPV_KHR_maximal_reconvergence);
100 return changed;
101 }
102 } // namespace opt
103 } // namespace spvtools
104