1 // Copyright (c) 2017 Pierre Moreau
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "spirv-tools/linker.hpp"
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <iostream>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28
29 #include "source/assembly_grammar.h"
30 #include "source/diagnostic.h"
31 #include "source/opt/build_module.h"
32 #include "source/opt/compact_ids_pass.h"
33 #include "source/opt/decoration_manager.h"
34 #include "source/opt/ir_builder.h"
35 #include "source/opt/ir_loader.h"
36 #include "source/opt/pass_manager.h"
37 #include "source/opt/remove_duplicates_pass.h"
38 #include "source/opt/remove_unused_interface_variables_pass.h"
39 #include "source/opt/type_manager.h"
40 #include "source/spirv_constant.h"
41 #include "source/spirv_target_env.h"
42 #include "source/util/make_unique.h"
43 #include "source/util/string_utils.h"
44 #include "spirv-tools/libspirv.hpp"
45
46 namespace spvtools {
47 namespace {
48
49 using opt::Instruction;
50 using opt::InstructionBuilder;
51 using opt::IRContext;
52 using opt::Module;
53 using opt::PassManager;
54 using opt::RemoveDuplicatesPass;
55 using opt::analysis::DecorationManager;
56 using opt::analysis::DefUseManager;
57 using opt::analysis::Function;
58 using opt::analysis::Type;
59 using opt::analysis::TypeManager;
60
61 // Stores various information about an imported or exported symbol.
62 struct LinkageSymbolInfo {
63 spv::Id id; // ID of the symbol
64 spv::Id type_id; // ID of the type of the symbol
65 std::string name; // unique name defining the symbol and used for matching
66 // imports and exports together
67 std::vector<spv::Id> parameter_ids; // ID of the parameters of the symbol, if
68 // it is a function
69 };
70 struct LinkageEntry {
71 LinkageSymbolInfo imported_symbol;
72 LinkageSymbolInfo exported_symbol;
73
LinkageEntryspvtools::__anon38ae8e2d0111::LinkageEntry74 LinkageEntry(const LinkageSymbolInfo& import_info,
75 const LinkageSymbolInfo& export_info)
76 : imported_symbol(import_info), exported_symbol(export_info) {}
77 };
78 using LinkageTable = std::vector<LinkageEntry>;
79
80 // Shifts the IDs used in each binary of |modules| so that they occupy a
81 // disjoint range from the other binaries, and compute the new ID bound which
82 // is returned in |max_id_bound|.
83 //
84 // Both |modules| and |max_id_bound| should not be null, and |modules| should
85 // not be empty either. Furthermore |modules| should not contain any null
86 // pointers.
87 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
88 std::vector<opt::Module*>* modules,
89 uint32_t* max_id_bound);
90
91 // Generates the header for the linked module and returns it in |header|.
92 //
93 // |header| should not be null, |modules| should not be empty and pointers
94 // should be non-null. |max_id_bound| should be strictly greater than 0.
95 spv_result_t GenerateHeader(const MessageConsumer& consumer,
96 const std::vector<opt::Module*>& modules,
97 uint32_t max_id_bound, opt::ModuleHeader* header,
98 const LinkerOptions& options);
99
100 // Merge all the modules from |in_modules| into a single module owned by
101 // |linked_context|.
102 //
103 // |linked_context| should not be null.
104 spv_result_t MergeModules(const MessageConsumer& consumer,
105 const std::vector<Module*>& in_modules,
106 const AssemblyGrammar& grammar,
107 IRContext* linked_context);
108
109 // Compute all pairs of import and export and return it in |linkings_to_do|.
110 //
111 // |linkings_to_do should not be null. Built-in symbols will be ignored.
112 //
113 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
114 // currently not handled. (You could have a group being
115 // applied to a single ID.)
116 // TODO(pierremoreau): What should be the proper behaviour with built-in
117 // symbols?
118 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
119 const opt::IRContext& linked_context,
120 const DefUseManager& def_use_manager,
121 const DecorationManager& decoration_manager,
122 bool allow_partial_linkage,
123 LinkageTable* linkings_to_do);
124
125 // Checks that for each pair of import and export, the import and export have
126 // the same type as well as the same decorations.
127 //
128 // TODO(pierremoreau): Decorations on functions parameters are currently not
129 // checked.
130 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
131 const LinkageTable& linkings_to_do,
132 bool allow_ptr_type_mismatch,
133 opt::IRContext* context);
134
135 // Remove linkage specific instructions, such as prototypes of imported
136 // functions, declarations of imported variables, import (and export if
137 // necessary) linkage attributes.
138 //
139 // |linked_context| and |decoration_manager| should not be null, and the
140 // 'RemoveDuplicatePass' should be run first.
141 //
142 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
143 // currently not handled. (You could have a group being
144 // applied to a single ID.)
145 spv_result_t RemoveLinkageSpecificInstructions(
146 const MessageConsumer& consumer, const LinkerOptions& options,
147 const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
148 opt::IRContext* linked_context);
149
150 // Verify that the unique ids of each instruction in |linked_context| (i.e. the
151 // merged module) are truly unique. Does not check the validity of other ids
152 spv_result_t VerifyIds(const MessageConsumer& consumer,
153 opt::IRContext* linked_context);
154
155 // Verify that the universal limits are not crossed, and warn the user
156 // otherwise.
157 //
158 // TODO(pierremoreau):
159 // - Verify against the limits of the environment (e.g. Vulkan limits if
160 // consuming vulkan1.x)
161 spv_result_t VerifyLimits(const MessageConsumer& consumer,
162 const opt::IRContext& linked_context);
163
ShiftIdsInModules(const MessageConsumer & consumer,std::vector<opt::Module * > * modules,uint32_t * max_id_bound)164 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
165 std::vector<opt::Module*>* modules,
166 uint32_t* max_id_bound) {
167 spv_position_t position = {};
168
169 if (modules == nullptr)
170 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
171 << "|modules| of ShiftIdsInModules should not be null.";
172 if (modules->empty())
173 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
174 << "|modules| of ShiftIdsInModules should not be empty.";
175 if (max_id_bound == nullptr)
176 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
177 << "|max_id_bound| of ShiftIdsInModules should not be null.";
178
179 const size_t id_bound =
180 std::accumulate(modules->begin(), modules->end(), static_cast<size_t>(1),
181 [](const size_t& accumulation, opt::Module* module) {
182 return accumulation + module->IdBound() - 1u;
183 });
184 if (id_bound > std::numeric_limits<uint32_t>::max())
185 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
186 << "Too many IDs (" << id_bound
187 << "): combining all modules would overflow the 32-bit word of the "
188 "SPIR-V header.";
189
190 *max_id_bound = static_cast<uint32_t>(id_bound);
191
192 uint32_t id_offset = modules->front()->IdBound() - 1u;
193 for (auto module_iter = modules->begin() + 1; module_iter != modules->end();
194 ++module_iter) {
195 Module* module = *module_iter;
196 module->ForEachInst([&id_offset](Instruction* insn) {
197 insn->ForEachId([&id_offset](uint32_t* id) { *id += id_offset; });
198 });
199 id_offset += module->IdBound() - 1u;
200
201 // Invalidate the DefUseManager
202 module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse);
203 }
204
205 return SPV_SUCCESS;
206 }
207
GenerateHeader(const MessageConsumer & consumer,const std::vector<opt::Module * > & modules,uint32_t max_id_bound,opt::ModuleHeader * header,const LinkerOptions & options)208 spv_result_t GenerateHeader(const MessageConsumer& consumer,
209 const std::vector<opt::Module*>& modules,
210 uint32_t max_id_bound, opt::ModuleHeader* header,
211 const LinkerOptions& options) {
212 spv_position_t position = {};
213
214 if (modules.empty())
215 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
216 << "|modules| of GenerateHeader should not be empty.";
217 if (max_id_bound == 0u)
218 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
219 << "|max_id_bound| of GenerateHeader should not be null.";
220
221 uint32_t linked_version = modules.front()->version();
222 for (std::size_t i = 1; i < modules.size(); ++i) {
223 const uint32_t module_version = modules[i]->version();
224 if (options.GetUseHighestVersion()) {
225 linked_version = std::max(linked_version, module_version);
226 } else if (module_version != linked_version) {
227 return DiagnosticStream({0, 0, 1}, consumer, "", SPV_ERROR_INTERNAL)
228 << "Conflicting SPIR-V versions: "
229 << SPV_SPIRV_VERSION_MAJOR_PART(linked_version) << "."
230 << SPV_SPIRV_VERSION_MINOR_PART(linked_version)
231 << " (input modules 1 through " << i << ") vs "
232 << SPV_SPIRV_VERSION_MAJOR_PART(module_version) << "."
233 << SPV_SPIRV_VERSION_MINOR_PART(module_version)
234 << " (input module " << (i + 1) << ").";
235 }
236 }
237
238 header->magic_number = spv::MagicNumber;
239 header->version = linked_version;
240 header->generator = SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_LINKER, 0);
241 header->bound = max_id_bound;
242 header->schema = 0u;
243
244 return SPV_SUCCESS;
245 }
246
MergeModules(const MessageConsumer & consumer,const std::vector<Module * > & input_modules,const AssemblyGrammar & grammar,IRContext * linked_context)247 spv_result_t MergeModules(const MessageConsumer& consumer,
248 const std::vector<Module*>& input_modules,
249 const AssemblyGrammar& grammar,
250 IRContext* linked_context) {
251 spv_position_t position = {};
252
253 if (linked_context == nullptr)
254 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
255 << "|linked_module| of MergeModules should not be null.";
256 Module* linked_module = linked_context->module();
257
258 if (input_modules.empty()) return SPV_SUCCESS;
259
260 for (const auto& module : input_modules)
261 for (const auto& inst : module->capabilities())
262 linked_module->AddCapability(
263 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
264
265 for (const auto& module : input_modules)
266 for (const auto& inst : module->extensions())
267 linked_module->AddExtension(
268 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
269
270 for (const auto& module : input_modules)
271 for (const auto& inst : module->ext_inst_imports())
272 linked_module->AddExtInstImport(
273 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
274
275 const Instruction* linked_memory_model_inst =
276 input_modules.front()->GetMemoryModel();
277 if (linked_memory_model_inst == nullptr) {
278 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
279 << "Input module 1 is lacking an OpMemoryModel instruction.";
280 }
281 const uint32_t linked_addressing_model =
282 linked_memory_model_inst->GetSingleWordOperand(0u);
283 const uint32_t linked_memory_model =
284 linked_memory_model_inst->GetSingleWordOperand(1u);
285
286 for (std::size_t i = 1; i < input_modules.size(); ++i) {
287 const Module* module = input_modules[i];
288 const Instruction* memory_model_inst = module->GetMemoryModel();
289 if (memory_model_inst == nullptr)
290 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
291 << "Input module " << (i + 1)
292 << " is lacking an OpMemoryModel instruction.";
293
294 const uint32_t module_addressing_model =
295 memory_model_inst->GetSingleWordOperand(0u);
296 if (module_addressing_model != linked_addressing_model) {
297 spv_operand_desc linked_desc = nullptr, module_desc = nullptr;
298 grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
299 linked_addressing_model, &linked_desc);
300 grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
301 module_addressing_model, &module_desc);
302 return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
303 << "Conflicting addressing models: " << linked_desc->name
304 << " (input modules 1 through " << i << ") vs "
305 << module_desc->name << " (input module " << (i + 1) << ").";
306 }
307
308 const uint32_t module_memory_model =
309 memory_model_inst->GetSingleWordOperand(1u);
310 if (module_memory_model != linked_memory_model) {
311 spv_operand_desc linked_desc = nullptr, module_desc = nullptr;
312 grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, linked_memory_model,
313 &linked_desc);
314 grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, module_memory_model,
315 &module_desc);
316 return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
317 << "Conflicting memory models: " << linked_desc->name
318 << " (input modules 1 through " << i << ") vs "
319 << module_desc->name << " (input module " << (i + 1) << ").";
320 }
321 }
322 linked_module->SetMemoryModel(std::unique_ptr<Instruction>(
323 linked_memory_model_inst->Clone(linked_context)));
324
325 std::vector<std::pair<uint32_t, std::string>> entry_points;
326 for (const auto& module : input_modules)
327 for (const auto& inst : module->entry_points()) {
328 const uint32_t model = inst.GetSingleWordInOperand(0);
329 const std::string name = inst.GetInOperand(2).AsString();
330 const auto i = std::find_if(
331 entry_points.begin(), entry_points.end(),
332 [model, name](const std::pair<uint32_t, std::string>& v) {
333 return v.first == model && v.second == name;
334 });
335 if (i != entry_points.end()) {
336 spv_operand_desc desc = nullptr;
337 grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc);
338 return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
339 << "The entry point \"" << name << "\", with execution model "
340 << desc->name << ", was already defined.";
341 }
342 linked_module->AddEntryPoint(
343 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
344 entry_points.emplace_back(model, name);
345 }
346
347 for (const auto& module : input_modules)
348 for (const auto& inst : module->execution_modes())
349 linked_module->AddExecutionMode(
350 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
351
352 for (const auto& module : input_modules)
353 for (const auto& inst : module->debugs1())
354 linked_module->AddDebug1Inst(
355 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
356
357 for (const auto& module : input_modules)
358 for (const auto& inst : module->debugs2())
359 linked_module->AddDebug2Inst(
360 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
361
362 for (const auto& module : input_modules)
363 for (const auto& inst : module->debugs3())
364 linked_module->AddDebug3Inst(
365 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
366
367 for (const auto& module : input_modules)
368 for (const auto& inst : module->ext_inst_debuginfo())
369 linked_module->AddExtInstDebugInfo(
370 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
371
372 // If the generated module uses SPIR-V 1.1 or higher, add an
373 // OpModuleProcessed instruction about the linking step.
374 if (linked_module->version() >= SPV_SPIRV_VERSION_WORD(1, 1)) {
375 const std::string processed_string("Linked by SPIR-V Tools Linker");
376 std::vector<uint32_t> processed_words =
377 spvtools::utils::MakeVector(processed_string);
378 linked_module->AddDebug3Inst(std::unique_ptr<Instruction>(
379 new Instruction(linked_context, spv::Op::OpModuleProcessed, 0u, 0u,
380 {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}})));
381 }
382
383 for (const auto& module : input_modules)
384 for (const auto& inst : module->annotations())
385 linked_module->AddAnnotationInst(
386 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
387
388 // TODO(pierremoreau): Since the modules have not been validate, should we
389 // expect spv::StorageClass::Function variables outside
390 // functions?
391 for (const auto& module : input_modules) {
392 for (const auto& inst : module->types_values()) {
393 linked_module->AddType(
394 std::unique_ptr<Instruction>(inst.Clone(linked_context)));
395 }
396 }
397
398 // Process functions and their basic blocks
399 for (const auto& module : input_modules) {
400 for (const auto& func : *module) {
401 std::unique_ptr<opt::Function> cloned_func(func.Clone(linked_context));
402 linked_module->AddFunction(std::move(cloned_func));
403 }
404 }
405
406 return SPV_SUCCESS;
407 }
408
GetImportExportPairs(const MessageConsumer & consumer,const opt::IRContext & linked_context,const DefUseManager & def_use_manager,const DecorationManager & decoration_manager,bool allow_partial_linkage,LinkageTable * linkings_to_do)409 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
410 const opt::IRContext& linked_context,
411 const DefUseManager& def_use_manager,
412 const DecorationManager& decoration_manager,
413 bool allow_partial_linkage,
414 LinkageTable* linkings_to_do) {
415 spv_position_t position = {};
416
417 if (linkings_to_do == nullptr)
418 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
419 << "|linkings_to_do| of GetImportExportPairs should not be empty.";
420
421 std::vector<LinkageSymbolInfo> imports;
422 std::unordered_map<std::string, std::vector<LinkageSymbolInfo>> exports;
423
424 // Figure out the imports and exports
425 for (const auto& decoration : linked_context.annotations()) {
426 if (decoration.opcode() != spv::Op::OpDecorate ||
427 spv::Decoration(decoration.GetSingleWordInOperand(1u)) !=
428 spv::Decoration::LinkageAttributes)
429 continue;
430
431 const spv::Id id = decoration.GetSingleWordInOperand(0u);
432 // Ignore if the targeted symbol is a built-in
433 bool is_built_in = false;
434 for (const auto& id_decoration :
435 decoration_manager.GetDecorationsFor(id, false)) {
436 if (spv::Decoration(id_decoration->GetSingleWordInOperand(1u)) ==
437 spv::Decoration::BuiltIn) {
438 is_built_in = true;
439 break;
440 }
441 }
442 if (is_built_in) {
443 continue;
444 }
445
446 const uint32_t type = decoration.GetSingleWordInOperand(3u);
447
448 LinkageSymbolInfo symbol_info;
449 symbol_info.name = decoration.GetInOperand(2u).AsString();
450 symbol_info.id = id;
451 symbol_info.type_id = 0u;
452
453 // Retrieve the type of the current symbol. This information will be used
454 // when checking that the imported and exported symbols have the same
455 // types.
456 const Instruction* def_inst = def_use_manager.GetDef(id);
457 if (def_inst == nullptr)
458 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
459 << "ID " << id << " is never defined:\n";
460
461 if (def_inst->opcode() == spv::Op::OpVariable) {
462 symbol_info.type_id = def_inst->type_id();
463 } else if (def_inst->opcode() == spv::Op::OpFunction) {
464 symbol_info.type_id = def_inst->GetSingleWordInOperand(1u);
465
466 // range-based for loop calls begin()/end(), but never cbegin()/cend(),
467 // which will not work here.
468 for (auto func_iter = linked_context.module()->cbegin();
469 func_iter != linked_context.module()->cend(); ++func_iter) {
470 if (func_iter->result_id() != id) continue;
471 func_iter->ForEachParam([&symbol_info](const Instruction* inst) {
472 symbol_info.parameter_ids.push_back(inst->result_id());
473 });
474 }
475 } else {
476 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
477 << "Only global variables and functions can be decorated using"
478 << " LinkageAttributes; " << id << " is neither of them.\n";
479 }
480
481 if (spv::LinkageType(type) == spv::LinkageType::Import)
482 imports.push_back(symbol_info);
483 else if (spv::LinkageType(type) == spv::LinkageType::Export)
484 exports[symbol_info.name].push_back(symbol_info);
485 }
486
487 // Find the import/export pairs
488 for (const auto& import : imports) {
489 std::vector<LinkageSymbolInfo> possible_exports;
490 const auto& exp = exports.find(import.name);
491 if (exp != exports.end()) possible_exports = exp->second;
492 if (possible_exports.empty() && !allow_partial_linkage)
493 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
494 << "Unresolved external reference to \"" << import.name << "\".";
495 else if (possible_exports.size() > 1u)
496 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
497 << "Too many external references, " << possible_exports.size()
498 << ", were found for \"" << import.name << "\".";
499
500 if (!possible_exports.empty())
501 linkings_to_do->emplace_back(import, possible_exports.front());
502 }
503
504 return SPV_SUCCESS;
505 }
506
CheckImportExportCompatibility(const MessageConsumer & consumer,const LinkageTable & linkings_to_do,bool allow_ptr_type_mismatch,opt::IRContext * context)507 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
508 const LinkageTable& linkings_to_do,
509 bool allow_ptr_type_mismatch,
510 opt::IRContext* context) {
511 spv_position_t position = {};
512
513 // Ensure the import and export types are the same.
514 const DecorationManager& decoration_manager = *context->get_decoration_mgr();
515 const TypeManager& type_manager = *context->get_type_mgr();
516 for (const auto& linking_entry : linkings_to_do) {
517 Type* imported_symbol_type =
518 type_manager.GetType(linking_entry.imported_symbol.type_id);
519 Type* exported_symbol_type =
520 type_manager.GetType(linking_entry.exported_symbol.type_id);
521 if (!(*imported_symbol_type == *exported_symbol_type)) {
522 Function* imported_symbol_type_func = imported_symbol_type->AsFunction();
523 Function* exported_symbol_type_func = exported_symbol_type->AsFunction();
524
525 if (imported_symbol_type_func && exported_symbol_type_func) {
526 const auto& imported_params = imported_symbol_type_func->param_types();
527 const auto& exported_params = exported_symbol_type_func->param_types();
528 // allow_ptr_type_mismatch allows linking functions where the pointer
529 // type of arguments doesn't match. Everything else still needs to be
530 // equal. This is to workaround LLVM-17+ not having typed pointers and
531 // generated SPIR-Vs not knowing the actual pointer types in some cases.
532 if (allow_ptr_type_mismatch &&
533 imported_params.size() == exported_params.size()) {
534 bool correct = true;
535 for (size_t i = 0; i < imported_params.size(); i++) {
536 const auto& imported_param = imported_params[i];
537 const auto& exported_param = exported_params[i];
538
539 if (!imported_param->IsSame(exported_param) &&
540 (imported_param->kind() != Type::kPointer ||
541 exported_param->kind() != Type::kPointer)) {
542 correct = false;
543 break;
544 }
545 }
546 if (correct) continue;
547 }
548 }
549 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
550 << "Type mismatch on symbol \""
551 << linking_entry.imported_symbol.name
552 << "\" between imported variable/function %"
553 << linking_entry.imported_symbol.id
554 << " and exported variable/function %"
555 << linking_entry.exported_symbol.id << ".";
556 }
557 }
558
559 // Ensure the import and export decorations are similar
560 for (const auto& linking_entry : linkings_to_do) {
561 if (!decoration_manager.HaveTheSameDecorations(
562 linking_entry.imported_symbol.id, linking_entry.exported_symbol.id))
563 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
564 << "Decorations mismatch on symbol \""
565 << linking_entry.imported_symbol.name
566 << "\" between imported variable/function %"
567 << linking_entry.imported_symbol.id
568 << " and exported variable/function %"
569 << linking_entry.exported_symbol.id << ".";
570 // TODO(pierremoreau): Decorations on function parameters should probably
571 // match, except for FuncParamAttr if I understand the
572 // spec correctly.
573 // TODO(pierremoreau): Decorations on the function return type should
574 // match, except for FuncParamAttr.
575 }
576
577 return SPV_SUCCESS;
578 }
579
RemoveLinkageSpecificInstructions(const MessageConsumer & consumer,const LinkerOptions & options,const LinkageTable & linkings_to_do,DecorationManager * decoration_manager,opt::IRContext * linked_context)580 spv_result_t RemoveLinkageSpecificInstructions(
581 const MessageConsumer& consumer, const LinkerOptions& options,
582 const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
583 opt::IRContext* linked_context) {
584 spv_position_t position = {};
585
586 if (decoration_manager == nullptr)
587 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
588 << "|decoration_manager| of RemoveLinkageSpecificInstructions "
589 "should not be empty.";
590 if (linked_context == nullptr)
591 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
592 << "|linked_module| of RemoveLinkageSpecificInstructions should not "
593 "be empty.";
594
595 // TODO(pierremoreau): Remove FuncParamAttr decorations of imported
596 // functions' return type.
597
598 // Remove prototypes of imported functions
599 for (const auto& linking_entry : linkings_to_do) {
600 for (auto func_iter = linked_context->module()->begin();
601 func_iter != linked_context->module()->end();) {
602 if (func_iter->result_id() == linking_entry.imported_symbol.id)
603 func_iter = func_iter.Erase();
604 else
605 ++func_iter;
606 }
607 }
608
609 // Remove declarations of imported variables
610 for (const auto& linking_entry : linkings_to_do) {
611 auto next = linked_context->types_values_begin();
612 for (auto inst = next; inst != linked_context->types_values_end();
613 inst = next) {
614 ++next;
615 if (inst->result_id() == linking_entry.imported_symbol.id) {
616 linked_context->KillInst(&*inst);
617 }
618 }
619 }
620
621 // If partial linkage is allowed, we need an efficient way to check whether
622 // an imported ID had a corresponding export symbol. As uses of the imported
623 // symbol have already been replaced by the exported symbol, use the exported
624 // symbol ID.
625 // TODO(pierremoreau): This will not work if the decoration is applied
626 // through a group, but the linker does not support that
627 // either.
628 std::unordered_set<spv::Id> imports;
629 if (options.GetAllowPartialLinkage()) {
630 imports.reserve(linkings_to_do.size());
631 for (const auto& linking_entry : linkings_to_do)
632 imports.emplace(linking_entry.exported_symbol.id);
633 }
634
635 // Remove import linkage attributes
636 auto next = linked_context->annotation_begin();
637 for (auto inst = next; inst != linked_context->annotation_end();
638 inst = next) {
639 ++next;
640 // If this is an import annotation:
641 // * if we do not allow partial linkage, remove all import annotations;
642 // * otherwise, remove the annotation only if there was a corresponding
643 // export.
644 if (inst->opcode() == spv::Op::OpDecorate &&
645 spv::Decoration(inst->GetSingleWordOperand(1u)) ==
646 spv::Decoration::LinkageAttributes &&
647 spv::LinkageType(inst->GetSingleWordOperand(3u)) ==
648 spv::LinkageType::Import &&
649 (!options.GetAllowPartialLinkage() ||
650 imports.find(inst->GetSingleWordOperand(0u)) != imports.end())) {
651 linked_context->KillInst(&*inst);
652 }
653 }
654
655 // Remove export linkage attributes if making an executable
656 if (!options.GetCreateLibrary()) {
657 next = linked_context->annotation_begin();
658 for (auto inst = next; inst != linked_context->annotation_end();
659 inst = next) {
660 ++next;
661 if (inst->opcode() == spv::Op::OpDecorate &&
662 spv::Decoration(inst->GetSingleWordOperand(1u)) ==
663 spv::Decoration::LinkageAttributes &&
664 spv::LinkageType(inst->GetSingleWordOperand(3u)) ==
665 spv::LinkageType::Export) {
666 linked_context->KillInst(&*inst);
667 }
668 }
669 }
670
671 // Remove Linkage capability if making an executable and partial linkage is
672 // not allowed
673 if (!options.GetCreateLibrary() && !options.GetAllowPartialLinkage()) {
674 for (auto& inst : linked_context->capabilities())
675 if (spv::Capability(inst.GetSingleWordInOperand(0u)) ==
676 spv::Capability::Linkage) {
677 linked_context->KillInst(&inst);
678 // The RemoveDuplicatesPass did remove duplicated capabilities, so we
679 // now there aren’t more spv::Capability::Linkage further down.
680 break;
681 }
682 }
683
684 return SPV_SUCCESS;
685 }
686
VerifyIds(const MessageConsumer & consumer,opt::IRContext * linked_context)687 spv_result_t VerifyIds(const MessageConsumer& consumer,
688 opt::IRContext* linked_context) {
689 std::unordered_set<uint32_t> ids;
690 bool ok = true;
691 linked_context->module()->ForEachInst(
692 [&ids, &ok](const opt::Instruction* inst) {
693 ok &= ids.insert(inst->unique_id()).second;
694 });
695
696 if (!ok) {
697 consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module");
698 return SPV_ERROR_INVALID_ID;
699 }
700
701 return SPV_SUCCESS;
702 }
703
VerifyLimits(const MessageConsumer & consumer,const opt::IRContext & linked_context)704 spv_result_t VerifyLimits(const MessageConsumer& consumer,
705 const opt::IRContext& linked_context) {
706 spv_position_t position = {};
707
708 const uint32_t max_id_bound = linked_context.module()->id_bound();
709 if (max_id_bound >= SPV_LIMIT_RESULT_ID_BOUND)
710 DiagnosticStream({0u, 0u, 4u}, consumer, "", SPV_WARNING)
711 << "The minimum limit of IDs, " << (SPV_LIMIT_RESULT_ID_BOUND - 1)
712 << ", was exceeded:"
713 << " " << max_id_bound << " is the current ID bound.\n"
714 << "The resulting module might not be supported by all "
715 "implementations.";
716
717 size_t num_global_values = 0u;
718 for (const auto& inst : linked_context.module()->types_values()) {
719 num_global_values += inst.opcode() == spv::Op::OpVariable;
720 }
721 if (num_global_values >= SPV_LIMIT_GLOBAL_VARIABLES_MAX)
722 DiagnosticStream(position, consumer, "", SPV_WARNING)
723 << "The minimum limit of global values, "
724 << (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded;"
725 << " " << num_global_values << " global values were found.\n"
726 << "The resulting module might not be supported by all "
727 "implementations.";
728
729 return SPV_SUCCESS;
730 }
731
FixFunctionCallTypes(opt::IRContext & context,const LinkageTable & linkings)732 spv_result_t FixFunctionCallTypes(opt::IRContext& context,
733 const LinkageTable& linkings) {
734 auto mod = context.module();
735 const auto type_manager = context.get_type_mgr();
736 const auto def_use_mgr = context.get_def_use_mgr();
737
738 for (auto& func : *mod) {
739 func.ForEachInst([&](Instruction* inst) {
740 if (inst->opcode() != spv::Op::OpFunctionCall) return;
741 opt::Operand& target = inst->GetInOperand(0);
742
743 // only fix calls to imported functions
744 auto linking = std::find_if(
745 linkings.begin(), linkings.end(), [&](const auto& entry) {
746 return entry.exported_symbol.id == target.AsId();
747 });
748 if (linking == linkings.end()) return;
749
750 auto builder = InstructionBuilder(&context, inst);
751 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
752 auto exported_func_param =
753 def_use_mgr->GetDef(linking->exported_symbol.parameter_ids[i - 1]);
754 const Type* target_type =
755 type_manager->GetType(exported_func_param->type_id());
756 if (target_type->kind() != Type::kPointer) continue;
757
758 opt::Operand& arg = inst->GetInOperand(i);
759 const Type* param_type =
760 type_manager->GetType(def_use_mgr->GetDef(arg.AsId())->type_id());
761
762 // No need to cast if it already matches
763 if (*param_type == *target_type) continue;
764
765 auto new_id = context.TakeNextId();
766
767 // cast to the expected pointer type
768 builder.AddInstruction(MakeUnique<opt::Instruction>(
769 &context, spv::Op::OpBitcast, exported_func_param->type_id(),
770 new_id,
771 opt::Instruction::OperandList(
772 {{SPV_OPERAND_TYPE_ID, {arg.AsId()}}})));
773
774 inst->SetInOperand(i, {new_id});
775 }
776 });
777 }
778 context.InvalidateAnalyses(opt::IRContext::kAnalysisDefUse |
779 opt::IRContext::kAnalysisInstrToBlockMapping);
780 return SPV_SUCCESS;
781 }
782
783 } // namespace
784
Link(const Context & context,const std::vector<std::vector<uint32_t>> & binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)785 spv_result_t Link(const Context& context,
786 const std::vector<std::vector<uint32_t>>& binaries,
787 std::vector<uint32_t>* linked_binary,
788 const LinkerOptions& options) {
789 std::vector<const uint32_t*> binary_ptrs;
790 binary_ptrs.reserve(binaries.size());
791 std::vector<size_t> binary_sizes;
792 binary_sizes.reserve(binaries.size());
793
794 for (const auto& binary : binaries) {
795 binary_ptrs.push_back(binary.data());
796 binary_sizes.push_back(binary.size());
797 }
798
799 return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(),
800 linked_binary, options);
801 }
802
Link(const Context & context,const uint32_t * const * binaries,const size_t * binary_sizes,size_t num_binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)803 spv_result_t Link(const Context& context, const uint32_t* const* binaries,
804 const size_t* binary_sizes, size_t num_binaries,
805 std::vector<uint32_t>* linked_binary,
806 const LinkerOptions& options) {
807 spv_position_t position = {};
808 const spv_context& c_context = context.CContext();
809 const MessageConsumer& consumer = c_context->consumer;
810
811 linked_binary->clear();
812 if (num_binaries == 0u)
813 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
814 << "No modules were given.";
815
816 std::vector<std::unique_ptr<IRContext>> ir_contexts;
817 std::vector<Module*> modules;
818 modules.reserve(num_binaries);
819 for (size_t i = 0u; i < num_binaries; ++i) {
820 const uint32_t schema = binaries[i][4u];
821 if (schema != 0u) {
822 position.index = 4u;
823 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
824 << "Schema is non-zero for module " << i + 1 << ".";
825 }
826
827 std::unique_ptr<IRContext> ir_context = BuildModule(
828 c_context->target_env, consumer, binaries[i], binary_sizes[i]);
829 if (ir_context == nullptr)
830 return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
831 << "Failed to build module " << i + 1 << " out of " << num_binaries
832 << ".";
833 modules.push_back(ir_context->module());
834 ir_contexts.push_back(std::move(ir_context));
835 }
836
837 // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
838 // range from the other binaries, and compute the new ID bound.
839 uint32_t max_id_bound = 0u;
840 spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound);
841 if (res != SPV_SUCCESS) return res;
842
843 // Phase 2: Generate the header
844 opt::ModuleHeader header;
845 res = GenerateHeader(consumer, modules, max_id_bound, &header, options);
846 if (res != SPV_SUCCESS) return res;
847 IRContext linked_context(c_context->target_env, consumer);
848 linked_context.module()->SetHeader(header);
849
850 // Phase 3: Merge all the binaries into a single one.
851 AssemblyGrammar grammar(c_context);
852 res = MergeModules(consumer, modules, grammar, &linked_context);
853 if (res != SPV_SUCCESS) return res;
854
855 if (options.GetVerifyIds()) {
856 res = VerifyIds(consumer, &linked_context);
857 if (res != SPV_SUCCESS) return res;
858 }
859
860 // Phase 4: Remove duplicates
861 PassManager manager;
862 manager.SetMessageConsumer(consumer);
863 manager.AddPass<RemoveDuplicatesPass>();
864 opt::Pass::Status pass_res = manager.Run(&linked_context);
865 if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
866
867 // Phase 5: Find the import/export pairs
868 LinkageTable linkings_to_do;
869 res = GetImportExportPairs(consumer, linked_context,
870 *linked_context.get_def_use_mgr(),
871 *linked_context.get_decoration_mgr(),
872 options.GetAllowPartialLinkage(), &linkings_to_do);
873 if (res != SPV_SUCCESS) return res;
874
875 // Phase 6: Ensure the import and export have the same types and decorations.
876 res = CheckImportExportCompatibility(consumer, linkings_to_do,
877 options.GetAllowPtrTypeMismatch(),
878 &linked_context);
879 if (res != SPV_SUCCESS) return res;
880
881 // Phase 7: Remove all names and decorations of import variables/functions
882 for (const auto& linking_entry : linkings_to_do) {
883 linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id);
884 for (const auto parameter_id :
885 linking_entry.imported_symbol.parameter_ids) {
886 linked_context.KillNamesAndDecorates(parameter_id);
887 }
888 }
889
890 // Phase 8: Rematch import variables/functions to export variables/functions
891 for (const auto& linking_entry : linkings_to_do) {
892 linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id,
893 linking_entry.exported_symbol.id);
894 }
895
896 // Phase 9: Remove linkage specific instructions, such as import/export
897 // attributes, linkage capability, etc. if applicable
898 res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do,
899 linked_context.get_decoration_mgr(),
900 &linked_context);
901 if (res != SPV_SUCCESS) return res;
902
903 // Phase 10: Optionally fix function call types
904 if (options.GetAllowPtrTypeMismatch()) {
905 res = FixFunctionCallTypes(linked_context, linkings_to_do);
906 if (res != SPV_SUCCESS) return res;
907 }
908
909 // Phase 11: Compact the IDs used in the module
910 manager.AddPass<opt::CompactIdsPass>();
911 pass_res = manager.Run(&linked_context);
912 if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
913
914 // Phase 12: Recompute EntryPoint variables
915 manager.AddPass<opt::RemoveUnusedInterfaceVariablesPass>();
916 pass_res = manager.Run(&linked_context);
917 if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
918
919 // Phase 13: Warn if SPIR-V limits were exceeded
920 res = VerifyLimits(consumer, linked_context);
921 if (res != SPV_SUCCESS) return res;
922
923 // Phase 14: Output the module
924 linked_context.module()->ToBinary(linked_binary, true);
925
926 return SPV_SUCCESS;
927 }
928
929 } // namespace spvtools
930