1 /* Copyright (c) 2015-2019 The Khronos Group Inc. 2 * Copyright (c) 2015-2019 Valve Corporation 3 * Copyright (c) 2015-2019 LunarG, Inc. 4 * Copyright (C) 2015-2019 Google Inc. 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 * 18 * Author: Chris Forbes <chrisf@ijw.co.nz> 19 */ 20 #ifndef VULKAN_SHADER_VALIDATION_H 21 #define VULKAN_SHADER_VALIDATION_H 22 23 #include <unordered_map> 24 25 #include <SPIRV/spirv.hpp> 26 #include <generated/spirv_tools_commit_id.h> 27 #include "spirv-tools/optimizer.hpp" 28 29 // A forward iterator over spirv instructions. Provides easy access to len, opcode, and content words 30 // without the caller needing to care too much about the physical SPIRV module layout. 31 struct spirv_inst_iter { 32 std::vector<uint32_t>::const_iterator zero; 33 std::vector<uint32_t>::const_iterator it; 34 lenspirv_inst_iter35 uint32_t len() const { 36 auto result = *it >> 16; 37 assert(result > 0); 38 return result; 39 } 40 opcodespirv_inst_iter41 uint32_t opcode() { return *it & 0x0ffffu; } 42 wordspirv_inst_iter43 uint32_t const &word(unsigned n) const { 44 assert(n < len()); 45 return it[n]; 46 } 47 offsetspirv_inst_iter48 uint32_t offset() { return (uint32_t)(it - zero); } 49 spirv_inst_iterspirv_inst_iter50 spirv_inst_iter() {} 51 spirv_inst_iterspirv_inst_iter52 spirv_inst_iter(std::vector<uint32_t>::const_iterator zero, std::vector<uint32_t>::const_iterator it) : zero(zero), it(it) {} 53 54 bool operator==(spirv_inst_iter const &other) const { return it == other.it; } 55 56 bool operator!=(spirv_inst_iter const &other) const { return it != other.it; } 57 58 spirv_inst_iter operator++(int) { // x++ 59 spirv_inst_iter ii = *this; 60 it += len(); 61 return ii; 62 } 63 64 spirv_inst_iter operator++() { // ++x; 65 it += len(); 66 return *this; 67 } 68 69 // The iterator and the value are the same thing. 70 spirv_inst_iter &operator*() { return *this; } 71 spirv_inst_iter const &operator*() const { return *this; } 72 }; 73 74 struct decoration_set { 75 enum { 76 location_bit = 1 << 0, 77 patch_bit = 1 << 1, 78 relaxed_precision_bit = 1 << 2, 79 block_bit = 1 << 3, 80 buffer_block_bit = 1 << 4, 81 component_bit = 1 << 5, 82 input_attachment_index_bit = 1 << 6, 83 descriptor_set_bit = 1 << 7, 84 binding_bit = 1 << 8, 85 nonwritable_bit = 1 << 9, 86 builtin_bit = 1 << 10, 87 }; 88 uint32_t flags = 0; 89 uint32_t location = static_cast<uint32_t>(-1); 90 uint32_t component = 0; 91 uint32_t input_attachment_index = 0; 92 uint32_t descriptor_set = 0; 93 uint32_t binding = 0; 94 uint32_t builtin = static_cast<uint32_t>(-1); 95 mergedecoration_set96 void merge(decoration_set const &other) { 97 if (other.flags & location_bit) location = other.location; 98 if (other.flags & component_bit) component = other.component; 99 if (other.flags & input_attachment_index_bit) input_attachment_index = other.input_attachment_index; 100 if (other.flags & descriptor_set_bit) descriptor_set = other.descriptor_set; 101 if (other.flags & binding_bit) binding = other.binding; 102 if (other.flags & builtin_bit) builtin = other.builtin; 103 flags |= other.flags; 104 } 105 106 void add(uint32_t decoration, uint32_t value); 107 }; 108 109 struct SHADER_MODULE_STATE { 110 // The spirv image itself 111 std::vector<uint32_t> words; 112 // A mapping of <id> to the first word of its def. this is useful because walking type 113 // trees, constant expressions, etc requires jumping all over the instruction stream. 114 std::unordered_map<unsigned, unsigned> def_index; 115 std::unordered_map<unsigned, decoration_set> decorations; 116 struct EntryPoint { 117 uint32_t offset; 118 VkShaderStageFlags stage; 119 }; 120 std::unordered_multimap<std::string, EntryPoint> entry_points; 121 bool has_valid_spirv; 122 VkShaderModule vk_shader_module; 123 uint32_t gpu_validation_shader_id; 124 PreprocessShaderBinarySHADER_MODULE_STATE125 std::vector<uint32_t> PreprocessShaderBinary(uint32_t *src_binary, size_t binary_size, spv_target_env env) { 126 std::vector<uint32_t> src(src_binary, src_binary + binary_size / sizeof(uint32_t)); 127 128 // Check if there are any group decoration instructions, and flatten them if found. 129 bool has_group_decoration = false; 130 bool done = false; 131 132 // Walk through the first part of the SPIR-V module, looking for group decoration instructions. 133 // Skip the header (5 words). 134 auto itr = spirv_inst_iter(src.begin(), src.begin() + 5); 135 auto itrend = spirv_inst_iter(src.begin(), src.end()); 136 while (itr != itrend && !done) { 137 spv::Op opcode = (spv::Op)itr.opcode(); 138 switch (opcode) { 139 case spv::OpDecorationGroup: 140 case spv::OpGroupDecorate: 141 case spv::OpGroupMemberDecorate: 142 has_group_decoration = true; 143 done = true; 144 break; 145 case spv::OpFunction: 146 // An OpFunction indicates there are no more decorations 147 done = true; 148 break; 149 default: 150 break; 151 } 152 itr++; 153 } 154 155 if (has_group_decoration) { 156 spvtools::Optimizer optimizer(env); 157 optimizer.RegisterPass(spvtools::CreateFlattenDecorationPass()); 158 std::vector<uint32_t> optimized_binary; 159 // Run optimizer to flatten decorations only, set skip_validation so as to not re-run validator 160 auto result = 161 optimizer.Run(src_binary, binary_size / sizeof(uint32_t), &optimized_binary, spvtools::ValidatorOptions(), true); 162 if (result) { 163 return optimized_binary; 164 } 165 } 166 // Return the original module. 167 return src; 168 } 169 SHADER_MODULE_STATESHADER_MODULE_STATE170 SHADER_MODULE_STATE(VkShaderModuleCreateInfo const *pCreateInfo, VkShaderModule shaderModule, spv_target_env env, 171 uint32_t unique_shader_id) 172 : words(PreprocessShaderBinary((uint32_t *)pCreateInfo->pCode, pCreateInfo->codeSize, env)), 173 def_index(), 174 has_valid_spirv(true), 175 vk_shader_module(shaderModule), 176 gpu_validation_shader_id(unique_shader_id) { 177 BuildDefIndex(); 178 } 179 SHADER_MODULE_STATESHADER_MODULE_STATE180 SHADER_MODULE_STATE() : has_valid_spirv(false), vk_shader_module(VK_NULL_HANDLE) {} 181 get_decorationsSHADER_MODULE_STATE182 decoration_set get_decorations(unsigned id) const { 183 // return the actual decorations for this id, or a default set. 184 auto it = decorations.find(id); 185 if (it != decorations.end()) return it->second; 186 return decoration_set(); 187 } 188 189 // Expose begin() / end() to enable range-based for beginSHADER_MODULE_STATE190 spirv_inst_iter begin() const { return spirv_inst_iter(words.begin(), words.begin() + 5); } // First insn endSHADER_MODULE_STATE191 spirv_inst_iter end() const { return spirv_inst_iter(words.begin(), words.end()); } // Just past last insn 192 // Given an offset into the module, produce an iterator there. atSHADER_MODULE_STATE193 spirv_inst_iter at(unsigned offset) const { return spirv_inst_iter(words.begin(), words.begin() + offset); } 194 195 // Gets an iterator to the definition of an id get_defSHADER_MODULE_STATE196 spirv_inst_iter get_def(unsigned id) const { 197 auto it = def_index.find(id); 198 if (it == def_index.end()) { 199 return end(); 200 } 201 return at(it->second); 202 } 203 204 void BuildDefIndex(); 205 }; 206 207 class ValidationCache { 208 // hashes of shaders that have passed validation before, and can be skipped. 209 // we don't store negative results, as we would have to also store what was 210 // wrong with them; also, we expect they will get fixed, so we're less 211 // likely to see them again. 212 std::unordered_set<uint32_t> good_shader_hashes; ValidationCache()213 ValidationCache() {} 214 215 public: Create(VkValidationCacheCreateInfoEXT const * pCreateInfo)216 static VkValidationCacheEXT Create(VkValidationCacheCreateInfoEXT const *pCreateInfo) { 217 auto cache = new ValidationCache(); 218 cache->Load(pCreateInfo); 219 return VkValidationCacheEXT(cache); 220 } 221 Load(VkValidationCacheCreateInfoEXT const * pCreateInfo)222 void Load(VkValidationCacheCreateInfoEXT const *pCreateInfo) { 223 const auto headerSize = 2 * sizeof(uint32_t) + VK_UUID_SIZE; 224 auto size = headerSize; 225 if (!pCreateInfo->pInitialData || pCreateInfo->initialDataSize < size) return; 226 227 uint32_t const *data = (uint32_t const *)pCreateInfo->pInitialData; 228 if (data[0] != size) return; 229 if (data[1] != VK_VALIDATION_CACHE_HEADER_VERSION_ONE_EXT) return; 230 uint8_t expected_uuid[VK_UUID_SIZE]; 231 Sha1ToVkUuid(SPIRV_TOOLS_COMMIT_ID, expected_uuid); 232 if (memcmp(&data[2], expected_uuid, VK_UUID_SIZE) != 0) return; // different version 233 234 data = (uint32_t const *)(reinterpret_cast<uint8_t const *>(data) + headerSize); 235 236 for (; size < pCreateInfo->initialDataSize; data++, size += sizeof(uint32_t)) { 237 good_shader_hashes.insert(*data); 238 } 239 } 240 Write(size_t * pDataSize,void * pData)241 void Write(size_t *pDataSize, void *pData) { 242 const auto headerSize = 2 * sizeof(uint32_t) + VK_UUID_SIZE; // 4 bytes for header size + 4 bytes for version number + UUID 243 if (!pData) { 244 *pDataSize = headerSize + good_shader_hashes.size() * sizeof(uint32_t); 245 return; 246 } 247 248 if (*pDataSize < headerSize) { 249 *pDataSize = 0; 250 return; // Too small for even the header! 251 } 252 253 uint32_t *out = (uint32_t *)pData; 254 size_t actualSize = headerSize; 255 256 // Write the header 257 *out++ = headerSize; 258 *out++ = VK_VALIDATION_CACHE_HEADER_VERSION_ONE_EXT; 259 Sha1ToVkUuid(SPIRV_TOOLS_COMMIT_ID, reinterpret_cast<uint8_t *>(out)); 260 out = (uint32_t *)(reinterpret_cast<uint8_t *>(out) + VK_UUID_SIZE); 261 262 for (auto it = good_shader_hashes.begin(); it != good_shader_hashes.end() && actualSize < *pDataSize; 263 it++, out++, actualSize += sizeof(uint32_t)) { 264 *out = *it; 265 } 266 267 *pDataSize = actualSize; 268 } 269 Merge(ValidationCache const * other)270 void Merge(ValidationCache const *other) { 271 good_shader_hashes.reserve(good_shader_hashes.size() + other->good_shader_hashes.size()); 272 for (auto h : other->good_shader_hashes) good_shader_hashes.insert(h); 273 } 274 275 static uint32_t MakeShaderHash(VkShaderModuleCreateInfo const *smci); 276 Contains(uint32_t hash)277 bool Contains(uint32_t hash) { return good_shader_hashes.count(hash) != 0; } 278 Insert(uint32_t hash)279 void Insert(uint32_t hash) { good_shader_hashes.insert(hash); } 280 281 private: Sha1ToVkUuid(const char * sha1_str,uint8_t uuid[VK_UUID_SIZE])282 void Sha1ToVkUuid(const char *sha1_str, uint8_t uuid[VK_UUID_SIZE]) { 283 // Convert sha1_str from a hex string to binary. We only need VK_UUID_BYTES of 284 // output, so pad with zeroes if the input string is shorter than that, and truncate 285 // if it's longer. 286 char padded_sha1_str[2 * VK_UUID_SIZE + 1] = {}; 287 strncpy(padded_sha1_str, sha1_str, 2 * VK_UUID_SIZE + 1); 288 char byte_str[3] = {}; 289 for (uint32_t i = 0; i < VK_UUID_SIZE; ++i) { 290 byte_str[0] = padded_sha1_str[2 * i + 0]; 291 byte_str[1] = padded_sha1_str[2 * i + 1]; 292 uuid[i] = static_cast<uint8_t>(strtol(byte_str, NULL, 16)); 293 } 294 } 295 }; 296 297 #endif // VULKAN_SHADER_VALIDATION_H 298