• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cassert>
16 #include <sstream>
17 #include <string>
18 
19 #include "source/diagnostic.h"
20 #include "source/lint/divergence_analysis.h"
21 #include "source/lint/lints.h"
22 #include "source/opt/basic_block.h"
23 #include "source/opt/cfg.h"
24 #include "source/opt/control_dependence.h"
25 #include "source/opt/def_use_manager.h"
26 #include "source/opt/dominator_analysis.h"
27 #include "source/opt/instruction.h"
28 #include "source/opt/ir_context.h"
29 #include "spirv-tools/libspirv.h"
30 
31 namespace spvtools {
32 namespace lint {
33 namespace lints {
34 namespace {
35 // Returns the %name[id], where `name` is the first name associated with the
36 // given id, or just %id if one is not found.
GetFriendlyName(opt::IRContext * context,uint32_t id)37 std::string GetFriendlyName(opt::IRContext* context, uint32_t id) {
38   auto names = context->GetNames(id);
39   std::stringstream ss;
40   ss << "%";
41   if (names.empty()) {
42     ss << id;
43   } else {
44     opt::Instruction* inst_name = names.begin()->second;
45     if (inst_name->opcode() == spv::Op::OpName) {
46       ss << names.begin()->second->GetInOperand(0).AsString();
47       ss << "[" << id << "]";
48     } else {
49       ss << id;
50     }
51   }
52   return ss.str();
53 }
54 
InstructionHasDerivative(const opt::Instruction & inst)55 bool InstructionHasDerivative(const opt::Instruction& inst) {
56   static const spv::Op derivative_opcodes[] = {
57       // Implicit derivatives.
58       spv::Op::OpImageSampleImplicitLod,
59       spv::Op::OpImageSampleDrefImplicitLod,
60       spv::Op::OpImageSampleProjImplicitLod,
61       spv::Op::OpImageSampleProjDrefImplicitLod,
62       spv::Op::OpImageSparseSampleImplicitLod,
63       spv::Op::OpImageSparseSampleDrefImplicitLod,
64       spv::Op::OpImageSparseSampleProjImplicitLod,
65       spv::Op::OpImageSparseSampleProjDrefImplicitLod,
66       // Explicit derivatives.
67       spv::Op::OpDPdx,
68       spv::Op::OpDPdy,
69       spv::Op::OpFwidth,
70       spv::Op::OpDPdxFine,
71       spv::Op::OpDPdyFine,
72       spv::Op::OpFwidthFine,
73       spv::Op::OpDPdxCoarse,
74       spv::Op::OpDPdyCoarse,
75       spv::Op::OpFwidthCoarse,
76   };
77   return std::find(std::begin(derivative_opcodes), std::end(derivative_opcodes),
78                    inst.opcode()) != std::end(derivative_opcodes);
79 }
80 
Warn(opt::IRContext * context,opt::Instruction * inst)81 spvtools::DiagnosticStream Warn(opt::IRContext* context,
82                                 opt::Instruction* inst) {
83   if (inst == nullptr) {
84     return DiagnosticStream({0, 0, 0}, context->consumer(), "", SPV_WARNING);
85   } else {
86     // TODO(kuhar): Use line numbers based on debug info.
87     return DiagnosticStream(
88         {0, 0, 0}, context->consumer(),
89         inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES),
90         SPV_WARNING);
91   }
92 }
93 
PrintDivergenceFlow(opt::IRContext * context,DivergenceAnalysis div,uint32_t id)94 void PrintDivergenceFlow(opt::IRContext* context, DivergenceAnalysis div,
95                          uint32_t id) {
96   opt::analysis::DefUseManager* def_use = context->get_def_use_mgr();
97   opt::CFG* cfg = context->cfg();
98   while (id != 0) {
99     bool is_block = def_use->GetDef(id)->opcode() == spv::Op::OpLabel;
100     if (is_block) {
101       Warn(context, nullptr)
102           << "block " << GetFriendlyName(context, id) << " is divergent";
103       uint32_t source = div.GetDivergenceSource(id);
104       // Skip intermediate blocks.
105       while (source != 0 &&
106              def_use->GetDef(source)->opcode() == spv::Op::OpLabel) {
107         id = source;
108         source = div.GetDivergenceSource(id);
109       }
110       if (source == 0) break;
111       spvtools::opt::Instruction* branch =
112           cfg->block(div.GetDivergenceDependenceSource(id))->terminator();
113       Warn(context, branch)
114           << "because it depends on a conditional branch on divergent value "
115           << GetFriendlyName(context, source) << "";
116       id = source;
117     } else {
118       Warn(context, nullptr)
119           << "value " << GetFriendlyName(context, id) << " is divergent";
120       uint32_t source = div.GetDivergenceSource(id);
121       opt::Instruction* def = def_use->GetDef(id);
122       opt::Instruction* source_def =
123           source == 0 ? nullptr : def_use->GetDef(source);
124       // First print data -> data dependencies.
125       while (source != 0 && source_def->opcode() != spv::Op::OpLabel) {
126         Warn(context, def_use->GetDef(id))
127             << "because " << GetFriendlyName(context, id) << " uses value "
128             << GetFriendlyName(context, source)
129             << "in its definition, which is divergent";
130         id = source;
131         def = source_def;
132         source = div.GetDivergenceSource(id);
133         source_def = def_use->GetDef(source);
134       }
135       if (source == 0) {
136         Warn(context, def) << "because it has a divergent definition";
137         break;
138       }
139       Warn(context, def) << "because it is conditionally set in block "
140                          << GetFriendlyName(context, source);
141       id = source;
142     }
143   }
144 }
145 }  // namespace
146 
CheckDivergentDerivatives(opt::IRContext * context)147 bool CheckDivergentDerivatives(opt::IRContext* context) {
148   DivergenceAnalysis div(*context);
149   for (opt::Function& func : *context->module()) {
150     div.Run(&func);
151     for (const opt::BasicBlock& bb : func) {
152       for (const opt::Instruction& inst : bb) {
153         if (InstructionHasDerivative(inst) &&
154             div.GetDivergenceLevel(bb.id()) >
155                 DivergenceAnalysis::DivergenceLevel::kPartiallyUniform) {
156           Warn(context, nullptr)
157               << "derivative with divergent control flow"
158               << " located in block " << GetFriendlyName(context, bb.id());
159           PrintDivergenceFlow(context, div, bb.id());
160         }
161       }
162     }
163   }
164   return true;
165 }
166 
167 }  // namespace lints
168 }  // namespace lint
169 }  // namespace spvtools
170