• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "convert_to_half_pass.h"
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace {
22 
23 // Indices of operands in SPIR-V instructions
24 static const int kImageSampleDrefIdInIdx = 2;
25 
26 }  // anonymous namespace
27 
28 namespace spvtools {
29 namespace opt {
30 
IsArithmetic(Instruction * inst)31 bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
32   return target_ops_core_.count(inst->opcode()) != 0 ||
33          (inst->opcode() == SpvOpExtInst &&
34           inst->GetSingleWordInOperand(0) ==
35               context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
36           target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
37 }
38 
IsFloat(Instruction * inst,uint32_t width)39 bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
40   uint32_t ty_id = inst->type_id();
41   if (ty_id == 0) return false;
42   return Pass::IsFloat(ty_id, width);
43 }
44 
IsDecoratedRelaxed(Instruction * inst)45 bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
46   uint32_t r_id = inst->result_id();
47   for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
48     if (r_inst->opcode() == SpvOpDecorate &&
49         r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
50       return true;
51   return false;
52 }
53 
IsRelaxed(uint32_t id)54 bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
55   return relaxed_ids_set_.count(id) > 0;
56 }
57 
AddRelaxed(uint32_t id)58 void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
59 
FloatScalarType(uint32_t width)60 analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
61   analysis::Float float_ty(width);
62   return context()->get_type_mgr()->GetRegisteredType(&float_ty);
63 }
64 
FloatVectorType(uint32_t v_len,uint32_t width)65 analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
66                                                    uint32_t width) {
67   analysis::Type* reg_float_ty = FloatScalarType(width);
68   analysis::Vector vec_ty(reg_float_ty, v_len);
69   return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
70 }
71 
FloatMatrixType(uint32_t v_cnt,uint32_t vty_id,uint32_t width)72 analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
73                                                    uint32_t vty_id,
74                                                    uint32_t width) {
75   Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
76   uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
77   analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
78   analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
79   return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
80 }
81 
EquivFloatTypeId(uint32_t ty_id,uint32_t width)82 uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
83   analysis::Type* reg_equiv_ty;
84   Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
85   if (ty_inst->opcode() == SpvOpTypeMatrix)
86     reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
87                                    ty_inst->GetSingleWordInOperand(0), width);
88   else if (ty_inst->opcode() == SpvOpTypeVector)
89     reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
90   else  // SpvOpTypeFloat
91     reg_equiv_ty = FloatScalarType(width);
92   return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
93 }
94 
GenConvert(uint32_t * val_idp,uint32_t width,Instruction * inst)95 void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
96                                    Instruction* inst) {
97   Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
98   uint32_t ty_id = val_inst->type_id();
99   uint32_t nty_id = EquivFloatTypeId(ty_id, width);
100   if (nty_id == ty_id) return;
101   Instruction* cvt_inst;
102   InstructionBuilder builder(
103       context(), inst,
104       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
105   if (val_inst->opcode() == SpvOpUndef)
106     cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
107   else
108     cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
109   *val_idp = cvt_inst->result_id();
110 }
111 
MatConvertCleanup(Instruction * inst)112 bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
113   if (inst->opcode() != SpvOpFConvert) return false;
114   uint32_t mty_id = inst->type_id();
115   Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
116   if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
117   uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
118   uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
119   Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
120   uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
121   Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
122   InstructionBuilder builder(
123       context(), inst,
124       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
125   // Convert each component vector, combine them with OpCompositeConstruct
126   // and replace original instruction.
127   uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
128   uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
129   uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
130   std::vector<Operand> opnds = {};
131   for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
132     Instruction* ext_inst = builder.AddIdLiteralOp(
133         orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
134     Instruction* cvt_inst =
135         builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
136     opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
137   }
138   uint32_t mat_id = TakeNextId();
139   std::unique_ptr<Instruction> mat_inst(new Instruction(
140       context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
141   (void)builder.AddInstruction(std::move(mat_inst));
142   context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
143   // Turn original instruction into copy so it is valid.
144   inst->SetOpcode(SpvOpCopyObject);
145   inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
146   get_def_use_mgr()->AnalyzeInstUse(inst);
147   return true;
148 }
149 
RemoveRelaxedDecoration(uint32_t id)150 bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
151   return context()->get_decoration_mgr()->RemoveDecorationsFrom(
152       id, [](const Instruction& dec) {
153         if (dec.opcode() == SpvOpDecorate &&
154             dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
155           return true;
156         else
157           return false;
158       });
159 }
160 
GenHalfArith(Instruction * inst)161 bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
162   bool modified = false;
163   // Convert all float32 based operands to float16 equivalent and change
164   // instruction type to float16 equivalent.
165   inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
166     Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
167     if (!IsFloat(op_inst, 32)) return;
168     GenConvert(idp, 16, inst);
169     modified = true;
170   });
171   if (IsFloat(inst, 32)) {
172     inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
173     converted_ids_.insert(inst->result_id());
174     modified = true;
175   }
176   if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
177   return modified;
178 }
179 
ProcessPhi(Instruction * inst)180 bool ConvertToHalfPass::ProcessPhi(Instruction* inst) {
181   // Add float16 converts of any float32 operands and change type
182   // of phi to float16 equivalent. Operand converts need to be added to
183   // preceeding blocks.
184   uint32_t ocnt = 0;
185   uint32_t* prev_idp;
186   inst->ForEachInId([&ocnt, &prev_idp, this](uint32_t* idp) {
187     if (ocnt % 2 == 0) {
188       prev_idp = idp;
189     } else {
190       Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
191       if (IsFloat(val_inst, 32)) {
192         BasicBlock* bp = context()->get_instr_block(*idp);
193         auto insert_before = bp->tail();
194         if (insert_before != bp->begin()) {
195           --insert_before;
196           if (insert_before->opcode() != SpvOpSelectionMerge &&
197               insert_before->opcode() != SpvOpLoopMerge)
198             ++insert_before;
199         }
200         GenConvert(prev_idp, 16, &*insert_before);
201       }
202     }
203     ++ocnt;
204   });
205   inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
206   get_def_use_mgr()->AnalyzeInstUse(inst);
207   converted_ids_.insert(inst->result_id());
208   return true;
209 }
210 
ProcessConvert(Instruction * inst)211 bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
212   // If float32 and relaxed, change to float16 convert
213   if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
214     inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
215     get_def_use_mgr()->AnalyzeInstUse(inst);
216     converted_ids_.insert(inst->result_id());
217   }
218   // If operand and result types are the same, change FConvert to CopyObject to
219   // keep validator happy; simplification and DCE will clean it up
220   // One way this can happen is if an FConvert generated during this pass
221   // (likely by ProcessPhi) is later encountered here and its operand has been
222   // changed to half.
223   uint32_t val_id = inst->GetSingleWordInOperand(0);
224   Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
225   if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
226   return true;  // modified
227 }
228 
ProcessImageRef(Instruction * inst)229 bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
230   bool modified = false;
231   // If image reference, only need to convert dref args back to float32
232   if (dref_image_ops_.count(inst->opcode()) != 0) {
233     uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
234     if (converted_ids_.count(dref_id) > 0) {
235       GenConvert(&dref_id, 32, inst);
236       inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
237       get_def_use_mgr()->AnalyzeInstUse(inst);
238       modified = true;
239     }
240   }
241   return modified;
242 }
243 
ProcessDefault(Instruction * inst)244 bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
245   bool modified = false;
246   // If non-relaxed instruction has changed operands, need to convert
247   // them back to float32
248   inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
249     if (converted_ids_.count(*idp) == 0) return;
250     uint32_t old_id = *idp;
251     GenConvert(idp, 32, inst);
252     if (*idp != old_id) modified = true;
253   });
254   if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
255   return modified;
256 }
257 
GenHalfInst(Instruction * inst)258 bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
259   bool modified = false;
260   // Remember id for later deletion of RelaxedPrecision decoration
261   bool inst_relaxed = IsRelaxed(inst->result_id());
262   if (IsArithmetic(inst) && inst_relaxed)
263     modified = GenHalfArith(inst);
264   else if (inst->opcode() == SpvOpPhi && inst_relaxed)
265     modified = ProcessPhi(inst);
266   else if (inst->opcode() == SpvOpFConvert)
267     modified = ProcessConvert(inst);
268   else if (image_ops_.count(inst->opcode()) != 0)
269     modified = ProcessImageRef(inst);
270   else
271     modified = ProcessDefault(inst);
272   return modified;
273 }
274 
CloseRelaxInst(Instruction * inst)275 bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
276   if (inst->result_id() == 0) return false;
277   if (IsRelaxed(inst->result_id())) return false;
278   if (!IsFloat(inst, 32)) return false;
279   if (IsDecoratedRelaxed(inst)) {
280     AddRelaxed(inst->result_id());
281     return true;
282   }
283   if (closure_ops_.count(inst->opcode()) == 0) return false;
284   // Can relax if all float operands are relaxed
285   bool relax = true;
286   inst->ForEachInId([&relax, this](uint32_t* idp) {
287     Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
288     if (!IsFloat(op_inst, 32)) return;
289     if (!IsRelaxed(*idp)) relax = false;
290   });
291   if (relax) {
292     AddRelaxed(inst->result_id());
293     return true;
294   }
295   // Can relax if all uses are relaxed
296   relax = true;
297   get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
298     if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
299         (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
300       relax = false;
301       return;
302     }
303   });
304   if (relax) {
305     AddRelaxed(inst->result_id());
306     return true;
307   }
308   return false;
309 }
310 
ProcessFunction(Function * func)311 bool ConvertToHalfPass::ProcessFunction(Function* func) {
312   // Do a closure of Relaxed on composite and phi instructions
313   bool changed = true;
314   while (changed) {
315     changed = false;
316     cfg()->ForEachBlockInReversePostOrder(
317         func->entry().get(), [&changed, this](BasicBlock* bb) {
318           for (auto ii = bb->begin(); ii != bb->end(); ++ii)
319             changed |= CloseRelaxInst(&*ii);
320         });
321   }
322   // Do convert of relaxed instructions to half precision
323   bool modified = false;
324   cfg()->ForEachBlockInReversePostOrder(
325       func->entry().get(), [&modified, this](BasicBlock* bb) {
326         for (auto ii = bb->begin(); ii != bb->end(); ++ii)
327           modified |= GenHalfInst(&*ii);
328       });
329   // Replace invalid converts of matrix into equivalent vector extracts,
330   // converts and finally a composite construct
331   cfg()->ForEachBlockInReversePostOrder(
332       func->entry().get(), [&modified, this](BasicBlock* bb) {
333         for (auto ii = bb->begin(); ii != bb->end(); ++ii)
334           modified |= MatConvertCleanup(&*ii);
335       });
336   return modified;
337 }
338 
ProcessImpl()339 Pass::Status ConvertToHalfPass::ProcessImpl() {
340   Pass::ProcessFunction pfn = [this](Function* fp) {
341     return ProcessFunction(fp);
342   };
343   bool modified = context()->ProcessEntryPointCallTree(pfn);
344   // If modified, make sure module has Float16 capability
345   if (modified) context()->AddCapability(SpvCapabilityFloat16);
346   // Remove all RelaxedPrecision decorations from instructions and globals
347   for (auto c_id : relaxed_ids_set_) {
348     modified |= RemoveRelaxedDecoration(c_id);
349   }
350   for (auto& val : get_module()->types_values()) {
351     uint32_t v_id = val.result_id();
352     if (v_id != 0) {
353       modified |= RemoveRelaxedDecoration(v_id);
354     }
355   }
356   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
357 }
358 
Process()359 Pass::Status ConvertToHalfPass::Process() {
360   Initialize();
361   return ProcessImpl();
362 }
363 
Initialize()364 void ConvertToHalfPass::Initialize() {
365   target_ops_core_ = {
366       SpvOpVectorExtractDynamic,
367       SpvOpVectorInsertDynamic,
368       SpvOpVectorShuffle,
369       SpvOpCompositeConstruct,
370       SpvOpCompositeInsert,
371       SpvOpCompositeExtract,
372       SpvOpCopyObject,
373       SpvOpTranspose,
374       SpvOpConvertSToF,
375       SpvOpConvertUToF,
376       // SpvOpFConvert,
377       // SpvOpQuantizeToF16,
378       SpvOpFNegate,
379       SpvOpFAdd,
380       SpvOpFSub,
381       SpvOpFMul,
382       SpvOpFDiv,
383       SpvOpFMod,
384       SpvOpVectorTimesScalar,
385       SpvOpMatrixTimesScalar,
386       SpvOpVectorTimesMatrix,
387       SpvOpMatrixTimesVector,
388       SpvOpMatrixTimesMatrix,
389       SpvOpOuterProduct,
390       SpvOpDot,
391       SpvOpSelect,
392       SpvOpFOrdEqual,
393       SpvOpFUnordEqual,
394       SpvOpFOrdNotEqual,
395       SpvOpFUnordNotEqual,
396       SpvOpFOrdLessThan,
397       SpvOpFUnordLessThan,
398       SpvOpFOrdGreaterThan,
399       SpvOpFUnordGreaterThan,
400       SpvOpFOrdLessThanEqual,
401       SpvOpFUnordLessThanEqual,
402       SpvOpFOrdGreaterThanEqual,
403       SpvOpFUnordGreaterThanEqual,
404   };
405   target_ops_450_ = {
406       GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
407       GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
408       GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
409       GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
410       GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
411       GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
412       GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
413       GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
414       GLSLstd450MatrixInverse,
415       // TODO(greg-lunarg): GLSLstd450ModfStruct,
416       GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
417       GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
418       // TODO(greg-lunarg): GLSLstd450FrexpStruct,
419       GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
420       GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
421       GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
422   image_ops_ = {SpvOpImageSampleImplicitLod,
423                 SpvOpImageSampleExplicitLod,
424                 SpvOpImageSampleDrefImplicitLod,
425                 SpvOpImageSampleDrefExplicitLod,
426                 SpvOpImageSampleProjImplicitLod,
427                 SpvOpImageSampleProjExplicitLod,
428                 SpvOpImageSampleProjDrefImplicitLod,
429                 SpvOpImageSampleProjDrefExplicitLod,
430                 SpvOpImageFetch,
431                 SpvOpImageGather,
432                 SpvOpImageDrefGather,
433                 SpvOpImageRead,
434                 SpvOpImageSparseSampleImplicitLod,
435                 SpvOpImageSparseSampleExplicitLod,
436                 SpvOpImageSparseSampleDrefImplicitLod,
437                 SpvOpImageSparseSampleDrefExplicitLod,
438                 SpvOpImageSparseSampleProjImplicitLod,
439                 SpvOpImageSparseSampleProjExplicitLod,
440                 SpvOpImageSparseSampleProjDrefImplicitLod,
441                 SpvOpImageSparseSampleProjDrefExplicitLod,
442                 SpvOpImageSparseFetch,
443                 SpvOpImageSparseGather,
444                 SpvOpImageSparseDrefGather,
445                 SpvOpImageSparseTexelsResident,
446                 SpvOpImageSparseRead};
447   dref_image_ops_ = {
448       SpvOpImageSampleDrefImplicitLod,
449       SpvOpImageSampleDrefExplicitLod,
450       SpvOpImageSampleProjDrefImplicitLod,
451       SpvOpImageSampleProjDrefExplicitLod,
452       SpvOpImageDrefGather,
453       SpvOpImageSparseSampleDrefImplicitLod,
454       SpvOpImageSparseSampleDrefExplicitLod,
455       SpvOpImageSparseSampleProjDrefImplicitLod,
456       SpvOpImageSparseSampleProjDrefExplicitLod,
457       SpvOpImageSparseDrefGather,
458   };
459   closure_ops_ = {
460       SpvOpVectorExtractDynamic,
461       SpvOpVectorInsertDynamic,
462       SpvOpVectorShuffle,
463       SpvOpCompositeConstruct,
464       SpvOpCompositeInsert,
465       SpvOpCompositeExtract,
466       SpvOpCopyObject,
467       SpvOpTranspose,
468       SpvOpPhi,
469   };
470   relaxed_ids_set_.clear();
471   converted_ids_.clear();
472 }
473 
474 }  // namespace opt
475 }  // namespace spvtools
476