• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "convert_to_half_pass.h"
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace {
22 
23 // Indices of operands in SPIR-V instructions
24 static const int kImageSampleDrefIdInIdx = 2;
25 
26 }  // anonymous namespace
27 
28 namespace spvtools {
29 namespace opt {
30 
IsArithmetic(Instruction * inst)31 bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
32   return target_ops_core_.count(inst->opcode()) != 0 ||
33          (inst->opcode() == SpvOpExtInst &&
34           inst->GetSingleWordInOperand(0) ==
35               context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
36           target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
37 }
38 
IsFloat(Instruction * inst,uint32_t width)39 bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
40   uint32_t ty_id = inst->type_id();
41   if (ty_id == 0) return false;
42   return Pass::IsFloat(ty_id, width);
43 }
44 
IsDecoratedRelaxed(Instruction * inst)45 bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
46   uint32_t r_id = inst->result_id();
47   for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
48     if (r_inst->opcode() == SpvOpDecorate &&
49         r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
50       return true;
51   return false;
52 }
53 
IsRelaxed(uint32_t id)54 bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
55   return relaxed_ids_set_.count(id) > 0;
56 }
57 
AddRelaxed(uint32_t id)58 void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
59 
FloatScalarType(uint32_t width)60 analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
61   analysis::Float float_ty(width);
62   return context()->get_type_mgr()->GetRegisteredType(&float_ty);
63 }
64 
FloatVectorType(uint32_t v_len,uint32_t width)65 analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
66                                                    uint32_t width) {
67   analysis::Type* reg_float_ty = FloatScalarType(width);
68   analysis::Vector vec_ty(reg_float_ty, v_len);
69   return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
70 }
71 
FloatMatrixType(uint32_t v_cnt,uint32_t vty_id,uint32_t width)72 analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
73                                                    uint32_t vty_id,
74                                                    uint32_t width) {
75   Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
76   uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
77   analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
78   analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
79   return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
80 }
81 
EquivFloatTypeId(uint32_t ty_id,uint32_t width)82 uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
83   analysis::Type* reg_equiv_ty;
84   Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
85   if (ty_inst->opcode() == SpvOpTypeMatrix)
86     reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
87                                    ty_inst->GetSingleWordInOperand(0), width);
88   else if (ty_inst->opcode() == SpvOpTypeVector)
89     reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
90   else  // SpvOpTypeFloat
91     reg_equiv_ty = FloatScalarType(width);
92   return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
93 }
94 
GenConvert(uint32_t * val_idp,uint32_t width,Instruction * inst)95 void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
96                                    Instruction* inst) {
97   Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
98   uint32_t ty_id = val_inst->type_id();
99   uint32_t nty_id = EquivFloatTypeId(ty_id, width);
100   if (nty_id == ty_id) return;
101   Instruction* cvt_inst;
102   InstructionBuilder builder(
103       context(), inst,
104       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
105   if (val_inst->opcode() == SpvOpUndef)
106     cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
107   else
108     cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
109   *val_idp = cvt_inst->result_id();
110 }
111 
MatConvertCleanup(Instruction * inst)112 bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
113   if (inst->opcode() != SpvOpFConvert) return false;
114   uint32_t mty_id = inst->type_id();
115   Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
116   if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
117   uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
118   uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
119   Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
120   uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
121   Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
122   InstructionBuilder builder(
123       context(), inst,
124       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
125   // Convert each component vector, combine them with OpCompositeConstruct
126   // and replace original instruction.
127   uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
128   uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
129   uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
130   std::vector<Operand> opnds = {};
131   for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
132     Instruction* ext_inst = builder.AddIdLiteralOp(
133         orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
134     Instruction* cvt_inst =
135         builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
136     opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
137   }
138   uint32_t mat_id = TakeNextId();
139   std::unique_ptr<Instruction> mat_inst(new Instruction(
140       context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
141   (void)builder.AddInstruction(std::move(mat_inst));
142   context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
143   // Turn original instruction into copy so it is valid.
144   inst->SetOpcode(SpvOpCopyObject);
145   inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
146   get_def_use_mgr()->AnalyzeInstUse(inst);
147   return true;
148 }
149 
RemoveRelaxedDecoration(uint32_t id)150 bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
151   return context()->get_decoration_mgr()->RemoveDecorationsFrom(
152       id, [](const Instruction& dec) {
153         if (dec.opcode() == SpvOpDecorate &&
154             dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
155           return true;
156         else
157           return false;
158       });
159 }
160 
GenHalfArith(Instruction * inst)161 bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
162   bool modified = false;
163   // Convert all float32 based operands to float16 equivalent and change
164   // instruction type to float16 equivalent.
165   inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
166     Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
167     if (!IsFloat(op_inst, 32)) return;
168     GenConvert(idp, 16, inst);
169     modified = true;
170   });
171   if (IsFloat(inst, 32)) {
172     inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
173     converted_ids_.insert(inst->result_id());
174     modified = true;
175   }
176   if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
177   return modified;
178 }
179 
ProcessPhi(Instruction * inst,uint32_t from_width,uint32_t to_width)180 bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
181                                    uint32_t to_width) {
182   // Add converts of any float operands to to_width if they are of from_width.
183   // If converting to 16, change type of phi to float16 equivalent and remember
184   // result id. Converts need to be added to preceding blocks.
185   uint32_t ocnt = 0;
186   uint32_t* prev_idp;
187   bool modified = false;
188   inst->ForEachInId([&ocnt, &prev_idp, &from_width, &to_width, &modified,
189                      this](uint32_t* idp) {
190     if (ocnt % 2 == 0) {
191       prev_idp = idp;
192     } else {
193       Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
194       if (IsFloat(val_inst, from_width)) {
195         BasicBlock* bp = context()->get_instr_block(*idp);
196         auto insert_before = bp->tail();
197         if (insert_before != bp->begin()) {
198           --insert_before;
199           if (insert_before->opcode() != SpvOpSelectionMerge &&
200               insert_before->opcode() != SpvOpLoopMerge)
201             ++insert_before;
202         }
203         GenConvert(prev_idp, to_width, &*insert_before);
204         modified = true;
205       }
206     }
207     ++ocnt;
208   });
209   if (to_width == 16u) {
210     inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16u));
211     converted_ids_.insert(inst->result_id());
212     modified = true;
213   }
214   if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
215   return modified;
216 }
217 
ProcessConvert(Instruction * inst)218 bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
219   // If float32 and relaxed, change to float16 convert
220   if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
221     inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
222     get_def_use_mgr()->AnalyzeInstUse(inst);
223     converted_ids_.insert(inst->result_id());
224   }
225   // If operand and result types are the same, change FConvert to CopyObject to
226   // keep validator happy; simplification and DCE will clean it up
227   // One way this can happen is if an FConvert generated during this pass
228   // (likely by ProcessPhi) is later encountered here and its operand has been
229   // changed to half.
230   uint32_t val_id = inst->GetSingleWordInOperand(0);
231   Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
232   if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
233   return true;  // modified
234 }
235 
ProcessImageRef(Instruction * inst)236 bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
237   bool modified = false;
238   // If image reference, only need to convert dref args back to float32
239   if (dref_image_ops_.count(inst->opcode()) != 0) {
240     uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
241     if (converted_ids_.count(dref_id) > 0) {
242       GenConvert(&dref_id, 32, inst);
243       inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
244       get_def_use_mgr()->AnalyzeInstUse(inst);
245       modified = true;
246     }
247   }
248   return modified;
249 }
250 
ProcessDefault(Instruction * inst)251 bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
252   // If non-relaxed instruction has changed operands, need to convert
253   // them back to float32
254   if (inst->opcode() == SpvOpPhi) return ProcessPhi(inst, 16u, 32u);
255   bool modified = false;
256   inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
257     if (converted_ids_.count(*idp) == 0) return;
258     uint32_t old_id = *idp;
259     GenConvert(idp, 32, inst);
260     if (*idp != old_id) modified = true;
261   });
262   if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
263   return modified;
264 }
265 
GenHalfInst(Instruction * inst)266 bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
267   bool modified = false;
268   // Remember id for later deletion of RelaxedPrecision decoration
269   bool inst_relaxed = IsRelaxed(inst->result_id());
270   if (IsArithmetic(inst) && inst_relaxed)
271     modified = GenHalfArith(inst);
272   else if (inst->opcode() == SpvOpPhi && inst_relaxed)
273     modified = ProcessPhi(inst, 32u, 16u);
274   else if (inst->opcode() == SpvOpFConvert)
275     modified = ProcessConvert(inst);
276   else if (image_ops_.count(inst->opcode()) != 0)
277     modified = ProcessImageRef(inst);
278   else
279     modified = ProcessDefault(inst);
280   return modified;
281 }
282 
CloseRelaxInst(Instruction * inst)283 bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
284   if (inst->result_id() == 0) return false;
285   if (IsRelaxed(inst->result_id())) return false;
286   if (!IsFloat(inst, 32)) return false;
287   if (IsDecoratedRelaxed(inst)) {
288     AddRelaxed(inst->result_id());
289     return true;
290   }
291   if (closure_ops_.count(inst->opcode()) == 0) return false;
292   // Can relax if all float operands are relaxed
293   bool relax = true;
294   inst->ForEachInId([&relax, this](uint32_t* idp) {
295     Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
296     if (!IsFloat(op_inst, 32)) return;
297     if (!IsRelaxed(*idp)) relax = false;
298   });
299   if (relax) {
300     AddRelaxed(inst->result_id());
301     return true;
302   }
303   // Can relax if all uses are relaxed
304   relax = true;
305   get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
306     if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
307         (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
308       relax = false;
309       return;
310     }
311   });
312   if (relax) {
313     AddRelaxed(inst->result_id());
314     return true;
315   }
316   return false;
317 }
318 
ProcessFunction(Function * func)319 bool ConvertToHalfPass::ProcessFunction(Function* func) {
320   // Do a closure of Relaxed on composite and phi instructions
321   bool changed = true;
322   while (changed) {
323     changed = false;
324     cfg()->ForEachBlockInReversePostOrder(
325         func->entry().get(), [&changed, this](BasicBlock* bb) {
326           for (auto ii = bb->begin(); ii != bb->end(); ++ii)
327             changed |= CloseRelaxInst(&*ii);
328         });
329   }
330   // Do convert of relaxed instructions to half precision
331   bool modified = false;
332   cfg()->ForEachBlockInReversePostOrder(
333       func->entry().get(), [&modified, this](BasicBlock* bb) {
334         for (auto ii = bb->begin(); ii != bb->end(); ++ii)
335           modified |= GenHalfInst(&*ii);
336       });
337   // Replace invalid converts of matrix into equivalent vector extracts,
338   // converts and finally a composite construct
339   cfg()->ForEachBlockInReversePostOrder(
340       func->entry().get(), [&modified, this](BasicBlock* bb) {
341         for (auto ii = bb->begin(); ii != bb->end(); ++ii)
342           modified |= MatConvertCleanup(&*ii);
343       });
344   return modified;
345 }
346 
ProcessImpl()347 Pass::Status ConvertToHalfPass::ProcessImpl() {
348   Pass::ProcessFunction pfn = [this](Function* fp) {
349     return ProcessFunction(fp);
350   };
351   bool modified = context()->ProcessReachableCallTree(pfn);
352   // If modified, make sure module has Float16 capability
353   if (modified) context()->AddCapability(SpvCapabilityFloat16);
354   // Remove all RelaxedPrecision decorations from instructions and globals
355   for (auto c_id : relaxed_ids_set_) {
356     modified |= RemoveRelaxedDecoration(c_id);
357   }
358   for (auto& val : get_module()->types_values()) {
359     uint32_t v_id = val.result_id();
360     if (v_id != 0) {
361       modified |= RemoveRelaxedDecoration(v_id);
362     }
363   }
364   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
365 }
366 
Process()367 Pass::Status ConvertToHalfPass::Process() {
368   Initialize();
369   return ProcessImpl();
370 }
371 
Initialize()372 void ConvertToHalfPass::Initialize() {
373   target_ops_core_ = {
374       SpvOpVectorExtractDynamic,
375       SpvOpVectorInsertDynamic,
376       SpvOpVectorShuffle,
377       SpvOpCompositeConstruct,
378       SpvOpCompositeInsert,
379       SpvOpCompositeExtract,
380       SpvOpCopyObject,
381       SpvOpTranspose,
382       SpvOpConvertSToF,
383       SpvOpConvertUToF,
384       // SpvOpFConvert,
385       // SpvOpQuantizeToF16,
386       SpvOpFNegate,
387       SpvOpFAdd,
388       SpvOpFSub,
389       SpvOpFMul,
390       SpvOpFDiv,
391       SpvOpFMod,
392       SpvOpVectorTimesScalar,
393       SpvOpMatrixTimesScalar,
394       SpvOpVectorTimesMatrix,
395       SpvOpMatrixTimesVector,
396       SpvOpMatrixTimesMatrix,
397       SpvOpOuterProduct,
398       SpvOpDot,
399       SpvOpSelect,
400       SpvOpFOrdEqual,
401       SpvOpFUnordEqual,
402       SpvOpFOrdNotEqual,
403       SpvOpFUnordNotEqual,
404       SpvOpFOrdLessThan,
405       SpvOpFUnordLessThan,
406       SpvOpFOrdGreaterThan,
407       SpvOpFUnordGreaterThan,
408       SpvOpFOrdLessThanEqual,
409       SpvOpFUnordLessThanEqual,
410       SpvOpFOrdGreaterThanEqual,
411       SpvOpFUnordGreaterThanEqual,
412   };
413   target_ops_450_ = {
414       GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
415       GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
416       GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
417       GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
418       GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
419       GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
420       GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
421       GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
422       GLSLstd450MatrixInverse,
423       // TODO(greg-lunarg): GLSLstd450ModfStruct,
424       GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
425       GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
426       // TODO(greg-lunarg): GLSLstd450FrexpStruct,
427       GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
428       GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
429       GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
430   image_ops_ = {SpvOpImageSampleImplicitLod,
431                 SpvOpImageSampleExplicitLod,
432                 SpvOpImageSampleDrefImplicitLod,
433                 SpvOpImageSampleDrefExplicitLod,
434                 SpvOpImageSampleProjImplicitLod,
435                 SpvOpImageSampleProjExplicitLod,
436                 SpvOpImageSampleProjDrefImplicitLod,
437                 SpvOpImageSampleProjDrefExplicitLod,
438                 SpvOpImageFetch,
439                 SpvOpImageGather,
440                 SpvOpImageDrefGather,
441                 SpvOpImageRead,
442                 SpvOpImageSparseSampleImplicitLod,
443                 SpvOpImageSparseSampleExplicitLod,
444                 SpvOpImageSparseSampleDrefImplicitLod,
445                 SpvOpImageSparseSampleDrefExplicitLod,
446                 SpvOpImageSparseSampleProjImplicitLod,
447                 SpvOpImageSparseSampleProjExplicitLod,
448                 SpvOpImageSparseSampleProjDrefImplicitLod,
449                 SpvOpImageSparseSampleProjDrefExplicitLod,
450                 SpvOpImageSparseFetch,
451                 SpvOpImageSparseGather,
452                 SpvOpImageSparseDrefGather,
453                 SpvOpImageSparseTexelsResident,
454                 SpvOpImageSparseRead};
455   dref_image_ops_ = {
456       SpvOpImageSampleDrefImplicitLod,
457       SpvOpImageSampleDrefExplicitLod,
458       SpvOpImageSampleProjDrefImplicitLod,
459       SpvOpImageSampleProjDrefExplicitLod,
460       SpvOpImageDrefGather,
461       SpvOpImageSparseSampleDrefImplicitLod,
462       SpvOpImageSparseSampleDrefExplicitLod,
463       SpvOpImageSparseSampleProjDrefImplicitLod,
464       SpvOpImageSparseSampleProjDrefExplicitLod,
465       SpvOpImageSparseDrefGather,
466   };
467   closure_ops_ = {
468       SpvOpVectorExtractDynamic,
469       SpvOpVectorInsertDynamic,
470       SpvOpVectorShuffle,
471       SpvOpCompositeConstruct,
472       SpvOpCompositeInsert,
473       SpvOpCompositeExtract,
474       SpvOpCopyObject,
475       SpvOpTranspose,
476       SpvOpPhi,
477   };
478   relaxed_ids_set_.clear();
479   converted_ids_.clear();
480 }
481 
482 }  // namespace opt
483 }  // namespace spvtools
484