1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "convert_to_half_pass.h"
18
19 #include "source/opt/ir_builder.h"
20
21 namespace {
22
23 // Indices of operands in SPIR-V instructions
24 static const int kImageSampleDrefIdInIdx = 2;
25
26 } // anonymous namespace
27
28 namespace spvtools {
29 namespace opt {
30
IsArithmetic(Instruction * inst)31 bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
32 return target_ops_core_.count(inst->opcode()) != 0 ||
33 (inst->opcode() == SpvOpExtInst &&
34 inst->GetSingleWordInOperand(0) ==
35 context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
36 target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
37 }
38
IsFloat(Instruction * inst,uint32_t width)39 bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
40 uint32_t ty_id = inst->type_id();
41 if (ty_id == 0) return false;
42 return Pass::IsFloat(ty_id, width);
43 }
44
IsDecoratedRelaxed(Instruction * inst)45 bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
46 uint32_t r_id = inst->result_id();
47 for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
48 if (r_inst->opcode() == SpvOpDecorate &&
49 r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
50 return true;
51 return false;
52 }
53
IsRelaxed(uint32_t id)54 bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
55 return relaxed_ids_set_.count(id) > 0;
56 }
57
AddRelaxed(uint32_t id)58 void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
59
FloatScalarType(uint32_t width)60 analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
61 analysis::Float float_ty(width);
62 return context()->get_type_mgr()->GetRegisteredType(&float_ty);
63 }
64
FloatVectorType(uint32_t v_len,uint32_t width)65 analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
66 uint32_t width) {
67 analysis::Type* reg_float_ty = FloatScalarType(width);
68 analysis::Vector vec_ty(reg_float_ty, v_len);
69 return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
70 }
71
FloatMatrixType(uint32_t v_cnt,uint32_t vty_id,uint32_t width)72 analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
73 uint32_t vty_id,
74 uint32_t width) {
75 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
76 uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
77 analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
78 analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
79 return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
80 }
81
EquivFloatTypeId(uint32_t ty_id,uint32_t width)82 uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
83 analysis::Type* reg_equiv_ty;
84 Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
85 if (ty_inst->opcode() == SpvOpTypeMatrix)
86 reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
87 ty_inst->GetSingleWordInOperand(0), width);
88 else if (ty_inst->opcode() == SpvOpTypeVector)
89 reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
90 else // SpvOpTypeFloat
91 reg_equiv_ty = FloatScalarType(width);
92 return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
93 }
94
GenConvert(uint32_t * val_idp,uint32_t width,Instruction * inst)95 void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
96 Instruction* inst) {
97 Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
98 uint32_t ty_id = val_inst->type_id();
99 uint32_t nty_id = EquivFloatTypeId(ty_id, width);
100 if (nty_id == ty_id) return;
101 Instruction* cvt_inst;
102 InstructionBuilder builder(
103 context(), inst,
104 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
105 if (val_inst->opcode() == SpvOpUndef)
106 cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
107 else
108 cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
109 *val_idp = cvt_inst->result_id();
110 }
111
MatConvertCleanup(Instruction * inst)112 bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
113 if (inst->opcode() != SpvOpFConvert) return false;
114 uint32_t mty_id = inst->type_id();
115 Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
116 if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
117 uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
118 uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
119 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
120 uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
121 Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
122 InstructionBuilder builder(
123 context(), inst,
124 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
125 // Convert each component vector, combine them with OpCompositeConstruct
126 // and replace original instruction.
127 uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
128 uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
129 uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
130 std::vector<Operand> opnds = {};
131 for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
132 Instruction* ext_inst = builder.AddIdLiteralOp(
133 orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
134 Instruction* cvt_inst =
135 builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
136 opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
137 }
138 uint32_t mat_id = TakeNextId();
139 std::unique_ptr<Instruction> mat_inst(new Instruction(
140 context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
141 (void)builder.AddInstruction(std::move(mat_inst));
142 context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
143 // Turn original instruction into copy so it is valid.
144 inst->SetOpcode(SpvOpCopyObject);
145 inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
146 get_def_use_mgr()->AnalyzeInstUse(inst);
147 return true;
148 }
149
RemoveRelaxedDecoration(uint32_t id)150 bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
151 return context()->get_decoration_mgr()->RemoveDecorationsFrom(
152 id, [](const Instruction& dec) {
153 if (dec.opcode() == SpvOpDecorate &&
154 dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
155 return true;
156 else
157 return false;
158 });
159 }
160
GenHalfArith(Instruction * inst)161 bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
162 bool modified = false;
163 // Convert all float32 based operands to float16 equivalent and change
164 // instruction type to float16 equivalent.
165 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
166 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
167 if (!IsFloat(op_inst, 32)) return;
168 GenConvert(idp, 16, inst);
169 modified = true;
170 });
171 if (IsFloat(inst, 32)) {
172 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
173 converted_ids_.insert(inst->result_id());
174 modified = true;
175 }
176 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
177 return modified;
178 }
179
ProcessPhi(Instruction * inst)180 bool ConvertToHalfPass::ProcessPhi(Instruction* inst) {
181 // Add float16 converts of any float32 operands and change type
182 // of phi to float16 equivalent. Operand converts need to be added to
183 // preceeding blocks.
184 uint32_t ocnt = 0;
185 uint32_t* prev_idp;
186 inst->ForEachInId([&ocnt, &prev_idp, this](uint32_t* idp) {
187 if (ocnt % 2 == 0) {
188 prev_idp = idp;
189 } else {
190 Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
191 if (IsFloat(val_inst, 32)) {
192 BasicBlock* bp = context()->get_instr_block(*idp);
193 auto insert_before = bp->tail();
194 if (insert_before != bp->begin()) {
195 --insert_before;
196 if (insert_before->opcode() != SpvOpSelectionMerge &&
197 insert_before->opcode() != SpvOpLoopMerge)
198 ++insert_before;
199 }
200 GenConvert(prev_idp, 16, &*insert_before);
201 }
202 }
203 ++ocnt;
204 });
205 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
206 get_def_use_mgr()->AnalyzeInstUse(inst);
207 converted_ids_.insert(inst->result_id());
208 return true;
209 }
210
ProcessConvert(Instruction * inst)211 bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
212 // If float32 and relaxed, change to float16 convert
213 if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
214 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
215 get_def_use_mgr()->AnalyzeInstUse(inst);
216 converted_ids_.insert(inst->result_id());
217 }
218 // If operand and result types are the same, change FConvert to CopyObject to
219 // keep validator happy; simplification and DCE will clean it up
220 // One way this can happen is if an FConvert generated during this pass
221 // (likely by ProcessPhi) is later encountered here and its operand has been
222 // changed to half.
223 uint32_t val_id = inst->GetSingleWordInOperand(0);
224 Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
225 if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
226 return true; // modified
227 }
228
ProcessImageRef(Instruction * inst)229 bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
230 bool modified = false;
231 // If image reference, only need to convert dref args back to float32
232 if (dref_image_ops_.count(inst->opcode()) != 0) {
233 uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
234 if (converted_ids_.count(dref_id) > 0) {
235 GenConvert(&dref_id, 32, inst);
236 inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
237 get_def_use_mgr()->AnalyzeInstUse(inst);
238 modified = true;
239 }
240 }
241 return modified;
242 }
243
ProcessDefault(Instruction * inst)244 bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
245 bool modified = false;
246 // If non-relaxed instruction has changed operands, need to convert
247 // them back to float32
248 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
249 if (converted_ids_.count(*idp) == 0) return;
250 uint32_t old_id = *idp;
251 GenConvert(idp, 32, inst);
252 if (*idp != old_id) modified = true;
253 });
254 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
255 return modified;
256 }
257
GenHalfInst(Instruction * inst)258 bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
259 bool modified = false;
260 // Remember id for later deletion of RelaxedPrecision decoration
261 bool inst_relaxed = IsRelaxed(inst->result_id());
262 if (IsArithmetic(inst) && inst_relaxed)
263 modified = GenHalfArith(inst);
264 else if (inst->opcode() == SpvOpPhi && inst_relaxed)
265 modified = ProcessPhi(inst);
266 else if (inst->opcode() == SpvOpFConvert)
267 modified = ProcessConvert(inst);
268 else if (image_ops_.count(inst->opcode()) != 0)
269 modified = ProcessImageRef(inst);
270 else
271 modified = ProcessDefault(inst);
272 return modified;
273 }
274
CloseRelaxInst(Instruction * inst)275 bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
276 if (inst->result_id() == 0) return false;
277 if (IsRelaxed(inst->result_id())) return false;
278 if (!IsFloat(inst, 32)) return false;
279 if (IsDecoratedRelaxed(inst)) {
280 AddRelaxed(inst->result_id());
281 return true;
282 }
283 if (closure_ops_.count(inst->opcode()) == 0) return false;
284 // Can relax if all float operands are relaxed
285 bool relax = true;
286 inst->ForEachInId([&relax, this](uint32_t* idp) {
287 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
288 if (!IsFloat(op_inst, 32)) return;
289 if (!IsRelaxed(*idp)) relax = false;
290 });
291 if (relax) {
292 AddRelaxed(inst->result_id());
293 return true;
294 }
295 // Can relax if all uses are relaxed
296 relax = true;
297 get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
298 if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
299 (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
300 relax = false;
301 return;
302 }
303 });
304 if (relax) {
305 AddRelaxed(inst->result_id());
306 return true;
307 }
308 return false;
309 }
310
ProcessFunction(Function * func)311 bool ConvertToHalfPass::ProcessFunction(Function* func) {
312 // Do a closure of Relaxed on composite and phi instructions
313 bool changed = true;
314 while (changed) {
315 changed = false;
316 cfg()->ForEachBlockInReversePostOrder(
317 func->entry().get(), [&changed, this](BasicBlock* bb) {
318 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
319 changed |= CloseRelaxInst(&*ii);
320 });
321 }
322 // Do convert of relaxed instructions to half precision
323 bool modified = false;
324 cfg()->ForEachBlockInReversePostOrder(
325 func->entry().get(), [&modified, this](BasicBlock* bb) {
326 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
327 modified |= GenHalfInst(&*ii);
328 });
329 // Replace invalid converts of matrix into equivalent vector extracts,
330 // converts and finally a composite construct
331 cfg()->ForEachBlockInReversePostOrder(
332 func->entry().get(), [&modified, this](BasicBlock* bb) {
333 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
334 modified |= MatConvertCleanup(&*ii);
335 });
336 return modified;
337 }
338
ProcessImpl()339 Pass::Status ConvertToHalfPass::ProcessImpl() {
340 Pass::ProcessFunction pfn = [this](Function* fp) {
341 return ProcessFunction(fp);
342 };
343 bool modified = context()->ProcessEntryPointCallTree(pfn);
344 // If modified, make sure module has Float16 capability
345 if (modified) context()->AddCapability(SpvCapabilityFloat16);
346 // Remove all RelaxedPrecision decorations from instructions and globals
347 for (auto c_id : relaxed_ids_set_) {
348 modified |= RemoveRelaxedDecoration(c_id);
349 }
350 for (auto& val : get_module()->types_values()) {
351 uint32_t v_id = val.result_id();
352 if (v_id != 0) {
353 modified |= RemoveRelaxedDecoration(v_id);
354 }
355 }
356 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
357 }
358
Process()359 Pass::Status ConvertToHalfPass::Process() {
360 Initialize();
361 return ProcessImpl();
362 }
363
Initialize()364 void ConvertToHalfPass::Initialize() {
365 target_ops_core_ = {
366 SpvOpVectorExtractDynamic,
367 SpvOpVectorInsertDynamic,
368 SpvOpVectorShuffle,
369 SpvOpCompositeConstruct,
370 SpvOpCompositeInsert,
371 SpvOpCompositeExtract,
372 SpvOpCopyObject,
373 SpvOpTranspose,
374 SpvOpConvertSToF,
375 SpvOpConvertUToF,
376 // SpvOpFConvert,
377 // SpvOpQuantizeToF16,
378 SpvOpFNegate,
379 SpvOpFAdd,
380 SpvOpFSub,
381 SpvOpFMul,
382 SpvOpFDiv,
383 SpvOpFMod,
384 SpvOpVectorTimesScalar,
385 SpvOpMatrixTimesScalar,
386 SpvOpVectorTimesMatrix,
387 SpvOpMatrixTimesVector,
388 SpvOpMatrixTimesMatrix,
389 SpvOpOuterProduct,
390 SpvOpDot,
391 SpvOpSelect,
392 SpvOpFOrdEqual,
393 SpvOpFUnordEqual,
394 SpvOpFOrdNotEqual,
395 SpvOpFUnordNotEqual,
396 SpvOpFOrdLessThan,
397 SpvOpFUnordLessThan,
398 SpvOpFOrdGreaterThan,
399 SpvOpFUnordGreaterThan,
400 SpvOpFOrdLessThanEqual,
401 SpvOpFUnordLessThanEqual,
402 SpvOpFOrdGreaterThanEqual,
403 SpvOpFUnordGreaterThanEqual,
404 };
405 target_ops_450_ = {
406 GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
407 GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
408 GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
409 GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
410 GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
411 GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
412 GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
413 GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
414 GLSLstd450MatrixInverse,
415 // TODO(greg-lunarg): GLSLstd450ModfStruct,
416 GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
417 GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
418 // TODO(greg-lunarg): GLSLstd450FrexpStruct,
419 GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
420 GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
421 GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
422 image_ops_ = {SpvOpImageSampleImplicitLod,
423 SpvOpImageSampleExplicitLod,
424 SpvOpImageSampleDrefImplicitLod,
425 SpvOpImageSampleDrefExplicitLod,
426 SpvOpImageSampleProjImplicitLod,
427 SpvOpImageSampleProjExplicitLod,
428 SpvOpImageSampleProjDrefImplicitLod,
429 SpvOpImageSampleProjDrefExplicitLod,
430 SpvOpImageFetch,
431 SpvOpImageGather,
432 SpvOpImageDrefGather,
433 SpvOpImageRead,
434 SpvOpImageSparseSampleImplicitLod,
435 SpvOpImageSparseSampleExplicitLod,
436 SpvOpImageSparseSampleDrefImplicitLod,
437 SpvOpImageSparseSampleDrefExplicitLod,
438 SpvOpImageSparseSampleProjImplicitLod,
439 SpvOpImageSparseSampleProjExplicitLod,
440 SpvOpImageSparseSampleProjDrefImplicitLod,
441 SpvOpImageSparseSampleProjDrefExplicitLod,
442 SpvOpImageSparseFetch,
443 SpvOpImageSparseGather,
444 SpvOpImageSparseDrefGather,
445 SpvOpImageSparseTexelsResident,
446 SpvOpImageSparseRead};
447 dref_image_ops_ = {
448 SpvOpImageSampleDrefImplicitLod,
449 SpvOpImageSampleDrefExplicitLod,
450 SpvOpImageSampleProjDrefImplicitLod,
451 SpvOpImageSampleProjDrefExplicitLod,
452 SpvOpImageDrefGather,
453 SpvOpImageSparseSampleDrefImplicitLod,
454 SpvOpImageSparseSampleDrefExplicitLod,
455 SpvOpImageSparseSampleProjDrefImplicitLod,
456 SpvOpImageSparseSampleProjDrefExplicitLod,
457 SpvOpImageSparseDrefGather,
458 };
459 closure_ops_ = {
460 SpvOpVectorExtractDynamic,
461 SpvOpVectorInsertDynamic,
462 SpvOpVectorShuffle,
463 SpvOpCompositeConstruct,
464 SpvOpCompositeInsert,
465 SpvOpCompositeExtract,
466 SpvOpCopyObject,
467 SpvOpTranspose,
468 SpvOpPhi,
469 };
470 relaxed_ids_set_.clear();
471 converted_ids_.clear();
472 }
473
474 } // namespace opt
475 } // namespace spvtools
476