• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/generate_webgpu_initializers_pass.h"
16 #include "source/opt/ir_context.h"
17 
18 namespace spvtools {
19 namespace opt {
20 
21 using inst_iterator = InstructionList::iterator;
22 
23 namespace {
24 
NeedsWebGPUInitializer(Instruction * inst)25 bool NeedsWebGPUInitializer(Instruction* inst) {
26   if (inst->opcode() != SpvOpVariable) return false;
27 
28   auto storage_class = inst->GetSingleWordOperand(2);
29   if (storage_class != SpvStorageClassOutput &&
30       storage_class != SpvStorageClassPrivate &&
31       storage_class != SpvStorageClassFunction) {
32     return false;
33   }
34 
35   if (inst->NumOperands() > 3) return false;
36 
37   return true;
38 }
39 
40 }  // namespace
41 
Process()42 Pass::Status GenerateWebGPUInitializersPass::Process() {
43   auto* module = context()->module();
44   bool changed = false;
45 
46   // Handle global/module scoped variables
47   for (auto iter = module->types_values_begin();
48        iter != module->types_values_end(); ++iter) {
49     Instruction* inst = &(*iter);
50 
51     if (inst->opcode() == SpvOpConstantNull) {
52       null_constant_type_map_[inst->type_id()] = inst;
53       seen_null_constants_.insert(inst);
54       continue;
55     }
56 
57     if (!NeedsWebGPUInitializer(inst)) continue;
58 
59     changed = true;
60 
61     auto* constant_inst = GetNullConstantForVariable(inst);
62     if (!constant_inst) return Status::Failure;
63 
64     if (seen_null_constants_.find(constant_inst) ==
65         seen_null_constants_.end()) {
66       constant_inst->InsertBefore(inst);
67       null_constant_type_map_[inst->type_id()] = inst;
68       seen_null_constants_.insert(inst);
69     }
70     AddNullInitializerToVariable(constant_inst, inst);
71   }
72 
73   // Handle local/function scoped variables
74   for (auto func = module->begin(); func != module->end(); ++func) {
75     auto block = func->entry().get();
76     for (auto iter = block->begin();
77          iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
78       Instruction* inst = &(*iter);
79       if (!NeedsWebGPUInitializer(inst)) continue;
80 
81       changed = true;
82       auto* constant_inst = GetNullConstantForVariable(inst);
83       if (!constant_inst) return Status::Failure;
84 
85       AddNullInitializerToVariable(constant_inst, inst);
86     }
87   }
88 
89   return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
90 }
91 
GetNullConstantForVariable(Instruction * variable_inst)92 Instruction* GenerateWebGPUInitializersPass::GetNullConstantForVariable(
93     Instruction* variable_inst) {
94   auto constant_mgr = context()->get_constant_mgr();
95   auto* def_use_mgr = get_def_use_mgr();
96 
97   auto* ptr_inst = def_use_mgr->GetDef(variable_inst->type_id());
98   auto type_id = ptr_inst->GetInOperand(1).words[0];
99   if (null_constant_type_map_.find(type_id) == null_constant_type_map_.end()) {
100     auto* constant_type = context()->get_type_mgr()->GetType(type_id);
101     auto* constant = constant_mgr->GetConstant(constant_type, {});
102     return constant_mgr->GetDefiningInstruction(constant, type_id);
103   } else {
104     return null_constant_type_map_[type_id];
105   }
106 }
107 
AddNullInitializerToVariable(Instruction * constant_inst,Instruction * variable_inst)108 void GenerateWebGPUInitializersPass::AddNullInitializerToVariable(
109     Instruction* constant_inst, Instruction* variable_inst) {
110   auto constant_id = constant_inst->result_id();
111   variable_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {constant_id}));
112   get_def_use_mgr()->AnalyzeInstUse(variable_inst);
113 }
114 
115 }  // namespace opt
116 }  // namespace spvtools
117