1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "convert_to_half_pass.h"
18
19 #include "source/opt/ir_builder.h"
20
21 namespace {
22
23 // Indices of operands in SPIR-V instructions
24 static const int kImageSampleDrefIdInIdx = 2;
25
26 } // anonymous namespace
27
28 namespace spvtools {
29 namespace opt {
30
IsArithmetic(Instruction * inst)31 bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
32 return target_ops_core_.count(inst->opcode()) != 0 ||
33 (inst->opcode() == SpvOpExtInst &&
34 inst->GetSingleWordInOperand(0) ==
35 context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
36 target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
37 }
38
IsFloat(Instruction * inst,uint32_t width)39 bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
40 uint32_t ty_id = inst->type_id();
41 if (ty_id == 0) return false;
42 return Pass::IsFloat(ty_id, width);
43 }
44
IsDecoratedRelaxed(Instruction * inst)45 bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
46 uint32_t r_id = inst->result_id();
47 for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
48 if (r_inst->opcode() == SpvOpDecorate &&
49 r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
50 return true;
51 return false;
52 }
53
IsRelaxed(uint32_t id)54 bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
55 return relaxed_ids_set_.count(id) > 0;
56 }
57
AddRelaxed(uint32_t id)58 void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
59
FloatScalarType(uint32_t width)60 analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
61 analysis::Float float_ty(width);
62 return context()->get_type_mgr()->GetRegisteredType(&float_ty);
63 }
64
FloatVectorType(uint32_t v_len,uint32_t width)65 analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
66 uint32_t width) {
67 analysis::Type* reg_float_ty = FloatScalarType(width);
68 analysis::Vector vec_ty(reg_float_ty, v_len);
69 return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
70 }
71
FloatMatrixType(uint32_t v_cnt,uint32_t vty_id,uint32_t width)72 analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
73 uint32_t vty_id,
74 uint32_t width) {
75 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
76 uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
77 analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
78 analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
79 return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
80 }
81
EquivFloatTypeId(uint32_t ty_id,uint32_t width)82 uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
83 analysis::Type* reg_equiv_ty;
84 Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
85 if (ty_inst->opcode() == SpvOpTypeMatrix)
86 reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
87 ty_inst->GetSingleWordInOperand(0), width);
88 else if (ty_inst->opcode() == SpvOpTypeVector)
89 reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
90 else // SpvOpTypeFloat
91 reg_equiv_ty = FloatScalarType(width);
92 return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
93 }
94
GenConvert(uint32_t * val_idp,uint32_t width,Instruction * inst)95 void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
96 Instruction* inst) {
97 Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
98 uint32_t ty_id = val_inst->type_id();
99 uint32_t nty_id = EquivFloatTypeId(ty_id, width);
100 if (nty_id == ty_id) return;
101 Instruction* cvt_inst;
102 InstructionBuilder builder(
103 context(), inst,
104 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
105 if (val_inst->opcode() == SpvOpUndef)
106 cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
107 else
108 cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
109 *val_idp = cvt_inst->result_id();
110 }
111
MatConvertCleanup(Instruction * inst)112 bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
113 if (inst->opcode() != SpvOpFConvert) return false;
114 uint32_t mty_id = inst->type_id();
115 Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
116 if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
117 uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
118 uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
119 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
120 uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
121 Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
122 InstructionBuilder builder(
123 context(), inst,
124 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
125 // Convert each component vector, combine them with OpCompositeConstruct
126 // and replace original instruction.
127 uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
128 uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
129 uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
130 std::vector<Operand> opnds = {};
131 for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
132 Instruction* ext_inst = builder.AddIdLiteralOp(
133 orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
134 Instruction* cvt_inst =
135 builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
136 opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
137 }
138 uint32_t mat_id = TakeNextId();
139 std::unique_ptr<Instruction> mat_inst(new Instruction(
140 context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
141 (void)builder.AddInstruction(std::move(mat_inst));
142 context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
143 // Turn original instruction into copy so it is valid.
144 inst->SetOpcode(SpvOpCopyObject);
145 inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
146 get_def_use_mgr()->AnalyzeInstUse(inst);
147 return true;
148 }
149
RemoveRelaxedDecoration(uint32_t id)150 bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
151 return context()->get_decoration_mgr()->RemoveDecorationsFrom(
152 id, [](const Instruction& dec) {
153 if (dec.opcode() == SpvOpDecorate &&
154 dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
155 return true;
156 else
157 return false;
158 });
159 }
160
GenHalfArith(Instruction * inst)161 bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
162 bool modified = false;
163 // Convert all float32 based operands to float16 equivalent and change
164 // instruction type to float16 equivalent.
165 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
166 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
167 if (!IsFloat(op_inst, 32)) return;
168 GenConvert(idp, 16, inst);
169 modified = true;
170 });
171 if (IsFloat(inst, 32)) {
172 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
173 converted_ids_.insert(inst->result_id());
174 modified = true;
175 }
176 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
177 return modified;
178 }
179
ProcessPhi(Instruction * inst,uint32_t from_width,uint32_t to_width)180 bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
181 uint32_t to_width) {
182 // Add converts of any float operands to to_width if they are of from_width.
183 // If converting to 16, change type of phi to float16 equivalent and remember
184 // result id. Converts need to be added to preceding blocks.
185 uint32_t ocnt = 0;
186 uint32_t* prev_idp;
187 bool modified = false;
188 inst->ForEachInId([&ocnt, &prev_idp, &from_width, &to_width, &modified,
189 this](uint32_t* idp) {
190 if (ocnt % 2 == 0) {
191 prev_idp = idp;
192 } else {
193 Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
194 if (IsFloat(val_inst, from_width)) {
195 BasicBlock* bp = context()->get_instr_block(*idp);
196 auto insert_before = bp->tail();
197 if (insert_before != bp->begin()) {
198 --insert_before;
199 if (insert_before->opcode() != SpvOpSelectionMerge &&
200 insert_before->opcode() != SpvOpLoopMerge)
201 ++insert_before;
202 }
203 GenConvert(prev_idp, to_width, &*insert_before);
204 modified = true;
205 }
206 }
207 ++ocnt;
208 });
209 if (to_width == 16u) {
210 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16u));
211 converted_ids_.insert(inst->result_id());
212 modified = true;
213 }
214 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
215 return modified;
216 }
217
ProcessConvert(Instruction * inst)218 bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
219 // If float32 and relaxed, change to float16 convert
220 if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
221 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
222 get_def_use_mgr()->AnalyzeInstUse(inst);
223 converted_ids_.insert(inst->result_id());
224 }
225 // If operand and result types are the same, change FConvert to CopyObject to
226 // keep validator happy; simplification and DCE will clean it up
227 // One way this can happen is if an FConvert generated during this pass
228 // (likely by ProcessPhi) is later encountered here and its operand has been
229 // changed to half.
230 uint32_t val_id = inst->GetSingleWordInOperand(0);
231 Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
232 if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
233 return true; // modified
234 }
235
ProcessImageRef(Instruction * inst)236 bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
237 bool modified = false;
238 // If image reference, only need to convert dref args back to float32
239 if (dref_image_ops_.count(inst->opcode()) != 0) {
240 uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
241 if (converted_ids_.count(dref_id) > 0) {
242 GenConvert(&dref_id, 32, inst);
243 inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
244 get_def_use_mgr()->AnalyzeInstUse(inst);
245 modified = true;
246 }
247 }
248 return modified;
249 }
250
ProcessDefault(Instruction * inst)251 bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
252 // If non-relaxed instruction has changed operands, need to convert
253 // them back to float32
254 if (inst->opcode() == SpvOpPhi) return ProcessPhi(inst, 16u, 32u);
255 bool modified = false;
256 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
257 if (converted_ids_.count(*idp) == 0) return;
258 uint32_t old_id = *idp;
259 GenConvert(idp, 32, inst);
260 if (*idp != old_id) modified = true;
261 });
262 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
263 return modified;
264 }
265
GenHalfInst(Instruction * inst)266 bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
267 bool modified = false;
268 // Remember id for later deletion of RelaxedPrecision decoration
269 bool inst_relaxed = IsRelaxed(inst->result_id());
270 if (IsArithmetic(inst) && inst_relaxed)
271 modified = GenHalfArith(inst);
272 else if (inst->opcode() == SpvOpPhi && inst_relaxed)
273 modified = ProcessPhi(inst, 32u, 16u);
274 else if (inst->opcode() == SpvOpFConvert)
275 modified = ProcessConvert(inst);
276 else if (image_ops_.count(inst->opcode()) != 0)
277 modified = ProcessImageRef(inst);
278 else
279 modified = ProcessDefault(inst);
280 return modified;
281 }
282
CloseRelaxInst(Instruction * inst)283 bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
284 if (inst->result_id() == 0) return false;
285 if (IsRelaxed(inst->result_id())) return false;
286 if (!IsFloat(inst, 32)) return false;
287 if (IsDecoratedRelaxed(inst)) {
288 AddRelaxed(inst->result_id());
289 return true;
290 }
291 if (closure_ops_.count(inst->opcode()) == 0) return false;
292 // Can relax if all float operands are relaxed
293 bool relax = true;
294 inst->ForEachInId([&relax, this](uint32_t* idp) {
295 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
296 if (!IsFloat(op_inst, 32)) return;
297 if (!IsRelaxed(*idp)) relax = false;
298 });
299 if (relax) {
300 AddRelaxed(inst->result_id());
301 return true;
302 }
303 // Can relax if all uses are relaxed
304 relax = true;
305 get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
306 if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
307 (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
308 relax = false;
309 return;
310 }
311 });
312 if (relax) {
313 AddRelaxed(inst->result_id());
314 return true;
315 }
316 return false;
317 }
318
ProcessFunction(Function * func)319 bool ConvertToHalfPass::ProcessFunction(Function* func) {
320 // Do a closure of Relaxed on composite and phi instructions
321 bool changed = true;
322 while (changed) {
323 changed = false;
324 cfg()->ForEachBlockInReversePostOrder(
325 func->entry().get(), [&changed, this](BasicBlock* bb) {
326 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
327 changed |= CloseRelaxInst(&*ii);
328 });
329 }
330 // Do convert of relaxed instructions to half precision
331 bool modified = false;
332 cfg()->ForEachBlockInReversePostOrder(
333 func->entry().get(), [&modified, this](BasicBlock* bb) {
334 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
335 modified |= GenHalfInst(&*ii);
336 });
337 // Replace invalid converts of matrix into equivalent vector extracts,
338 // converts and finally a composite construct
339 cfg()->ForEachBlockInReversePostOrder(
340 func->entry().get(), [&modified, this](BasicBlock* bb) {
341 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
342 modified |= MatConvertCleanup(&*ii);
343 });
344 return modified;
345 }
346
ProcessImpl()347 Pass::Status ConvertToHalfPass::ProcessImpl() {
348 Pass::ProcessFunction pfn = [this](Function* fp) {
349 return ProcessFunction(fp);
350 };
351 bool modified = context()->ProcessReachableCallTree(pfn);
352 // If modified, make sure module has Float16 capability
353 if (modified) context()->AddCapability(SpvCapabilityFloat16);
354 // Remove all RelaxedPrecision decorations from instructions and globals
355 for (auto c_id : relaxed_ids_set_) {
356 modified |= RemoveRelaxedDecoration(c_id);
357 }
358 for (auto& val : get_module()->types_values()) {
359 uint32_t v_id = val.result_id();
360 if (v_id != 0) {
361 modified |= RemoveRelaxedDecoration(v_id);
362 }
363 }
364 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
365 }
366
Process()367 Pass::Status ConvertToHalfPass::Process() {
368 Initialize();
369 return ProcessImpl();
370 }
371
Initialize()372 void ConvertToHalfPass::Initialize() {
373 target_ops_core_ = {
374 SpvOpVectorExtractDynamic,
375 SpvOpVectorInsertDynamic,
376 SpvOpVectorShuffle,
377 SpvOpCompositeConstruct,
378 SpvOpCompositeInsert,
379 SpvOpCompositeExtract,
380 SpvOpCopyObject,
381 SpvOpTranspose,
382 SpvOpConvertSToF,
383 SpvOpConvertUToF,
384 // SpvOpFConvert,
385 // SpvOpQuantizeToF16,
386 SpvOpFNegate,
387 SpvOpFAdd,
388 SpvOpFSub,
389 SpvOpFMul,
390 SpvOpFDiv,
391 SpvOpFMod,
392 SpvOpVectorTimesScalar,
393 SpvOpMatrixTimesScalar,
394 SpvOpVectorTimesMatrix,
395 SpvOpMatrixTimesVector,
396 SpvOpMatrixTimesMatrix,
397 SpvOpOuterProduct,
398 SpvOpDot,
399 SpvOpSelect,
400 SpvOpFOrdEqual,
401 SpvOpFUnordEqual,
402 SpvOpFOrdNotEqual,
403 SpvOpFUnordNotEqual,
404 SpvOpFOrdLessThan,
405 SpvOpFUnordLessThan,
406 SpvOpFOrdGreaterThan,
407 SpvOpFUnordGreaterThan,
408 SpvOpFOrdLessThanEqual,
409 SpvOpFUnordLessThanEqual,
410 SpvOpFOrdGreaterThanEqual,
411 SpvOpFUnordGreaterThanEqual,
412 };
413 target_ops_450_ = {
414 GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
415 GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
416 GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
417 GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
418 GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
419 GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
420 GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
421 GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
422 GLSLstd450MatrixInverse,
423 // TODO(greg-lunarg): GLSLstd450ModfStruct,
424 GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
425 GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
426 // TODO(greg-lunarg): GLSLstd450FrexpStruct,
427 GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
428 GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
429 GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
430 image_ops_ = {SpvOpImageSampleImplicitLod,
431 SpvOpImageSampleExplicitLod,
432 SpvOpImageSampleDrefImplicitLod,
433 SpvOpImageSampleDrefExplicitLod,
434 SpvOpImageSampleProjImplicitLod,
435 SpvOpImageSampleProjExplicitLod,
436 SpvOpImageSampleProjDrefImplicitLod,
437 SpvOpImageSampleProjDrefExplicitLod,
438 SpvOpImageFetch,
439 SpvOpImageGather,
440 SpvOpImageDrefGather,
441 SpvOpImageRead,
442 SpvOpImageSparseSampleImplicitLod,
443 SpvOpImageSparseSampleExplicitLod,
444 SpvOpImageSparseSampleDrefImplicitLod,
445 SpvOpImageSparseSampleDrefExplicitLod,
446 SpvOpImageSparseSampleProjImplicitLod,
447 SpvOpImageSparseSampleProjExplicitLod,
448 SpvOpImageSparseSampleProjDrefImplicitLod,
449 SpvOpImageSparseSampleProjDrefExplicitLod,
450 SpvOpImageSparseFetch,
451 SpvOpImageSparseGather,
452 SpvOpImageSparseDrefGather,
453 SpvOpImageSparseTexelsResident,
454 SpvOpImageSparseRead};
455 dref_image_ops_ = {
456 SpvOpImageSampleDrefImplicitLod,
457 SpvOpImageSampleDrefExplicitLod,
458 SpvOpImageSampleProjDrefImplicitLod,
459 SpvOpImageSampleProjDrefExplicitLod,
460 SpvOpImageDrefGather,
461 SpvOpImageSparseSampleDrefImplicitLod,
462 SpvOpImageSparseSampleDrefExplicitLod,
463 SpvOpImageSparseSampleProjDrefImplicitLod,
464 SpvOpImageSparseSampleProjDrefExplicitLod,
465 SpvOpImageSparseDrefGather,
466 };
467 closure_ops_ = {
468 SpvOpVectorExtractDynamic,
469 SpvOpVectorInsertDynamic,
470 SpvOpVectorShuffle,
471 SpvOpCompositeConstruct,
472 SpvOpCompositeInsert,
473 SpvOpCompositeExtract,
474 SpvOpCopyObject,
475 SpvOpTranspose,
476 SpvOpPhi,
477 };
478 relaxed_ids_set_.clear();
479 converted_ids_.clear();
480 }
481
482 } // namespace opt
483 } // namespace spvtools
484