• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "modify_maximal_reconvergence.h"
16 
17 #include "source/opt/ir_context.h"
18 #include "source/util/make_unique.h"
19 
20 namespace spvtools {
21 namespace opt {
22 
Process()23 Pass::Status ModifyMaximalReconvergence::Process() {
24   bool changed = false;
25   if (add_) {
26     changed = AddMaximalReconvergence();
27   } else {
28     changed = RemoveMaximalReconvergence();
29   }
30   return changed ? Pass::Status::SuccessWithChange
31                  : Pass::Status::SuccessWithoutChange;
32 }
33 
AddMaximalReconvergence()34 bool ModifyMaximalReconvergence::AddMaximalReconvergence() {
35   bool changed = false;
36   bool has_extension = false;
37   bool has_shader =
38       context()->get_feature_mgr()->HasCapability(spv::Capability::Shader);
39   for (auto extension : context()->extensions()) {
40     if (extension.GetOperand(0).AsString() == "SPV_KHR_maximal_reconvergence") {
41       has_extension = true;
42       break;
43     }
44   }
45 
46   std::unordered_set<uint32_t> entry_points_with_mode;
47   for (auto mode : get_module()->execution_modes()) {
48     if (spv::ExecutionMode(mode.GetSingleWordInOperand(1)) ==
49         spv::ExecutionMode::MaximallyReconvergesKHR) {
50       entry_points_with_mode.insert(mode.GetSingleWordInOperand(0));
51     }
52   }
53 
54   for (auto entry_point : get_module()->entry_points()) {
55     const uint32_t id = entry_point.GetSingleWordInOperand(1);
56     if (!entry_points_with_mode.count(id)) {
57       changed = true;
58       if (!has_extension) {
59         context()->AddExtension("SPV_KHR_maximal_reconvergence");
60         has_extension = true;
61       }
62       if (!has_shader) {
63         context()->AddCapability(spv::Capability::Shader);
64         has_shader = true;
65       }
66       context()->AddExecutionMode(MakeUnique<Instruction>(
67           context(), spv::Op::OpExecutionMode, 0, 0,
68           std::initializer_list<Operand>{
69               {SPV_OPERAND_TYPE_ID, {id}},
70               {SPV_OPERAND_TYPE_EXECUTION_MODE,
71                {static_cast<uint32_t>(
72                    spv::ExecutionMode::MaximallyReconvergesKHR)}}}));
73       entry_points_with_mode.insert(id);
74     }
75   }
76 
77   return changed;
78 }
79 
RemoveMaximalReconvergence()80 bool ModifyMaximalReconvergence::RemoveMaximalReconvergence() {
81   bool changed = false;
82   std::vector<Instruction*> to_remove;
83   Instruction* mode = &*get_module()->execution_mode_begin();
84   while (mode) {
85     if (mode->opcode() != spv::Op::OpExecutionMode &&
86         mode->opcode() != spv::Op::OpExecutionModeId) {
87       break;
88     }
89     if (spv::ExecutionMode(mode->GetSingleWordInOperand(1)) ==
90         spv::ExecutionMode::MaximallyReconvergesKHR) {
91       mode = context()->KillInst(mode);
92       changed = true;
93     } else {
94       mode = mode->NextNode();
95     }
96   }
97 
98   changed |=
99       context()->RemoveExtension(Extension::kSPV_KHR_maximal_reconvergence);
100   return changed;
101 }
102 }  // namespace opt
103 }  // namespace spvtools
104