1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <cassert>
16 #include <sstream>
17 #include <string>
18
19 #include "source/diagnostic.h"
20 #include "source/lint/divergence_analysis.h"
21 #include "source/lint/lints.h"
22 #include "source/opt/basic_block.h"
23 #include "source/opt/cfg.h"
24 #include "source/opt/control_dependence.h"
25 #include "source/opt/def_use_manager.h"
26 #include "source/opt/dominator_analysis.h"
27 #include "source/opt/instruction.h"
28 #include "source/opt/ir_context.h"
29 #include "spirv-tools/libspirv.h"
30 #include "spirv/unified1/spirv.h"
31
32 namespace spvtools {
33 namespace lint {
34 namespace lints {
35 namespace {
36 // Returns the %name[id], where `name` is the first name associated with the
37 // given id, or just %id if one is not found.
GetFriendlyName(opt::IRContext * context,uint32_t id)38 std::string GetFriendlyName(opt::IRContext* context, uint32_t id) {
39 auto names = context->GetNames(id);
40 std::stringstream ss;
41 ss << "%";
42 if (names.empty()) {
43 ss << id;
44 } else {
45 opt::Instruction* inst_name = names.begin()->second;
46 if (inst_name->opcode() == SpvOpName) {
47 ss << names.begin()->second->GetInOperand(0).AsString();
48 ss << "[" << id << "]";
49 } else {
50 ss << id;
51 }
52 }
53 return ss.str();
54 }
55
InstructionHasDerivative(const opt::Instruction & inst)56 bool InstructionHasDerivative(const opt::Instruction& inst) {
57 static const SpvOp derivative_opcodes[] = {
58 // Implicit derivatives.
59 SpvOpImageSampleImplicitLod,
60 SpvOpImageSampleDrefImplicitLod,
61 SpvOpImageSampleProjImplicitLod,
62 SpvOpImageSampleProjDrefImplicitLod,
63 SpvOpImageSparseSampleImplicitLod,
64 SpvOpImageSparseSampleDrefImplicitLod,
65 SpvOpImageSparseSampleProjImplicitLod,
66 SpvOpImageSparseSampleProjDrefImplicitLod,
67 // Explicit derivatives.
68 SpvOpDPdx,
69 SpvOpDPdy,
70 SpvOpFwidth,
71 SpvOpDPdxFine,
72 SpvOpDPdyFine,
73 SpvOpFwidthFine,
74 SpvOpDPdxCoarse,
75 SpvOpDPdyCoarse,
76 SpvOpFwidthCoarse,
77 };
78 return std::find(std::begin(derivative_opcodes), std::end(derivative_opcodes),
79 inst.opcode()) != std::end(derivative_opcodes);
80 }
81
Warn(opt::IRContext * context,opt::Instruction * inst)82 spvtools::DiagnosticStream Warn(opt::IRContext* context,
83 opt::Instruction* inst) {
84 if (inst == nullptr) {
85 return DiagnosticStream({0, 0, 0}, context->consumer(), "", SPV_WARNING);
86 } else {
87 // TODO(kuhar): Use line numbers based on debug info.
88 return DiagnosticStream(
89 {0, 0, 0}, context->consumer(),
90 inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES),
91 SPV_WARNING);
92 }
93 }
94
PrintDivergenceFlow(opt::IRContext * context,DivergenceAnalysis div,uint32_t id)95 void PrintDivergenceFlow(opt::IRContext* context, DivergenceAnalysis div,
96 uint32_t id) {
97 opt::analysis::DefUseManager* def_use = context->get_def_use_mgr();
98 opt::CFG* cfg = context->cfg();
99 while (id != 0) {
100 bool is_block = def_use->GetDef(id)->opcode() == SpvOpLabel;
101 if (is_block) {
102 Warn(context, nullptr)
103 << "block " << GetFriendlyName(context, id) << " is divergent";
104 uint32_t source = div.GetDivergenceSource(id);
105 // Skip intermediate blocks.
106 while (source != 0 && def_use->GetDef(source)->opcode() == SpvOpLabel) {
107 id = source;
108 source = div.GetDivergenceSource(id);
109 }
110 if (source == 0) break;
111 spvtools::opt::Instruction* branch =
112 cfg->block(div.GetDivergenceDependenceSource(id))->terminator();
113 Warn(context, branch)
114 << "because it depends on a conditional branch on divergent value "
115 << GetFriendlyName(context, source) << "";
116 id = source;
117 } else {
118 Warn(context, nullptr)
119 << "value " << GetFriendlyName(context, id) << " is divergent";
120 uint32_t source = div.GetDivergenceSource(id);
121 opt::Instruction* def = def_use->GetDef(id);
122 opt::Instruction* source_def =
123 source == 0 ? nullptr : def_use->GetDef(source);
124 // First print data -> data dependencies.
125 while (source != 0 && source_def->opcode() != SpvOpLabel) {
126 Warn(context, def_use->GetDef(id))
127 << "because " << GetFriendlyName(context, id) << " uses value "
128 << GetFriendlyName(context, source)
129 << "in its definition, which is divergent";
130 id = source;
131 def = source_def;
132 source = div.GetDivergenceSource(id);
133 source_def = def_use->GetDef(source);
134 }
135 if (source == 0) {
136 Warn(context, def) << "because it has a divergent definition";
137 break;
138 }
139 Warn(context, def) << "because it is conditionally set in block "
140 << GetFriendlyName(context, source);
141 id = source;
142 }
143 }
144 }
145 } // namespace
146
CheckDivergentDerivatives(opt::IRContext * context)147 bool CheckDivergentDerivatives(opt::IRContext* context) {
148 DivergenceAnalysis div(*context);
149 for (opt::Function& func : *context->module()) {
150 div.Run(&func);
151 for (const opt::BasicBlock& bb : func) {
152 for (const opt::Instruction& inst : bb) {
153 if (InstructionHasDerivative(inst) &&
154 div.GetDivergenceLevel(bb.id()) >
155 DivergenceAnalysis::DivergenceLevel::kPartiallyUniform) {
156 Warn(context, nullptr)
157 << "derivative with divergent control flow"
158 << " located in block " << GetFriendlyName(context, bb.id());
159 PrintDivergenceFlow(context, div, bb.id());
160 }
161 }
162 }
163 }
164 return true;
165 }
166
167 } // namespace lints
168 } // namespace lint
169 } // namespace spvtools
170