• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/set_spec_constant_default_value_pass.h"
16 
17 #include <algorithm>
18 #include <cctype>
19 #include <cstring>
20 #include <tuple>
21 #include <vector>
22 
23 #include "source/opt/def_use_manager.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26 #include "source/util/parse_number.h"
27 #include "spirv-tools/libspirv.h"
28 
29 namespace spvtools {
30 namespace opt {
31 namespace {
32 using utils::EncodeNumberStatus;
33 using utils::NumberType;
34 using utils::ParseAndEncodeNumber;
35 using utils::ParseNumber;
36 
37 // Given a numeric value in a null-terminated c string and the expected type of
38 // the value, parses the string and encodes it in a vector of words. If the
39 // value is a scalar integer or floating point value, encodes the value in
40 // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector
41 // with single word with value 0 or 1 respectively. Returns the vector
42 // containing the encoded value on success. Otherwise returns an empty vector.
ParseDefaultValueStr(const char * text,const analysis::Type * type)43 std::vector<uint32_t> ParseDefaultValueStr(const char* text,
44                                            const analysis::Type* type) {
45   std::vector<uint32_t> result;
46   if (!strcmp(text, "true") && type->AsBool()) {
47     result.push_back(1u);
48   } else if (!strcmp(text, "false") && type->AsBool()) {
49     result.push_back(0u);
50   } else {
51     NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
52     if (const auto* IT = type->AsInteger()) {
53       number_type.bitwidth = IT->width();
54       number_type.kind =
55           IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
56     } else if (const auto* FT = type->AsFloat()) {
57       number_type.bitwidth = FT->width();
58       number_type.kind = SPV_NUMBER_FLOATING;
59     } else {
60       // Does not handle types other then boolean, integer or float. Returns
61       // empty vector.
62       result.clear();
63       return result;
64     }
65     EncodeNumberStatus rc = ParseAndEncodeNumber(
66         text, number_type, [&result](uint32_t word) { result.push_back(word); },
67         nullptr);
68     // Clear the result vector on failure.
69     if (rc != EncodeNumberStatus::kSuccess) {
70       result.clear();
71     }
72   }
73   return result;
74 }
75 
76 // Given a bit pattern and a type, checks if the bit pattern is compatible
77 // with the type. If so, returns the bit pattern, otherwise returns an empty
78 // bit pattern. If the given bit pattern is empty, returns an empty bit
79 // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
80 // to be returned is determined with the following standard:
81 //   If any words in the input bit pattern are non zero, returns a bit pattern
82 //   with 0x1, which represents a 'true'.
83 //   If all words in the bit pattern are zero, returns a bit pattern with 0x0,
84 //   which represents a 'false'.
85 // For integer and floating point types narrower than 32 bits, the upper bits
86 // in the input bit pattern are ignored.  Instead the upper bits are set
87 // according to SPIR-V literal requirements: sign extend a signed integer, and
88 // otherwise set the upper bits to zero.
ParseDefaultValueBitPattern(const std::vector<uint32_t> & input_bit_pattern,const analysis::Type * type)89 std::vector<uint32_t> ParseDefaultValueBitPattern(
90     const std::vector<uint32_t>& input_bit_pattern,
91     const analysis::Type* type) {
92   std::vector<uint32_t> result;
93   if (type->AsBool()) {
94     if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
95                     [](uint32_t i) { return i != 0; })) {
96       result.push_back(1u);
97     } else {
98       result.push_back(0u);
99     }
100     return result;
101   } else if (const auto* IT = type->AsInteger()) {
102     const auto width = IT->width();
103     assert(width > 0);
104     const auto adjusted_width = std::max(32u, width);
105     if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
106       result = std::vector<uint32_t>(input_bit_pattern);
107       if (width < 32) {
108         const uint32_t high_active_bit = (1u << width) >> 1;
109         if (IT->IsSigned() && (high_active_bit & result[0])) {
110           // Sign extend.  This overwrites the sign bit again, but that's ok.
111           result[0] = result[0] | ~(high_active_bit - 1);
112         } else {
113           // Upper bits must be zero.
114           result[0] = result[0] & ((1u << width) - 1);
115         }
116       }
117       return result;
118     }
119   } else if (const auto* FT = type->AsFloat()) {
120     const auto width = FT->width();
121     const auto adjusted_width = std::max(32u, width);
122     if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
123       result = std::vector<uint32_t>(input_bit_pattern);
124       if (width < 32) {
125         // Upper bits must be zero.
126         result[0] = result[0] & ((1u << width) - 1);
127       }
128       return result;
129     }
130   }
131   result.clear();
132   return result;
133 }
134 
135 // Returns true if the given instruction's result id could have a SpecId
136 // decoration.
CanHaveSpecIdDecoration(const Instruction & inst)137 bool CanHaveSpecIdDecoration(const Instruction& inst) {
138   switch (inst.opcode()) {
139     case spv::Op::OpSpecConstant:
140     case spv::Op::OpSpecConstantFalse:
141     case spv::Op::OpSpecConstantTrue:
142       return true;
143     default:
144       return false;
145   }
146 }
147 
148 // Given a decoration group defining instruction that is decorated with SpecId
149 // decoration, finds the spec constant defining instruction which is the real
150 // target of the SpecId decoration. Returns the spec constant defining
151 // instruction if such an instruction is found, otherwise returns a nullptr.
GetSpecIdTargetFromDecorationGroup(const Instruction & decoration_group_defining_inst,analysis::DefUseManager * def_use_mgr)152 Instruction* GetSpecIdTargetFromDecorationGroup(
153     const Instruction& decoration_group_defining_inst,
154     analysis::DefUseManager* def_use_mgr) {
155   // Find the OpGroupDecorate instruction which consumes the given decoration
156   // group. Note that the given decoration group has SpecId decoration, which
157   // is unique for different spec constants. So the decoration group cannot be
158   // consumed by different OpGroupDecorate instructions. Therefore we only need
159   // the first OpGroupDecoration instruction that uses the given decoration
160   // group.
161   Instruction* group_decorate_inst = nullptr;
162   if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
163                                  [&group_decorate_inst](Instruction* user) {
164                                    if (user->opcode() ==
165                                        spv::Op::OpGroupDecorate) {
166                                      group_decorate_inst = user;
167                                      return false;
168                                    }
169                                    return true;
170                                  }))
171     return nullptr;
172 
173   // Scan through the target ids of the OpGroupDecorate instruction. There
174   // should be only one spec constant target consumes the SpecId decoration.
175   // If multiple target ids are presented in the OpGroupDecorate instruction,
176   // they must be the same one that defined by an eligible spec constant
177   // instruction. If the OpGroupDecorate instruction has different target ids
178   // or a target id is not defined by an eligible spec cosntant instruction,
179   // returns a nullptr.
180   Instruction* target_inst = nullptr;
181   for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) {
182     // All the operands of a OpGroupDecorate instruction should be of type
183     // SPV_OPERAND_TYPE_ID.
184     uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i);
185     Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id);
186 
187     if (!candidate_inst) {
188       continue;
189     }
190 
191     if (!target_inst) {
192       // If the spec constant target has not been found yet, check if the
193       // candidate instruction is the target.
194       if (CanHaveSpecIdDecoration(*candidate_inst)) {
195         target_inst = candidate_inst;
196       } else {
197         // Spec id decoration should not be applied on other instructions.
198         // TODO(qining): Emit an error message in the invalid case once the
199         // error handling is done.
200         return nullptr;
201       }
202     } else {
203       // If the spec constant target has been found, check if the candidate
204       // instruction is the same one as the target. The module is invalid if
205       // the candidate instruction is different with the found target.
206       // TODO(qining): Emit an error messaage in the invalid case once the
207       // error handling is done.
208       if (candidate_inst != target_inst) return nullptr;
209     }
210   }
211   return target_inst;
212 }
213 }  // namespace
214 
Process()215 Pass::Status SetSpecConstantDefaultValuePass::Process() {
216   // The operand index of decoration target in an OpDecorate instruction.
217   constexpr uint32_t kTargetIdOperandIndex = 0;
218   // The operand index of the decoration literal in an OpDecorate instruction.
219   constexpr uint32_t kDecorationOperandIndex = 1;
220   // The operand index of Spec id literal value in an OpDecorate SpecId
221   // instruction.
222   constexpr uint32_t kSpecIdLiteralOperandIndex = 2;
223   // The number of operands in an OpDecorate SpecId instruction.
224   constexpr uint32_t kOpDecorateSpecIdNumOperands = 3;
225   // The in-operand index of the default value in a OpSpecConstant instruction.
226   constexpr uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
227 
228   bool modified = false;
229   // Scan through all the annotation instructions to find 'OpDecorate SpecId'
230   // instructions. Then extract the decoration target of those instructions.
231   // The decoration targets should be spec constant defining instructions with
232   // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants
233   // will be used to look up their new default values in the mapping from
234   // spec id to new default value strings. Once a new default value string
235   // is found for a spec id, the string will be parsed according to the target
236   // spec constant type. The parsed value will be used to replace the original
237   // default value of the target spec constant.
238   for (Instruction& inst : context()->annotations()) {
239     // Only process 'OpDecorate SpecId' instructions
240     if (inst.opcode() != spv::Op::OpDecorate) continue;
241     if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue;
242     if (inst.GetSingleWordInOperand(kDecorationOperandIndex) !=
243         uint32_t(spv::Decoration::SpecId)) {
244       continue;
245     }
246 
247     // 'inst' is an OpDecorate SpecId instruction.
248     uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex);
249     uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex);
250 
251     // Find the spec constant defining instruction. Note that the
252     // target_id might be a decoration group id.
253     Instruction* spec_inst = nullptr;
254     if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
255       if (target_inst->opcode() == spv::Op::OpDecorationGroup) {
256         spec_inst =
257             GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
258       } else {
259         spec_inst = target_inst;
260       }
261     } else {
262       continue;
263     }
264     if (!spec_inst) continue;
265 
266     // Get the default value bit pattern for this spec id.
267     std::vector<uint32_t> bit_pattern;
268 
269     if (spec_id_to_value_str_.size() != 0) {
270       // Search for the new string-form default value for this spec id.
271       auto iter = spec_id_to_value_str_.find(spec_id);
272       if (iter == spec_id_to_value_str_.end()) {
273         continue;
274       }
275 
276       // Gets the string of the default value and parses it to bit pattern
277       // with the type of the spec constant.
278       const std::string& default_value_str = iter->second;
279       bit_pattern = ParseDefaultValueStr(
280           default_value_str.c_str(),
281           context()->get_type_mgr()->GetType(spec_inst->type_id()));
282 
283     } else {
284       // Search for the new bit-pattern-form default value for this spec id.
285       auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
286       if (iter == spec_id_to_value_bit_pattern_.end()) {
287         continue;
288       }
289 
290       // Gets the bit-pattern of the default value from the map directly.
291       bit_pattern = ParseDefaultValueBitPattern(
292           iter->second,
293           context()->get_type_mgr()->GetType(spec_inst->type_id()));
294     }
295 
296     if (bit_pattern.empty()) continue;
297 
298     // Update the operand bit patterns of the spec constant defining
299     // instruction.
300     switch (spec_inst->opcode()) {
301       case spv::Op::OpSpecConstant:
302         // If the new value is the same with the original value, no
303         // need to do anything. Otherwise update the operand words.
304         if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex)
305                 .words != bit_pattern) {
306           spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex,
307                                   std::move(bit_pattern));
308           modified = true;
309         }
310         break;
311       case spv::Op::OpSpecConstantTrue:
312         // If the new value is also 'true', no need to change anything.
313         // Otherwise, set the opcode to OpSpecConstantFalse;
314         if (!static_cast<bool>(bit_pattern.front())) {
315           spec_inst->SetOpcode(spv::Op::OpSpecConstantFalse);
316           modified = true;
317         }
318         break;
319       case spv::Op::OpSpecConstantFalse:
320         // If the new value is also 'false', no need to change anything.
321         // Otherwise, set the opcode to OpSpecConstantTrue;
322         if (static_cast<bool>(bit_pattern.front())) {
323           spec_inst->SetOpcode(spv::Op::OpSpecConstantTrue);
324           modified = true;
325         }
326         break;
327       default:
328         break;
329     }
330     // No need to update the DefUse manager, as this pass does not change any
331     // ids.
332   }
333   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
334 }
335 
336 // Returns true if the given char is ':', '\0' or considered as blank space
337 // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
IsSeparator(char ch)338 bool IsSeparator(char ch) {
339   return std::strchr(":\0", ch) || std::isspace(ch) != 0;
340 }
341 
342 std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap>
ParseDefaultValuesString(const char * str)343 SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) {
344   if (!str) return nullptr;
345 
346   auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>();
347 
348   // The parsing loop, break when points to the end.
349   while (*str) {
350     // Find the spec id.
351     while (std::isspace(*str)) str++;  // skip leading spaces.
352     const char* entry_begin = str;
353     while (!IsSeparator(*str)) str++;
354     const char* entry_end = str;
355     std::string spec_id_str(entry_begin, entry_end - entry_begin);
356     uint32_t spec_id = 0;
357     if (!ParseNumber(spec_id_str.c_str(), &spec_id)) {
358       // The spec id is not a valid uint32 number.
359       return nullptr;
360     }
361     auto iter = spec_id_to_value->find(spec_id);
362     if (iter != spec_id_to_value->end()) {
363       // Same spec id has been defined before
364       return nullptr;
365     }
366     // Find the ':', spaces between the spec id and the ':' are not allowed.
367     if (*str++ != ':') {
368       // ':' not found
369       return nullptr;
370     }
371     // Find the value string
372     const char* val_begin = str;
373     while (!IsSeparator(*str)) str++;
374     const char* val_end = str;
375     if (val_end == val_begin) {
376       // Value string is empty.
377       return nullptr;
378     }
379     // Update the mapping with spec id and value string.
380     (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin);
381 
382     // Skip trailing spaces.
383     while (std::isspace(*str)) str++;
384   }
385 
386   return spec_id_to_value;
387 }
388 
389 }  // namespace opt
390 }  // namespace spvtools
391