• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/amd_ext_to_khr.h"
16 
17 #include <set>
18 #include <string>
19 
20 #include "ir_builder.h"
21 #include "source/opt/ir_context.h"
22 #include "spv-amd-shader-ballot.insts.inc"
23 #include "type_manager.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 enum AmdShaderBallotExtOpcodes {
30   AmdShaderBallotSwizzleInvocationsAMD = 1,
31   AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
32   AmdShaderBallotWriteInvocationAMD = 3,
33   AmdShaderBallotMbcntAMD = 4
34 };
35 
36 enum AmdShaderTrinaryMinMaxExtOpCodes {
37   FMin3AMD = 1,
38   UMin3AMD = 2,
39   SMin3AMD = 3,
40   FMax3AMD = 4,
41   UMax3AMD = 5,
42   SMax3AMD = 6,
43   FMid3AMD = 7,
44   UMid3AMD = 8,
45   SMid3AMD = 9
46 };
47 
48 enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 };
49 
GetUIntType(IRContext * ctx)50 analysis::Type* GetUIntType(IRContext* ctx) {
51   analysis::Integer int_type(32, false);
52   return ctx->get_type_mgr()->GetRegisteredType(&int_type);
53 }
54 
55 // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
56 // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
57 // extended instruction set that corresponds to the trinary instruction being
58 // replaced.
59 template <GLSLstd450 opcode>
ReplaceTrinaryMinMax(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)60 bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
61                           const std::vector<const analysis::Constant*>&) {
62   uint32_t glsl405_ext_inst_id =
63       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
64   if (glsl405_ext_inst_id == 0) {
65     ctx->AddExtInstImport("GLSL.std.450");
66     glsl405_ext_inst_id =
67         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
68   }
69 
70   InstructionBuilder ir_builder(
71       ctx, inst,
72       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
73 
74   uint32_t op1 = inst->GetSingleWordInOperand(2);
75   uint32_t op2 = inst->GetSingleWordInOperand(3);
76   uint32_t op3 = inst->GetSingleWordInOperand(4);
77 
78   Instruction* temp = ir_builder.AddNaryExtendedInstruction(
79       inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
80 
81   Instruction::OperandList new_operands;
82   new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
83   new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
84                           {static_cast<uint32_t>(opcode)}});
85   new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
86   new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
87 
88   inst->SetInOperands(std::move(new_operands));
89   ctx->UpdateDefUse(inst);
90   return true;
91 }
92 
93 // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
94 // max(b,c)|. The three parameters are the opcode that correspond to the min,
95 // max, and clamp operations for the type of the instruction being replaced.
96 template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
ReplaceTrinaryMid(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)97 bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
98                        const std::vector<const analysis::Constant*>&) {
99   uint32_t glsl405_ext_inst_id =
100       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
101   if (glsl405_ext_inst_id == 0) {
102     ctx->AddExtInstImport("GLSL.std.450");
103     glsl405_ext_inst_id =
104         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
105   }
106 
107   InstructionBuilder ir_builder(
108       ctx, inst,
109       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
110 
111   uint32_t op1 = inst->GetSingleWordInOperand(2);
112   uint32_t op2 = inst->GetSingleWordInOperand(3);
113   uint32_t op3 = inst->GetSingleWordInOperand(4);
114 
115   Instruction* min = ir_builder.AddNaryExtendedInstruction(
116       inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
117       {op2, op3});
118   Instruction* max = ir_builder.AddNaryExtendedInstruction(
119       inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
120       {op2, op3});
121 
122   Instruction::OperandList new_operands;
123   new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
124   new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
125                           {static_cast<uint32_t>(clamp_opcode)}});
126   new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
127   new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
128   new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
129 
130   inst->SetInOperands(std::move(new_operands));
131   ctx->UpdateDefUse(inst);
132   return true;
133 }
134 
135 // Returns a folding rule that will replace the opcode with |opcode| and add
136 // the capabilities required.  The folding rule assumes it is folding an
137 // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
138 template <spv::Op new_opcode>
ReplaceGroupNonuniformOperationOpCode(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)139 bool ReplaceGroupNonuniformOperationOpCode(
140     IRContext* ctx, Instruction* inst,
141     const std::vector<const analysis::Constant*>&) {
142   switch (new_opcode) {
143     case spv::Op::OpGroupNonUniformIAdd:
144     case spv::Op::OpGroupNonUniformFAdd:
145     case spv::Op::OpGroupNonUniformUMin:
146     case spv::Op::OpGroupNonUniformSMin:
147     case spv::Op::OpGroupNonUniformFMin:
148     case spv::Op::OpGroupNonUniformUMax:
149     case spv::Op::OpGroupNonUniformSMax:
150     case spv::Op::OpGroupNonUniformFMax:
151       break;
152     default:
153       assert(
154           false &&
155           "Should be replacing with a group non uniform arithmetic operation.");
156   }
157 
158   switch (inst->opcode()) {
159     case spv::Op::OpGroupIAddNonUniformAMD:
160     case spv::Op::OpGroupFAddNonUniformAMD:
161     case spv::Op::OpGroupUMinNonUniformAMD:
162     case spv::Op::OpGroupSMinNonUniformAMD:
163     case spv::Op::OpGroupFMinNonUniformAMD:
164     case spv::Op::OpGroupUMaxNonUniformAMD:
165     case spv::Op::OpGroupSMaxNonUniformAMD:
166     case spv::Op::OpGroupFMaxNonUniformAMD:
167       break;
168     default:
169       assert(false &&
170              "Should be replacing a group non uniform arithmetic operation.");
171   }
172 
173   ctx->AddCapability(spv::Capability::GroupNonUniformArithmetic);
174   inst->SetOpcode(new_opcode);
175   return true;
176 }
177 
178 // Returns a folding rule that will replace the SwizzleInvocationsAMD extended
179 // instruction in the SPV_AMD_shader_ballot extension.
180 //
181 // The instruction
182 //
183 //  %offset = OpConstantComposite %v3uint %x %y %z %w
184 //  %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
185 //
186 // is replaced with
187 //
188 // potentially new constants and types
189 //
190 // clang-format off
191 //         %uint_max = OpConstant %uint 0xFFFFFFFF
192 //           %v4uint = OpTypeVector %uint 4
193 //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
194 //             %null = OpConstantNull %type
195 // clang-format on
196 //
197 // and the following code in the function body
198 //
199 // clang-format off
200 //         %id = OpLoad %uint %SubgroupLocalInvocationId
201 //   %quad_idx = OpBitwiseAnd %uint %id %uint_3
202 //   %quad_ldr = OpBitwiseXor %uint %id %quad_idx
203 //  %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
204 // %target_inv = OpIAdd %uint %quad_ldr %my_offset
205 //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
206 //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
207 //     %result = OpSelect %type %is_active %shuffle %null
208 // clang-format on
209 //
210 // Also adding the capabilities and builtins that are needed.
ReplaceSwizzleInvocations(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)211 bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
212                                const std::vector<const analysis::Constant*>&) {
213   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
214   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
215 
216   ctx->AddExtension("SPV_KHR_shader_ballot");
217   ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
218   ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
219 
220   InstructionBuilder ir_builder(
221       ctx, inst,
222       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
223 
224   uint32_t data_id = inst->GetSingleWordInOperand(2);
225   uint32_t offset_id = inst->GetSingleWordInOperand(3);
226 
227   // Get the subgroup invocation id.
228   uint32_t var_id = ctx->GetBuiltinInputVarId(
229       uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
230   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
231   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
232   Instruction* var_ptr_type =
233       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
234   uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
235 
236   Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
237 
238   uint32_t quad_mask = ir_builder.GetUintConstantId(3);
239 
240   // This gives the offset in the group of 4 of this invocation.
241   Instruction* quad_idx = ir_builder.AddBinaryOp(
242       uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask);
243 
244   // Get the invocation id of the first invocation in the group of 4.
245   Instruction* quad_ldr =
246       ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor,
247                              id->result_id(), quad_idx->result_id());
248 
249   // Get the offset of the target invocation from the offset vector.
250   Instruction* my_offset =
251       ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic,
252                              offset_id, quad_idx->result_id());
253 
254   // Determine the index of the invocation to read from.
255   Instruction* target_inv =
256       ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd,
257                              quad_ldr->result_id(), my_offset->result_id());
258 
259   // Do the group operations
260   uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
261   uint32_t subgroup_scope =
262       ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
263   const auto* ballot_value_const = const_mgr->GetConstant(
264       type_mgr->GetUIntVectorType(4),
265       {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
266   Instruction* ballot_value =
267       const_mgr->GetDefiningInstruction(ballot_value_const);
268   Instruction* is_active = ir_builder.AddNaryOp(
269       type_mgr->GetBoolTypeId(), spv::Op::OpGroupNonUniformBallotBitExtract,
270       {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
271   Instruction* shuffle =
272       ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
273                            {subgroup_scope, data_id, target_inv->result_id()});
274 
275   // Create the null constant to use in the select.
276   const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
277                                             std::vector<uint32_t>());
278   Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
279 
280   // Build the select.
281   inst->SetOpcode(spv::Op::OpSelect);
282   Instruction::OperandList new_operands;
283   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
284   new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
285   new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
286 
287   inst->SetInOperands(std::move(new_operands));
288   ctx->UpdateDefUse(inst);
289   return true;
290 }
291 
292 // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
293 // extended instruction in the SPV_AMD_shader_ballot extension.
294 //
295 // The instruction
296 //
297 //    %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
298 //  %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
299 //
300 // is replaced with
301 //
302 // potentially new constants and types
303 //
304 // clang-format off
305 // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
306 //         %uint_max = OpConstant %uint 0xFFFFFFFF
307 //           %v4uint = OpTypeVector %uint 4
308 //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
309 // clang-format on
310 //
311 // and the following code in the function body
312 //
313 // clang-format off
314 //         %id = OpLoad %uint %SubgroupLocalInvocationId
315 //   %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
316 //        %and = OpBitwiseAnd %uint %id %and_mask
317 //         %or = OpBitwiseOr %uint %and %uint_y
318 // %target_inv = OpBitwiseXor %uint %or %uint_z
319 //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
320 //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
321 //     %result = OpSelect %type %is_active %shuffle %uint_0
322 // clang-format on
323 //
324 // Also adding the capabilities and builtins that are needed.
ReplaceSwizzleInvocationsMasked(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)325 bool ReplaceSwizzleInvocationsMasked(
326     IRContext* ctx, Instruction* inst,
327     const std::vector<const analysis::Constant*>&) {
328   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
329   analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
330   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
331 
332   ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
333   ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
334 
335   InstructionBuilder ir_builder(
336       ctx, inst,
337       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
338 
339   // Get the operands to inst, and the components of the mask
340   uint32_t data_id = inst->GetSingleWordInOperand(2);
341 
342   Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
343   assert(mask_inst->opcode() == spv::Op::OpConstantComposite &&
344          "The mask is suppose to be a vector constant.");
345   assert(mask_inst->NumInOperands() == 3 &&
346          "The mask is suppose to have 3 components.");
347 
348   uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
349   uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
350   uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
351 
352   // Get the subgroup invocation id.
353   uint32_t var_id = ctx->GetBuiltinInputVarId(
354       uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
355   ctx->AddExtension("SPV_KHR_shader_ballot");
356   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
357   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
358   Instruction* var_ptr_type =
359       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
360   uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
361 
362   Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
363 
364   // Do the bitwise operations.
365   uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
366   Instruction* and_mask = ir_builder.AddBinaryOp(
367       uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended);
368   Instruction* and_result =
369       ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd,
370                              id->result_id(), and_mask->result_id());
371   Instruction* or_result = ir_builder.AddBinaryOp(
372       uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y);
373   Instruction* target_inv = ir_builder.AddBinaryOp(
374       uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z);
375 
376   // Do the group operations
377   uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
378   uint32_t subgroup_scope =
379       ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
380   const auto* ballot_value_const = const_mgr->GetConstant(
381       type_mgr->GetUIntVectorType(4),
382       {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
383   Instruction* ballot_value =
384       const_mgr->GetDefiningInstruction(ballot_value_const);
385   Instruction* is_active = ir_builder.AddNaryOp(
386       type_mgr->GetBoolTypeId(), spv::Op::OpGroupNonUniformBallotBitExtract,
387       {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
388   Instruction* shuffle =
389       ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
390                            {subgroup_scope, data_id, target_inv->result_id()});
391 
392   // Create the null constant to use in the select.
393   const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
394                                             std::vector<uint32_t>());
395   Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
396 
397   // Build the select.
398   inst->SetOpcode(spv::Op::OpSelect);
399   Instruction::OperandList new_operands;
400   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
401   new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
402   new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
403 
404   inst->SetInOperands(std::move(new_operands));
405   ctx->UpdateDefUse(inst);
406   return true;
407 }
408 
409 // Returns a folding rule that will replace the WriteInvocationAMD extended
410 // instruction in the SPV_AMD_shader_ballot extension.
411 //
412 // The instruction
413 //
414 // clang-format off
415 //    %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
416 // clang-format on
417 //
418 // with
419 //
420 //     %id = OpLoad %uint %SubgroupLocalInvocationId
421 //    %cmp = OpIEqual %bool %id %invocation_index
422 // %result = OpSelect %type %cmp %write_value %input_value
423 //
424 // Also adding the capabilities and builtins that are needed.
ReplaceWriteInvocation(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)425 bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
426                             const std::vector<const analysis::Constant*>&) {
427   uint32_t var_id = ctx->GetBuiltinInputVarId(
428       uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
429   ctx->AddCapability(spv::Capability::SubgroupBallotKHR);
430   ctx->AddExtension("SPV_KHR_shader_ballot");
431   assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
432   Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
433   Instruction* var_ptr_type =
434       ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
435 
436   InstructionBuilder ir_builder(
437       ctx, inst,
438       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
439   Instruction* t =
440       ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
441   analysis::Bool bool_type;
442   uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
443   Instruction* cmp =
444       ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(),
445                              inst->GetSingleWordInOperand(4));
446 
447   // Build a select.
448   inst->SetOpcode(spv::Op::OpSelect);
449   Instruction::OperandList new_operands;
450   new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
451   new_operands.push_back(inst->GetInOperand(3));
452   new_operands.push_back(inst->GetInOperand(2));
453 
454   inst->SetInOperands(std::move(new_operands));
455   ctx->UpdateDefUse(inst);
456   return true;
457 }
458 
459 // Returns a folding rule that will replace the MbcntAMD extended instruction in
460 // the SPV_AMD_shader_ballot extension.
461 //
462 // The instruction
463 //
464 //  %result = OpExtInst %uint %1 MbcntAMD %mask
465 //
466 // with
467 //
468 // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
469 // AMD's shader compiler expects a 64-bit integer mask.
470 //
471 //     %var = OpLoad %v4uint %SubgroupLtMaskKHR
472 // %shuffle = OpVectorShuffle %v2uint %var %var 0 1
473 //    %cast = OpBitcast %ulong %shuffle
474 //
475 // Perform the mask and count the bits.
476 //
477 //     %and = OpBitwiseAnd %ulong %cast %mask
478 //  %result = OpBitCount %uint %and
479 //
480 // Also adding the capabilities and builtins that are needed.
ReplaceMbcnt(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)481 bool ReplaceMbcnt(IRContext* context, Instruction* inst,
482                   const std::vector<const analysis::Constant*>&) {
483   analysis::TypeManager* type_mgr = context->get_type_mgr();
484   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
485 
486   uint32_t var_id =
487       context->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::SubgroupLtMask));
488   assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
489   context->AddCapability(spv::Capability::GroupNonUniformBallot);
490   Instruction* var_inst = def_use_mgr->GetDef(var_id);
491   Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
492   Instruction* var_type =
493       def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
494   assert(var_type->opcode() == spv::Op::OpTypeVector &&
495          "Variable is suppose to be a vector of 4 ints");
496 
497   // Get the type for the shuffle.
498   analysis::Vector temp_type(GetUIntType(context), 2);
499   const analysis::Type* shuffle_type =
500       context->get_type_mgr()->GetRegisteredType(&temp_type);
501   uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
502 
503   uint32_t mask_id = inst->GetSingleWordInOperand(2);
504   Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
505 
506   // Testing with amd's shader compiler shows that a 64-bit mask is expected.
507   assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
508   assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
509 
510   InstructionBuilder ir_builder(
511       context, inst,
512       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
513   Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
514   Instruction* shuffle = ir_builder.AddVectorShuffle(
515       shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
516   Instruction* bitcast = ir_builder.AddUnaryOp(
517       mask_inst->type_id(), spv::Op::OpBitcast, shuffle->result_id());
518   Instruction* t =
519       ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd,
520                              bitcast->result_id(), mask_id);
521 
522   inst->SetOpcode(spv::Op::OpBitCount);
523   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
524   context->UpdateDefUse(inst);
525   return true;
526 }
527 
528 // A folding rule that will replace the CubeFaceCoordAMD extended
529 // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding is
530 // successful.
531 //
532 // The instruction
533 //
534 //  %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input
535 //
536 // with
537 //
538 //             %x = OpCompositeExtract %float %input 0
539 //             %y = OpCompositeExtract %float %input 1
540 //             %z = OpCompositeExtract %float %input 2
541 //            %nx = OpFNegate %float %x
542 //            %ny = OpFNegate %float %y
543 //            %nz = OpFNegate %float %z
544 //            %ax = OpExtInst %float %n_1 FAbs %x
545 //            %ay = OpExtInst %float %n_1 FAbs %y
546 //            %az = OpExtInst %float %n_1 FAbs %z
547 //      %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax
548 //          %amax = OpExtInst %float %n_1 FMax %az %amax_x_y
549 //        %cubema = OpFMul %float %float_2 %amax
550 //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
551 //  %not_is_z_max = OpLogicalNot %bool %is_z_max
552 //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
553 //      %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x
554 //      %is_z_neg = OpFOrdLessThan %bool %z %float_0
555 // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x
556 //      %is_x_neg = OpFOrdLessThan %bool %x %float_0
557 // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz
558 //           %sel = OpSelect %float %is_y_max %x %cubesc_case_2
559 //        %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel
560 //      %is_y_neg = OpFOrdLessThan %bool %y %float_0
561 // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z
562 //        %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny
563 //          %cube = OpCompositeConstruct %v2float %cubesc %cubetc
564 //         %denom = OpCompositeConstruct %v2float %cubema %cubema
565 //           %div = OpFDiv %v2float %cube %denom
566 //        %result = OpFAdd %v2float %div %const
567 //
568 // Also adding the capabilities and builtins that are needed.
ReplaceCubeFaceCoord(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)569 bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
570                           const std::vector<const analysis::Constant*>&) {
571   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
572   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
573 
574   uint32_t float_type_id = type_mgr->GetFloatTypeId();
575   const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2);
576   uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type);
577   uint32_t bool_id = type_mgr->GetBoolTypeId();
578 
579   InstructionBuilder ir_builder(
580       ctx, inst,
581       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
582 
583   uint32_t input_id = inst->GetSingleWordInOperand(2);
584   uint32_t glsl405_ext_inst_id =
585       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
586   if (glsl405_ext_inst_id == 0) {
587     ctx->AddExtInstImport("GLSL.std.450");
588     glsl405_ext_inst_id =
589         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
590   }
591 
592   // Get the constants that will be used.
593   uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
594   uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
595   uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5);
596   const analysis::Constant* vec_const =
597       const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id});
598   uint32_t vec_const_id =
599       const_mgr->GetDefiningInstruction(vec_const)->result_id();
600 
601   // Extract the input values.
602   Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
603   Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
604   Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
605 
606   // Negate the input values.
607   Instruction* nx =
608       ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, x->result_id());
609   Instruction* ny =
610       ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, y->result_id());
611   Instruction* nz =
612       ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, z->result_id());
613 
614   // Get the abolsute values of the inputs.
615   Instruction* ax = ir_builder.AddNaryExtendedInstruction(
616       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
617   Instruction* ay = ir_builder.AddNaryExtendedInstruction(
618       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
619   Instruction* az = ir_builder.AddNaryExtendedInstruction(
620       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
621 
622   // Find which values are negative.  Used in later computations.
623   Instruction* is_z_neg = ir_builder.AddBinaryOp(
624       bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
625   Instruction* is_y_neg = ir_builder.AddBinaryOp(
626       bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
627   Instruction* is_x_neg = ir_builder.AddBinaryOp(
628       bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
629 
630   // Compute cubema
631   Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
632       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
633       {ax->result_id(), ay->result_id()});
634   Instruction* amax = ir_builder.AddNaryExtendedInstruction(
635       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
636       {az->result_id(), amax_x_y->result_id()});
637   Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul,
638                                                f2_const_id, amax->result_id());
639 
640   // Do the comparisons needed for computing cubesc and cubetc.
641   Instruction* is_z_max =
642       ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
643                              az->result_id(), amax_x_y->result_id());
644   Instruction* not_is_z_max = ir_builder.AddUnaryOp(
645       bool_id, spv::Op::OpLogicalNot, is_z_max->result_id());
646   Instruction* y_gr_x =
647       ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
648                              ay->result_id(), ax->result_id());
649   Instruction* is_y_max =
650       ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd,
651                              not_is_z_max->result_id(), y_gr_x->result_id());
652 
653   // Select the correct value for cubesc.
654   Instruction* cubesc_case_1 = ir_builder.AddSelect(
655       float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id());
656   Instruction* cubesc_case_2 = ir_builder.AddSelect(
657       float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id());
658   Instruction* sel =
659       ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(),
660                            cubesc_case_2->result_id());
661   Instruction* cubesc =
662       ir_builder.AddSelect(float_type_id, is_z_max->result_id(),
663                            cubesc_case_1->result_id(), sel->result_id());
664 
665   // Select the correct value for cubetc.
666   Instruction* cubetc_case_1 = ir_builder.AddSelect(
667       float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id());
668   Instruction* cubetc =
669       ir_builder.AddSelect(float_type_id, is_y_max->result_id(),
670                            cubetc_case_1->result_id(), ny->result_id());
671 
672   // Do the division
673   Instruction* cube = ir_builder.AddCompositeConstruct(
674       v2_float_type_id, {cubesc->result_id(), cubetc->result_id()});
675   Instruction* denom = ir_builder.AddCompositeConstruct(
676       v2_float_type_id, {cubema->result_id(), cubema->result_id()});
677   Instruction* div = ir_builder.AddBinaryOp(
678       v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id());
679 
680   // Get the final result by adding 0.5 to |div|.
681   inst->SetOpcode(spv::Op::OpFAdd);
682   Instruction::OperandList new_operands;
683   new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}});
684   new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}});
685 
686   inst->SetInOperands(std::move(new_operands));
687   ctx->UpdateDefUse(inst);
688   return true;
689 }
690 
691 // A folding rule that will replace the CubeFaceIndexAMD extended
692 // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding
693 // is successful.
694 //
695 // The instruction
696 //
697 //  %result = OpExtInst %float %1 CubeFaceIndexAMD %input
698 //
699 // with
700 //
701 //             %x = OpCompositeExtract %float %input 0
702 //             %y = OpCompositeExtract %float %input 1
703 //             %z = OpCompositeExtract %float %input 2
704 //            %ax = OpExtInst %float %n_1 FAbs %x
705 //            %ay = OpExtInst %float %n_1 FAbs %y
706 //            %az = OpExtInst %float %n_1 FAbs %z
707 //      %is_z_neg = OpFOrdLessThan %bool %z %float_0
708 //      %is_y_neg = OpFOrdLessThan %bool %y %float_0
709 //      %is_x_neg = OpFOrdLessThan %bool %x %float_0
710 //      %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay
711 //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
712 //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
713 //        %case_z = OpSelect %float %is_z_neg %float_5 %float4
714 //        %case_y = OpSelect %float %is_y_neg %float_3 %float2
715 //        %case_x = OpSelect %float %is_x_neg %float_1 %float0
716 //           %sel = OpSelect %float %y_gt_x %case_y %case_x
717 //        %result = OpSelect %float %is_z_max %case_z %sel
718 //
719 // Also adding the capabilities and builtins that are needed.
ReplaceCubeFaceIndex(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)720 bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
721                           const std::vector<const analysis::Constant*>&) {
722   analysis::TypeManager* type_mgr = ctx->get_type_mgr();
723   analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
724 
725   uint32_t float_type_id = type_mgr->GetFloatTypeId();
726   uint32_t bool_id = type_mgr->GetBoolTypeId();
727 
728   InstructionBuilder ir_builder(
729       ctx, inst,
730       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
731 
732   uint32_t input_id = inst->GetSingleWordInOperand(2);
733   uint32_t glsl405_ext_inst_id =
734       ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
735   if (glsl405_ext_inst_id == 0) {
736     ctx->AddExtInstImport("GLSL.std.450");
737     glsl405_ext_inst_id =
738         ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
739   }
740 
741   // Get the constants that will be used.
742   uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
743   uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0);
744   uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
745   uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0);
746   uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0);
747   uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0);
748 
749   // Extract the input values.
750   Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
751   Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
752   Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
753 
754   // Get the absolute values of the inputs.
755   Instruction* ax = ir_builder.AddNaryExtendedInstruction(
756       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
757   Instruction* ay = ir_builder.AddNaryExtendedInstruction(
758       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
759   Instruction* az = ir_builder.AddNaryExtendedInstruction(
760       float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
761 
762   // Find which values are negative.  Used in later computations.
763   Instruction* is_z_neg = ir_builder.AddBinaryOp(
764       bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
765   Instruction* is_y_neg = ir_builder.AddBinaryOp(
766       bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
767   Instruction* is_x_neg = ir_builder.AddBinaryOp(
768       bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
769 
770   // Find the max value.
771   Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
772       float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
773       {ax->result_id(), ay->result_id()});
774   Instruction* is_z_max =
775       ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
776                              az->result_id(), amax_x_y->result_id());
777   Instruction* y_gr_x =
778       ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
779                              ay->result_id(), ax->result_id());
780 
781   // Get the value for each case.
782   Instruction* case_z = ir_builder.AddSelect(
783       float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id);
784   Instruction* case_y = ir_builder.AddSelect(
785       float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id);
786   Instruction* case_x = ir_builder.AddSelect(
787       float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id);
788 
789   // Select the correct case.
790   Instruction* sel =
791       ir_builder.AddSelect(float_type_id, y_gr_x->result_id(),
792                            case_y->result_id(), case_x->result_id());
793 
794   // Get the final result by adding 0.5 to |div|.
795   inst->SetOpcode(spv::Op::OpSelect);
796   Instruction::OperandList new_operands;
797   new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}});
798   new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}});
799   new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}});
800 
801   inst->SetInOperands(std::move(new_operands));
802   ctx->UpdateDefUse(inst);
803   return true;
804 }
805 
806 // A folding rule that will replace the TimeAMD extended instruction in the
807 // SPV_AMD_gcn_shader_ballot.  It returns true if the folding is successful.
808 // It returns False, otherwise.
809 //
810 // The instruction
811 //
812 //  %result = OpExtInst %uint64 %1 TimeAMD
813 //
814 // with
815 //
816 //  %result = OpReadClockKHR %uint64 %uint_3
817 //
818 // NOTE: TimeAMD uses subgroup scope (it is not a real time clock).
ReplaceTimeAMD(IRContext * ctx,Instruction * inst,const std::vector<const analysis::Constant * > &)819 bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst,
820                     const std::vector<const analysis::Constant*>&) {
821   InstructionBuilder ir_builder(
822       ctx, inst,
823       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
824   ctx->AddExtension("SPV_KHR_shader_clock");
825   ctx->AddCapability(spv::Capability::ShaderClockKHR);
826 
827   inst->SetOpcode(spv::Op::OpReadClockKHR);
828   Instruction::OperandList args;
829   uint32_t subgroup_scope_id =
830       ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
831   args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}});
832   inst->SetInOperands(std::move(args));
833   ctx->UpdateDefUse(inst);
834 
835   return true;
836 }
837 
838 class AmdExtFoldingRules : public FoldingRules {
839  public:
AmdExtFoldingRules(IRContext * ctx)840   explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
841 
842  protected:
AddFoldingRules()843   virtual void AddFoldingRules() override {
844     rules_[spv::Op::OpGroupIAddNonUniformAMD].push_back(
845         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformIAdd>);
846     rules_[spv::Op::OpGroupFAddNonUniformAMD].push_back(
847         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFAdd>);
848     rules_[spv::Op::OpGroupUMinNonUniformAMD].push_back(
849         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMin>);
850     rules_[spv::Op::OpGroupSMinNonUniformAMD].push_back(
851         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMin>);
852     rules_[spv::Op::OpGroupFMinNonUniformAMD].push_back(
853         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMin>);
854     rules_[spv::Op::OpGroupUMaxNonUniformAMD].push_back(
855         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMax>);
856     rules_[spv::Op::OpGroupSMaxNonUniformAMD].push_back(
857         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMax>);
858     rules_[spv::Op::OpGroupFMaxNonUniformAMD].push_back(
859         ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMax>);
860 
861     uint32_t extension_id =
862         context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
863 
864     if (extension_id != 0) {
865       ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
866           .push_back(ReplaceSwizzleInvocations);
867       ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
868           .push_back(ReplaceSwizzleInvocationsMasked);
869       ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
870           ReplaceWriteInvocation);
871       ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
872           ReplaceMbcnt);
873     }
874 
875     extension_id = context()->module()->GetExtInstImportId(
876         "SPV_AMD_shader_trinary_minmax");
877 
878     if (extension_id != 0) {
879       ext_rules_[{extension_id, FMin3AMD}].push_back(
880           ReplaceTrinaryMinMax<GLSLstd450FMin>);
881       ext_rules_[{extension_id, UMin3AMD}].push_back(
882           ReplaceTrinaryMinMax<GLSLstd450UMin>);
883       ext_rules_[{extension_id, SMin3AMD}].push_back(
884           ReplaceTrinaryMinMax<GLSLstd450SMin>);
885       ext_rules_[{extension_id, FMax3AMD}].push_back(
886           ReplaceTrinaryMinMax<GLSLstd450FMax>);
887       ext_rules_[{extension_id, UMax3AMD}].push_back(
888           ReplaceTrinaryMinMax<GLSLstd450UMax>);
889       ext_rules_[{extension_id, SMax3AMD}].push_back(
890           ReplaceTrinaryMinMax<GLSLstd450SMax>);
891       ext_rules_[{extension_id, FMid3AMD}].push_back(
892           ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
893       ext_rules_[{extension_id, UMid3AMD}].push_back(
894           ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
895       ext_rules_[{extension_id, SMid3AMD}].push_back(
896           ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
897     }
898 
899     extension_id =
900         context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader");
901 
902     if (extension_id != 0) {
903       ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back(
904           ReplaceCubeFaceCoord);
905       ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back(
906           ReplaceCubeFaceIndex);
907       ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD);
908     }
909   }
910 };
911 
912 class AmdExtConstFoldingRules : public ConstantFoldingRules {
913  public:
AmdExtConstFoldingRules(IRContext * ctx)914   AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
915 
916  protected:
AddFoldingRules()917   virtual void AddFoldingRules() override {}
918 };
919 
920 }  // namespace
921 
Process()922 Pass::Status AmdExtensionToKhrPass::Process() {
923   bool changed = false;
924 
925   // Traverse the body of the functions to replace instructions that require
926   // the extensions.
927   InstructionFolder folder(
928       context(),
929       std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
930       MakeUnique<AmdExtConstFoldingRules>(context()));
931   for (Function& func : *get_module()) {
932     func.ForEachInst([&changed, &folder](Instruction* inst) {
933       if (folder.FoldInstruction(inst)) {
934         changed = true;
935       }
936     });
937   }
938 
939   // Now that instruction that require the extensions have been removed, we can
940   // remove the extension instructions.
941   std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot",
942                                          "SPV_AMD_shader_trinary_minmax",
943                                          "SPV_AMD_gcn_shader"};
944 
945   std::vector<Instruction*> to_be_killed;
946   for (Instruction& inst : context()->module()->extensions()) {
947     if (inst.opcode() == spv::Op::OpExtension) {
948       if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
949         to_be_killed.push_back(&inst);
950       }
951     }
952   }
953 
954   for (Instruction& inst : context()->ext_inst_imports()) {
955     if (inst.opcode() == spv::Op::OpExtInstImport) {
956       if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
957         to_be_killed.push_back(&inst);
958       }
959     }
960   }
961 
962   for (Instruction* inst : to_be_killed) {
963     context()->KillInst(inst);
964     changed = true;
965   }
966 
967   // The replacements that take place use instructions that are missing before
968   // SPIR-V 1.3. If we changed something, we will have to make sure the version
969   // is at least SPIR-V 1.3 to make sure those instruction can be used.
970   if (changed) {
971     uint32_t version = get_module()->version();
972     if (version < 0x00010300 /*1.3*/) {
973       get_module()->set_version(0x00010300);
974     }
975   }
976   return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
977 }
978 
979 }  // namespace opt
980 }  // namespace spvtools
981