• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/amd_ext_to_khr.h"
16 
17 #include <set>
18 #include <string>
19 
20 #include "ir_builder.h"
21 #include "source/opt/ir_context.h"
22 #include "spv-amd-shader-ballot.insts.inc"
23 #include "type_manager.h"
24 
25 namespace spvtools {
26 namespace opt {
27 
28 namespace {
29 
30 enum AmdShaderBallotExtOpcodes {
31   AmdShaderBallotSwizzleInvocationsAMD = 1,
32   AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
33   AmdShaderBallotWriteInvocationAMD = 3,
34   AmdShaderBallotMbcntAMD = 4
35 };
36 
37 enum AmdShaderTrinaryMinMaxExtOpCodes {
38   FMin3AMD = 1,
39   UMin3AMD = 2,
40   SMin3AMD = 3,
41   FMax3AMD = 4,
42   UMax3AMD = 5,
43   SMax3AMD = 6,
44   FMid3AMD = 7,
45   UMid3AMD = 8,
46   SMid3AMD = 9
47 };
48 
49 enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 };
50 
GetUIntType(IRContext * ctx)51 analysis::Type* GetUIntType(IRContext* ctx) {
52   analysis::Integer int_type(32, false);
53   return ctx->get_type_mgr()->GetRegisteredType(&int_type);
54 }
55 
56 // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
57 // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
58 // extended instruction set that corresponds to the trinary instruction being
59 // replaced.
60 template <GLSLstd450 opcode>
ReplaceTrinaryMinMax(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)61 bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
62                           const std::vector<const analysis::Constant*>&) {
63   uint32_t glsl405_ext_inst_id =
64       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
65   if (glsl405_ext_inst_id == 0) {
66     ctx->AddExtInstImport("GLSL.std.450");
67     glsl405_ext_inst_id =
68         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
69   }
70 
71   InstructionBuilder ir_builder(
72       ctx, inst,
73       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
74 
75   uint32_t op1 = inst->GetSingleWordInOperand(2);
76   uint32_t op2 = inst->GetSingleWordInOperand(3);
77   uint32_t op3 = inst->GetSingleWordInOperand(4);
78 
79   Instruction* temp = ir_builder.AddNaryExtendedInstruction(
80       inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
81 
82   Instruction::OperandList new_operands;
83   new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
84   new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
85                           {static_cast<uint32_t>(opcode)}});
86   new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
87   new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
88 
89   inst->SetInOperands(std::move(new_operands));
90   ctx->UpdateDefUse(inst);
91   return true;
92 }
93 
94 // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
95 // max(b,c)|. The three parameters are the opcode that correspond to the min,
96 // max, and clamp operations for the type of the instruction being replaced.
97 template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
ReplaceTrinaryMid(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)98 bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
99                        const std::vector<const analysis::Constant*>&) {
100   uint32_t glsl405_ext_inst_id =
101       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
102   if (glsl405_ext_inst_id == 0) {
103     ctx->AddExtInstImport("GLSL.std.450");
104     glsl405_ext_inst_id =
105         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
106   }
107 
108   InstructionBuilder ir_builder(
109       ctx, inst,
110       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
111 
112   uint32_t op1 = inst->GetSingleWordInOperand(2);
113   uint32_t op2 = inst->GetSingleWordInOperand(3);
114   uint32_t op3 = inst->GetSingleWordInOperand(4);
115 
116   Instruction* min = ir_builder.AddNaryExtendedInstruction(
117       inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
118       {op2, op3});
119   Instruction* max = ir_builder.AddNaryExtendedInstruction(
120       inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
121       {op2, op3});
122 
123   Instruction::OperandList new_operands;
124   new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
125   new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
126                           {static_cast<uint32_t>(clamp_opcode)}});
127   new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
128   new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
129   new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
130 
131   inst->SetInOperands(std::move(new_operands));
132   ctx->UpdateDefUse(inst);
133   return true;
134 }
135 
136 // Returns a folding rule that will replace the opcode with |opcode| and add
137 // the capabilities required.  The folding rule assumes it is folding an
138 // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
139 template <SpvOp new_opcode>
ReplaceGroupNonuniformOperationOpCode(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)140 bool ReplaceGroupNonuniformOperationOpCode(
141     IRContext* ctx, Instruction* inst,
142     const std::vector<const analysis::Constant*>&) {
143   switch (new_opcode) {
144     case SpvOpGroupNonUniformIAdd:
145     case SpvOpGroupNonUniformFAdd:
146     case SpvOpGroupNonUniformUMin:
147     case SpvOpGroupNonUniformSMin:
148     case SpvOpGroupNonUniformFMin:
149     case SpvOpGroupNonUniformUMax:
150     case SpvOpGroupNonUniformSMax:
151     case SpvOpGroupNonUniformFMax:
152       break;
153     default:
154       assert(
155           false &&
156           "Should be replacing with a group non uniform arithmetic operation.");
157   }
158 
159   switch (inst->opcode()) {
160     case SpvOpGroupIAddNonUniformAMD:
161     case SpvOpGroupFAddNonUniformAMD:
162     case SpvOpGroupUMinNonUniformAMD:
163     case SpvOpGroupSMinNonUniformAMD:
164     case SpvOpGroupFMinNonUniformAMD:
165     case SpvOpGroupUMaxNonUniformAMD:
166     case SpvOpGroupSMaxNonUniformAMD:
167     case SpvOpGroupFMaxNonUniformAMD:
168       break;
169     default:
170       assert(false &&
171              "Should be replacing a group non uniform arithmetic operation.");
172   }
173 
174   ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
175   inst->SetOpcode(new_opcode);
176   return true;
177 }
178 
179 // Returns a folding rule that will replace the SwizzleInvocationsAMD extended
180 // instruction in the SPV_AMD_shader_ballot extension.
181 //
182 // The instruction
183 //
184 //  %offset = OpConstantComposite %v3uint %x %y %z %w
185 //  %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
186 //
187 // is replaced with
188 //
189 // potentially new constants and types
190 //
191 // clang-format off
192 //         %uint_max = OpConstant %uint 0xFFFFFFFF
193 //           %v4uint = OpTypeVector %uint 4
194 //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
195 //             %null = OpConstantNull %type
196 // clang-format on
197 //
198 // and the following code in the function body
199 //
200 // clang-format off
201 //         %id = OpLoad %uint %SubgroupLocalInvocationId
202 //   %quad_idx = OpBitwiseAnd %uint %id %uint_3
203 //   %quad_ldr = OpBitwiseXor %uint %id %quad_idx
204 //  %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
205 // %target_inv = OpIAdd %uint %quad_ldr %my_offset
206 //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
207 //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
208 //     %result = OpSelect %type %is_active %shuffle %null
209 // clang-format on
210 //
211 // Also adding the capabilities and builtins that are needed.
ReplaceSwizzleInvocations(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)212 bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
213                                const std::vector<const analysis::Constant*>&) {
214   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
215   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
216 
217   ctx->AddExtension("SPV_KHR_shader_ballot");
218   ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
219   ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
220 
221   InstructionBuilder ir_builder(
222       ctx, inst,
223       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
224 
225   uint32_t data_id = inst->GetSingleWordInOperand(2);
226   uint32_t offset_id = inst->GetSingleWordInOperand(3);
227 
228   // Get the subgroup invocation id.
229   uint32_t var_id =
230       ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
231   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
232   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
233   Instruction* var_ptr_type =
234       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
235   uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
236 
237   Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
238 
239   uint32_t quad_mask = ir_builder.GetUintConstantId(3);
240 
241   // This gives the offset in the group of 4 of this invocation.
242   Instruction* quad_idx = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseAnd,
243                                                  id->result_id(), quad_mask);
244 
245   // Get the invocation id of the first invocation in the group of 4.
246   Instruction* quad_ldr = ir_builder.AddBinaryOp(
247       uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
248 
249   // Get the offset of the target invocation from the offset vector.
250   Instruction* my_offset =
251       ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic, offset_id,
252                              quad_idx->result_id());
253 
254   // Determine the index of the invocation to read from.
255   Instruction* target_inv = ir_builder.AddBinaryOp(
256       uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
257 
258   // Do the group operations
259   uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
260   uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
261   const auto* ballot_value_const = const_mgr->GetConstant(
262       type_mgr->GetUIntVectorType(4),
263       {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
264   Instruction* ballot_value =
265       const_mgr->GetDefiningInstruction(ballot_value_const);
266   Instruction* is_active = ir_builder.AddNaryOp(
267       type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
268       {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
269   Instruction* shuffle =
270       ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
271                            {subgroup_scope, data_id, target_inv->result_id()});
272 
273   // Create the null constant to use in the select.
274   const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
275                                             std::vector<uint32_t>());
276   Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
277 
278   // Build the select.
279   inst->SetOpcode(SpvOpSelect);
280   Instruction::OperandList new_operands;
281   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
282   new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
283   new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
284 
285   inst->SetInOperands(std::move(new_operands));
286   ctx->UpdateDefUse(inst);
287   return true;
288 }
289 
290 // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
291 // extended instruction in the SPV_AMD_shader_ballot extension.
292 //
293 // The instruction
294 //
295 //    %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
296 //  %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
297 //
298 // is replaced with
299 //
300 // potentially new constants and types
301 //
302 // clang-format off
303 // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
304 //         %uint_max = OpConstant %uint 0xFFFFFFFF
305 //           %v4uint = OpTypeVector %uint 4
306 //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
307 // clang-format on
308 //
309 // and the following code in the function body
310 //
311 // clang-format off
312 //         %id = OpLoad %uint %SubgroupLocalInvocationId
313 //   %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
314 //        %and = OpBitwiseAnd %uint %id %and_mask
315 //         %or = OpBitwiseOr %uint %and %uint_y
316 // %target_inv = OpBitwiseXor %uint %or %uint_z
317 //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
318 //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
319 //     %result = OpSelect %type %is_active %shuffle %uint_0
320 // clang-format on
321 //
322 // Also adding the capabilities and builtins that are needed.
ReplaceSwizzleInvocationsMasked(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)323 bool ReplaceSwizzleInvocationsMasked(
324     IRContext* ctx, Instruction* inst,
325     const std::vector<const analysis::Constant*>&) {
326   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
327   analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
328   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
329 
330   ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
331   ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
332 
333   InstructionBuilder ir_builder(
334       ctx, inst,
335       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
336 
337   // Get the operands to inst, and the components of the mask
338   uint32_t data_id = inst->GetSingleWordInOperand(2);
339 
340   Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
341   assert(mask_inst->opcode() == SpvOpConstantComposite &&
342          "The mask is suppose to be a vector constant.");
343   assert(mask_inst->NumInOperands() == 3 &&
344          "The mask is suppose to have 3 components.");
345 
346   uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
347   uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
348   uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
349 
350   // Get the subgroup invocation id.
351   uint32_t var_id =
352       ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
353   ctx->AddExtension("SPV_KHR_shader_ballot");
354   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
355   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
356   Instruction* var_ptr_type =
357       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
358   uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
359 
360   Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
361 
362   // Do the bitwise operations.
363   uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
364   Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
365                                                  uint_x, mask_extended);
366   Instruction* and_result = ir_builder.AddBinaryOp(
367       uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
368   Instruction* or_result = ir_builder.AddBinaryOp(
369       uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
370   Instruction* target_inv = ir_builder.AddBinaryOp(
371       uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
372 
373   // Do the group operations
374   uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
375   uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
376   const auto* ballot_value_const = const_mgr->GetConstant(
377       type_mgr->GetUIntVectorType(4),
378       {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
379   Instruction* ballot_value =
380       const_mgr->GetDefiningInstruction(ballot_value_const);
381   Instruction* is_active = ir_builder.AddNaryOp(
382       type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
383       {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
384   Instruction* shuffle =
385       ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
386                            {subgroup_scope, data_id, target_inv->result_id()});
387 
388   // Create the null constant to use in the select.
389   const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
390                                             std::vector<uint32_t>());
391   Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
392 
393   // Build the select.
394   inst->SetOpcode(SpvOpSelect);
395   Instruction::OperandList new_operands;
396   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
397   new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
398   new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
399 
400   inst->SetInOperands(std::move(new_operands));
401   ctx->UpdateDefUse(inst);
402   return true;
403 }
404 
405 // Returns a folding rule that will replace the WriteInvocationAMD extended
406 // instruction in the SPV_AMD_shader_ballot extension.
407 //
408 // The instruction
409 //
410 // clang-format off
411 //    %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
412 // clang-format on
413 //
414 // with
415 //
416 //     %id = OpLoad %uint %SubgroupLocalInvocationId
417 //    %cmp = OpIEqual %bool %id %invocation_index
418 // %result = OpSelect %type %cmp %write_value %input_value
419 //
420 // Also adding the capabilities and builtins that are needed.
ReplaceWriteInvocation(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)421 bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
422                             const std::vector<const analysis::Constant*>&) {
423   uint32_t var_id =
424       ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
425   ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
426   ctx->AddExtension("SPV_KHR_shader_ballot");
427   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
428   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
429   Instruction* var_ptr_type =
430       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
431 
432   InstructionBuilder ir_builder(
433       ctx, inst,
434       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
435   Instruction* t =
436       ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
437   analysis::Bool bool_type;
438   uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
439   Instruction* cmp =
440       ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
441                              inst->GetSingleWordInOperand(4));
442 
443   // Build a select.
444   inst->SetOpcode(SpvOpSelect);
445   Instruction::OperandList new_operands;
446   new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
447   new_operands.push_back(inst->GetInOperand(3));
448   new_operands.push_back(inst->GetInOperand(2));
449 
450   inst->SetInOperands(std::move(new_operands));
451   ctx->UpdateDefUse(inst);
452   return true;
453 }
454 
455 // Returns a folding rule that will replace the MbcntAMD extended instruction in
456 // the SPV_AMD_shader_ballot extension.
457 //
458 // The instruction
459 //
460 //  %result = OpExtInst %uint %1 MbcntAMD %mask
461 //
462 // with
463 //
464 // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
465 // AMD's shader compiler expects a 64-bit integer mask.
466 //
467 //     %var = OpLoad %v4uint %SubgroupLtMaskKHR
468 // %shuffle = OpVectorShuffle %v2uint %var %var 0 1
469 //    %cast = OpBitcast %ulong %shuffle
470 //
471 // Perform the mask and count the bits.
472 //
473 //     %and = OpBitwiseAnd %ulong %cast %mask
474 //  %result = OpBitCount %uint %and
475 //
476 // Also adding the capabilities and builtins that are needed.
ReplaceMbcnt(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)477 bool ReplaceMbcnt(IRContext* context, Instruction* inst,
478                   const std::vector<const analysis::Constant*>&) {
479   analysis::TypeManager* type_mgr = context->get_type_mgr();
480   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
481 
482   uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
483   assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
484   context->AddCapability(SpvCapabilityGroupNonUniformBallot);
485   Instruction* var_inst = def_use_mgr->GetDef(var_id);
486   Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
487   Instruction* var_type =
488       def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
489   assert(var_type->opcode() == SpvOpTypeVector &&
490          "Variable is suppose to be a vector of 4 ints");
491 
492   // Get the type for the shuffle.
493   analysis::Vector temp_type(GetUIntType(context), 2);
494   const analysis::Type* shuffle_type =
495       context->get_type_mgr()->GetRegisteredType(&temp_type);
496   uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
497 
498   uint32_t mask_id = inst->GetSingleWordInOperand(2);
499   Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
500 
501   // Testing with amd's shader compiler shows that a 64-bit mask is expected.
502   assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
503   assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
504 
505   InstructionBuilder ir_builder(
506       context, inst,
507       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
508   Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
509   Instruction* shuffle = ir_builder.AddVectorShuffle(
510       shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
511   Instruction* bitcast = ir_builder.AddUnaryOp(
512       mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
513   Instruction* t = ir_builder.AddBinaryOp(mask_inst->type_id(), SpvOpBitwiseAnd,
514                                           bitcast->result_id(), mask_id);
515 
516   inst->SetOpcode(SpvOpBitCount);
517   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
518   context->UpdateDefUse(inst);
519   return true;
520 }
521 
522 // A folding rule that will replace the CubeFaceCoordAMD extended
523 // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding is
524 // successful.
525 //
526 // The instruction
527 //
528 //  %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input
529 //
530 // with
531 //
532 //             %x = OpCompositeExtract %float %input 0
533 //             %y = OpCompositeExtract %float %input 1
534 //             %z = OpCompositeExtract %float %input 2
535 //            %nx = OpFNegate %float %x
536 //            %ny = OpFNegate %float %y
537 //            %nz = OpFNegate %float %z
538 //            %ax = OpExtInst %float %n_1 FAbs %x
539 //            %ay = OpExtInst %float %n_1 FAbs %y
540 //            %az = OpExtInst %float %n_1 FAbs %z
541 //      %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax
542 //          %amax = OpExtInst %float %n_1 FMax %az %amax_x_y
543 //        %cubema = OpFMul %float %float_2 %amax
544 //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
545 //  %not_is_z_max = OpLogicalNot %bool %is_z_max
546 //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
547 //      %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x
548 //      %is_z_neg = OpFOrdLessThan %bool %z %float_0
549 // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x
550 //      %is_x_neg = OpFOrdLessThan %bool %x %float_0
551 // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz
552 //           %sel = OpSelect %float %is_y_max %x %cubesc_case_2
553 //        %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel
554 //      %is_y_neg = OpFOrdLessThan %bool %y %float_0
555 // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z
556 //        %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny
557 //          %cube = OpCompositeConstruct %v2float %cubesc %cubetc
558 //         %denom = OpCompositeConstruct %v2float %cubema %cubema
559 //           %div = OpFDiv %v2float %cube %denom
560 //        %result = OpFAdd %v2float %div %const
561 //
562 // Also adding the capabilities and builtins that are needed.
ReplaceCubeFaceCoord(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)563 bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
564                           const std::vector<const analysis::Constant*>&) {
565   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
566   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
567 
568   uint32_t float_type_id = type_mgr->GetFloatTypeId();
569   const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2);
570   uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type);
571   uint32_t bool_id = type_mgr->GetBoolTypeId();
572 
573   InstructionBuilder ir_builder(
574       ctx, inst,
575       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
576 
577   uint32_t input_id = inst->GetSingleWordInOperand(2);
578   uint32_t glsl405_ext_inst_id =
579       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
580   if (glsl405_ext_inst_id == 0) {
581     ctx->AddExtInstImport("GLSL.std.450");
582     glsl405_ext_inst_id =
583         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
584   }
585 
586   // Get the constants that will be used.
587   uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
588   uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
589   uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5);
590   const analysis::Constant* vec_const =
591       const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id});
592   uint32_t vec_const_id =
593       const_mgr->GetDefiningInstruction(vec_const)->result_id();
594 
595   // Extract the input values.
596   Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
597   Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
598   Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
599 
600   // Negate the input values.
601   Instruction* nx =
602       ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, x->result_id());
603   Instruction* ny =
604       ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, y->result_id());
605   Instruction* nz =
606       ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, z->result_id());
607 
608   // Get the abolsute values of the inputs.
609   Instruction* ax = ir_builder.AddNaryExtendedInstruction(
610       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
611   Instruction* ay = ir_builder.AddNaryExtendedInstruction(
612       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
613   Instruction* az = ir_builder.AddNaryExtendedInstruction(
614       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
615 
616   // Find which values are negative.  Used in later computations.
617   Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
618                                                  z->result_id(), f0_const_id);
619   Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
620                                                  y->result_id(), f0_const_id);
621   Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
622                                                  x->result_id(), f0_const_id);
623 
624   // Compute cubema
625   Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
626       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
627       {ax->result_id(), ay->result_id()});
628   Instruction* amax = ir_builder.AddNaryExtendedInstruction(
629       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
630       {az->result_id(), amax_x_y->result_id()});
631   Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, SpvOpFMul,
632                                                f2_const_id, amax->result_id());
633 
634   // Do the comparisons needed for computing cubesc and cubetc.
635   Instruction* is_z_max =
636       ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual,
637                              az->result_id(), amax_x_y->result_id());
638   Instruction* not_is_z_max =
639       ir_builder.AddUnaryOp(bool_id, SpvOpLogicalNot, is_z_max->result_id());
640   Instruction* y_gr_x = ir_builder.AddBinaryOp(
641       bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id());
642   Instruction* is_y_max = ir_builder.AddBinaryOp(
643       bool_id, SpvOpLogicalAnd, not_is_z_max->result_id(), y_gr_x->result_id());
644 
645   // Select the correct value for cubesc.
646   Instruction* cubesc_case_1 = ir_builder.AddSelect(
647       float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id());
648   Instruction* cubesc_case_2 = ir_builder.AddSelect(
649       float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id());
650   Instruction* sel =
651       ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(),
652                            cubesc_case_2->result_id());
653   Instruction* cubesc =
654       ir_builder.AddSelect(float_type_id, is_z_max->result_id(),
655                            cubesc_case_1->result_id(), sel->result_id());
656 
657   // Select the correct value for cubetc.
658   Instruction* cubetc_case_1 = ir_builder.AddSelect(
659       float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id());
660   Instruction* cubetc =
661       ir_builder.AddSelect(float_type_id, is_y_max->result_id(),
662                            cubetc_case_1->result_id(), ny->result_id());
663 
664   // Do the division
665   Instruction* cube = ir_builder.AddCompositeConstruct(
666       v2_float_type_id, {cubesc->result_id(), cubetc->result_id()});
667   Instruction* denom = ir_builder.AddCompositeConstruct(
668       v2_float_type_id, {cubema->result_id(), cubema->result_id()});
669   Instruction* div = ir_builder.AddBinaryOp(
670       v2_float_type_id, SpvOpFDiv, cube->result_id(), denom->result_id());
671 
672   // Get the final result by adding 0.5 to |div|.
673   inst->SetOpcode(SpvOpFAdd);
674   Instruction::OperandList new_operands;
675   new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}});
676   new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}});
677 
678   inst->SetInOperands(std::move(new_operands));
679   ctx->UpdateDefUse(inst);
680   return true;
681 }
682 
683 // A folding rule that will replace the CubeFaceIndexAMD extended
684 // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding
685 // is successful.
686 //
687 // The instruction
688 //
689 //  %result = OpExtInst %float %1 CubeFaceIndexAMD %input
690 //
691 // with
692 //
693 //             %x = OpCompositeExtract %float %input 0
694 //             %y = OpCompositeExtract %float %input 1
695 //             %z = OpCompositeExtract %float %input 2
696 //            %ax = OpExtInst %float %n_1 FAbs %x
697 //            %ay = OpExtInst %float %n_1 FAbs %y
698 //            %az = OpExtInst %float %n_1 FAbs %z
699 //      %is_z_neg = OpFOrdLessThan %bool %z %float_0
700 //      %is_y_neg = OpFOrdLessThan %bool %y %float_0
701 //      %is_x_neg = OpFOrdLessThan %bool %x %float_0
702 //      %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay
703 //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
704 //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
705 //        %case_z = OpSelect %float %is_z_neg %float_5 %float4
706 //        %case_y = OpSelect %float %is_y_neg %float_3 %float2
707 //        %case_x = OpSelect %float %is_x_neg %float_1 %float0
708 //           %sel = OpSelect %float %y_gt_x %case_y %case_x
709 //        %result = OpSelect %float %is_z_max %case_z %sel
710 //
711 // Also adding the capabilities and builtins that are needed.
ReplaceCubeFaceIndex(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)712 bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
713                           const std::vector<const analysis::Constant*>&) {
714   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
715   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
716 
717   uint32_t float_type_id = type_mgr->GetFloatTypeId();
718   uint32_t bool_id = type_mgr->GetBoolTypeId();
719 
720   InstructionBuilder ir_builder(
721       ctx, inst,
722       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
723 
724   uint32_t input_id = inst->GetSingleWordInOperand(2);
725   uint32_t glsl405_ext_inst_id =
726       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
727   if (glsl405_ext_inst_id == 0) {
728     ctx->AddExtInstImport("GLSL.std.450");
729     glsl405_ext_inst_id =
730         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
731   }
732 
733   // Get the constants that will be used.
734   uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
735   uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0);
736   uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
737   uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0);
738   uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0);
739   uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0);
740 
741   // Extract the input values.
742   Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
743   Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
744   Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
745 
746   // Get the absolute values of the inputs.
747   Instruction* ax = ir_builder.AddNaryExtendedInstruction(
748       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
749   Instruction* ay = ir_builder.AddNaryExtendedInstruction(
750       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
751   Instruction* az = ir_builder.AddNaryExtendedInstruction(
752       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
753 
754   // Find which values are negative.  Used in later computations.
755   Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
756                                                  z->result_id(), f0_const_id);
757   Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
758                                                  y->result_id(), f0_const_id);
759   Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan,
760                                                  x->result_id(), f0_const_id);
761 
762   // Find the max value.
763   Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
764       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
765       {ax->result_id(), ay->result_id()});
766   Instruction* is_z_max =
767       ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual,
768                              az->result_id(), amax_x_y->result_id());
769   Instruction* y_gr_x = ir_builder.AddBinaryOp(
770       bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id());
771 
772   // Get the value for each case.
773   Instruction* case_z = ir_builder.AddSelect(
774       float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id);
775   Instruction* case_y = ir_builder.AddSelect(
776       float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id);
777   Instruction* case_x = ir_builder.AddSelect(
778       float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id);
779 
780   // Select the correct case.
781   Instruction* sel =
782       ir_builder.AddSelect(float_type_id, y_gr_x->result_id(),
783                            case_y->result_id(), case_x->result_id());
784 
785   // Get the final result by adding 0.5 to |div|.
786   inst->SetOpcode(SpvOpSelect);
787   Instruction::OperandList new_operands;
788   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}});
789   new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}});
790   new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}});
791 
792   inst->SetInOperands(std::move(new_operands));
793   ctx->UpdateDefUse(inst);
794   return true;
795 }
796 
797 // A folding rule that will replace the TimeAMD extended instruction in the
798 // SPV_AMD_gcn_shader_ballot.  It returns true if the folding is successful.
799 // It returns False, otherwise.
800 //
801 // The instruction
802 //
803 //  %result = OpExtInst %uint64 %1 TimeAMD
804 //
805 // with
806 //
807 //  %result = OpReadClockKHR %uint64 %uint_3
808 //
809 // NOTE: TimeAMD uses subgroup scope (it is not a real time clock).
ReplaceTimeAMD(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)810 bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst,
811                     const std::vector<const analysis::Constant*>&) {
812   InstructionBuilder ir_builder(
813       ctx, inst,
814       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
815   ctx->AddExtension("SPV_KHR_shader_clock");
816   ctx->AddCapability(SpvCapabilityShaderClockKHR);
817 
818   inst->SetOpcode(SpvOpReadClockKHR);
819   Instruction::OperandList args;
820   uint32_t subgroup_scope_id = ir_builder.GetUintConstantId(SpvScopeSubgroup);
821   args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}});
822   inst->SetInOperands(std::move(args));
823   ctx->UpdateDefUse(inst);
824 
825   return true;
826 }
827 
828 class AmdExtFoldingRules : public FoldingRules {
829  public:
AmdExtFoldingRules(IRContext * ctx)830   explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
831 
832  protected:
AddFoldingRules()833   virtual void AddFoldingRules() override {
834     rules_[SpvOpGroupIAddNonUniformAMD].push_back(
835         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformIAdd>);
836     rules_[SpvOpGroupFAddNonUniformAMD].push_back(
837         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFAdd>);
838     rules_[SpvOpGroupUMinNonUniformAMD].push_back(
839         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMin>);
840     rules_[SpvOpGroupSMinNonUniformAMD].push_back(
841         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMin>);
842     rules_[SpvOpGroupFMinNonUniformAMD].push_back(
843         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMin>);
844     rules_[SpvOpGroupUMaxNonUniformAMD].push_back(
845         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMax>);
846     rules_[SpvOpGroupSMaxNonUniformAMD].push_back(
847         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMax>);
848     rules_[SpvOpGroupFMaxNonUniformAMD].push_back(
849         ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMax>);
850 
851     uint32_t extension_id =
852         context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
853 
854     if (extension_id != 0) {
855       ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
856           .push_back(ReplaceSwizzleInvocations);
857       ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
858           .push_back(ReplaceSwizzleInvocationsMasked);
859       ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
860           ReplaceWriteInvocation);
861       ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
862           ReplaceMbcnt);
863     }
864 
865     extension_id = context()->module()->GetExtInstImportId(
866         "SPV_AMD_shader_trinary_minmax");
867 
868     if (extension_id != 0) {
869       ext_rules_[{extension_id, FMin3AMD}].push_back(
870           ReplaceTrinaryMinMax<GLSLstd450FMin>);
871       ext_rules_[{extension_id, UMin3AMD}].push_back(
872           ReplaceTrinaryMinMax<GLSLstd450UMin>);
873       ext_rules_[{extension_id, SMin3AMD}].push_back(
874           ReplaceTrinaryMinMax<GLSLstd450SMin>);
875       ext_rules_[{extension_id, FMax3AMD}].push_back(
876           ReplaceTrinaryMinMax<GLSLstd450FMax>);
877       ext_rules_[{extension_id, UMax3AMD}].push_back(
878           ReplaceTrinaryMinMax<GLSLstd450UMax>);
879       ext_rules_[{extension_id, SMax3AMD}].push_back(
880           ReplaceTrinaryMinMax<GLSLstd450SMax>);
881       ext_rules_[{extension_id, FMid3AMD}].push_back(
882           ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
883       ext_rules_[{extension_id, UMid3AMD}].push_back(
884           ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
885       ext_rules_[{extension_id, SMid3AMD}].push_back(
886           ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
887     }
888 
889     extension_id =
890         context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader");
891 
892     if (extension_id != 0) {
893       ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back(
894           ReplaceCubeFaceCoord);
895       ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back(
896           ReplaceCubeFaceIndex);
897       ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD);
898     }
899   }
900 };
901 
902 class AmdExtConstFoldingRules : public ConstantFoldingRules {
903  public:
AmdExtConstFoldingRules(IRContext * ctx)904   AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
905 
906  protected:
AddFoldingRules()907   virtual void AddFoldingRules() override {}
908 };
909 
910 }  // namespace
911 
Process()912 Pass::Status AmdExtensionToKhrPass::Process() {
913   bool changed = false;
914 
915   // Traverse the body of the functions to replace instructions that require
916   // the extensions.
917   InstructionFolder folder(
918       context(),
919       std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
920       MakeUnique<AmdExtConstFoldingRules>(context()));
921   for (Function& func : *get_module()) {
922     func.ForEachInst([&changed, &folder](Instruction* inst) {
923       if (folder.FoldInstruction(inst)) {
924         changed = true;
925       }
926     });
927   }
928 
929   // Now that instruction that require the extensions have been removed, we can
930   // remove the extension instructions.
931   std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot",
932                                          "SPV_AMD_shader_trinary_minmax",
933                                          "SPV_AMD_gcn_shader"};
934 
935   std::vector<Instruction*> to_be_killed;
936   for (Instruction& inst : context()->module()->extensions()) {
937     if (inst.opcode() == SpvOpExtension) {
938       if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
939         to_be_killed.push_back(&inst);
940       }
941     }
942   }
943 
944   for (Instruction& inst : context()->ext_inst_imports()) {
945     if (inst.opcode() == SpvOpExtInstImport) {
946       if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
947         to_be_killed.push_back(&inst);
948       }
949     }
950   }
951 
952   for (Instruction* inst : to_be_killed) {
953     context()->KillInst(inst);
954     changed = true;
955   }
956 
957   // The replacements that take place use instructions that are missing before
958   // SPIR-V 1.3. If we changed something, we will have to make sure the version
959   // is at least SPIR-V 1.3 to make sure those instruction can be used.
960   if (changed) {
961     uint32_t version = get_module()->version();
962     if (version < 0x00010300 /*1.3*/) {
963       get_module()->set_version(0x00010300);
964     }
965   }
966   return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
967 }
968 
969 }  // namespace opt
970 }  // namespace spvtools
971