• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Pierre Moreau
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv-tools/linker.hpp"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <iostream>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "source/assembly_grammar.h"
30 #include "source/diagnostic.h"
31 #include "source/opt/build_module.h"
32 #include "source/opt/compact_ids_pass.h"
33 #include "source/opt/decoration_manager.h"
34 #include "source/opt/ir_loader.h"
35 #include "source/opt/pass_manager.h"
36 #include "source/opt/remove_duplicates_pass.h"
37 #include "source/opt/remove_unused_interface_variables_pass.h"
38 #include "source/opt/type_manager.h"
39 #include "source/spirv_constant.h"
40 #include "source/spirv_target_env.h"
41 #include "source/util/make_unique.h"
42 #include "source/util/string_utils.h"
43 #include "spirv-tools/libspirv.hpp"
44 
45 namespace spvtools {
46 namespace {
47 
48 using opt::Instruction;
49 using opt::IRContext;
50 using opt::Module;
51 using opt::PassManager;
52 using opt::RemoveDuplicatesPass;
53 using opt::analysis::DecorationManager;
54 using opt::analysis::DefUseManager;
55 using opt::analysis::Type;
56 using opt::analysis::TypeManager;
57 
58 // Stores various information about an imported or exported symbol.
59 struct LinkageSymbolInfo {
60   spv::Id id;        // ID of the symbol
61   spv::Id type_id;   // ID of the type of the symbol
62   std::string name;  // unique name defining the symbol and used for matching
63                      // imports and exports together
64   std::vector<spv::Id> parameter_ids;  // ID of the parameters of the symbol, if
65                                        // it is a function
66 };
67 struct LinkageEntry {
68   LinkageSymbolInfo imported_symbol;
69   LinkageSymbolInfo exported_symbol;
70 
LinkageEntryspvtools::__anona33a67a70111::LinkageEntry71   LinkageEntry(const LinkageSymbolInfo& import_info,
72                const LinkageSymbolInfo& export_info)
73       : imported_symbol(import_info), exported_symbol(export_info) {}
74 };
75 using LinkageTable = std::vector<LinkageEntry>;
76 
77 // Shifts the IDs used in each binary of |modules| so that they occupy a
78 // disjoint range from the other binaries, and compute the new ID bound which
79 // is returned in |max_id_bound|.
80 //
81 // Both |modules| and |max_id_bound| should not be null, and |modules| should
82 // not be empty either. Furthermore |modules| should not contain any null
83 // pointers.
84 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
85                                std::vector<opt::Module*>* modules,
86                                uint32_t* max_id_bound);
87 
88 // Generates the header for the linked module and returns it in |header|.
89 //
90 // |header| should not be null, |modules| should not be empty and pointers
91 // should be non-null. |max_id_bound| should be strictly greater than 0.
92 spv_result_t GenerateHeader(const MessageConsumer& consumer,
93                             const std::vector<opt::Module*>& modules,
94                             uint32_t max_id_bound, opt::ModuleHeader* header,
95                             const LinkerOptions& options);
96 
97 // Merge all the modules from |in_modules| into a single module owned by
98 // |linked_context|.
99 //
100 // |linked_context| should not be null.
101 spv_result_t MergeModules(const MessageConsumer& consumer,
102                           const std::vector<Module*>& in_modules,
103                           const AssemblyGrammar& grammar,
104                           IRContext* linked_context);
105 
106 // Compute all pairs of import and export and return it in |linkings_to_do|.
107 //
108 // |linkings_to_do should not be null. Built-in symbols will be ignored.
109 //
110 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
111 //                     currently not handled. (You could have a group being
112 //                     applied to a single ID.)
113 // TODO(pierremoreau): What should be the proper behaviour with built-in
114 //                     symbols?
115 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
116                                   const opt::IRContext& linked_context,
117                                   const DefUseManager& def_use_manager,
118                                   const DecorationManager& decoration_manager,
119                                   bool allow_partial_linkage,
120                                   LinkageTable* linkings_to_do);
121 
122 // Checks that for each pair of import and export, the import and export have
123 // the same type as well as the same decorations.
124 //
125 // TODO(pierremoreau): Decorations on functions parameters are currently not
126 // checked.
127 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
128                                             const LinkageTable& linkings_to_do,
129                                             opt::IRContext* context);
130 
131 // Remove linkage specific instructions, such as prototypes of imported
132 // functions, declarations of imported variables, import (and export if
133 // necessary) linkage attributes.
134 //
135 // |linked_context| and |decoration_manager| should not be null, and the
136 // 'RemoveDuplicatePass' should be run first.
137 //
138 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
139 //                     currently not handled. (You could have a group being
140 //                     applied to a single ID.)
141 spv_result_t RemoveLinkageSpecificInstructions(
142     const MessageConsumer& consumer, const LinkerOptions& options,
143     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
144     opt::IRContext* linked_context);
145 
146 // Verify that the unique ids of each instruction in |linked_context| (i.e. the
147 // merged module) are truly unique. Does not check the validity of other ids
148 spv_result_t VerifyIds(const MessageConsumer& consumer,
149                        opt::IRContext* linked_context);
150 
151 // Verify that the universal limits are not crossed, and warn the user
152 // otherwise.
153 //
154 // TODO(pierremoreau):
155 // - Verify against the limits of the environment (e.g. Vulkan limits if
156 //   consuming vulkan1.x)
157 spv_result_t VerifyLimits(const MessageConsumer& consumer,
158                           const opt::IRContext& linked_context);
159 
ShiftIdsInModules(const MessageConsumer & consumer,std::vector<opt::Module * > * modules,uint32_t * max_id_bound)160 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
161                                std::vector<opt::Module*>* modules,
162                                uint32_t* max_id_bound) {
163   spv_position_t position = {};
164 
165   if (modules == nullptr)
166     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
167            << "|modules| of ShiftIdsInModules should not be null.";
168   if (modules->empty())
169     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
170            << "|modules| of ShiftIdsInModules should not be empty.";
171   if (max_id_bound == nullptr)
172     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
173            << "|max_id_bound| of ShiftIdsInModules should not be null.";
174 
175   const size_t id_bound =
176       std::accumulate(modules->begin(), modules->end(), static_cast<size_t>(1),
177                       [](const size_t& accumulation, opt::Module* module) {
178                         return accumulation + module->IdBound() - 1u;
179                       });
180   if (id_bound > std::numeric_limits<uint32_t>::max())
181     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
182            << "Too many IDs (" << id_bound
183            << "): combining all modules would overflow the 32-bit word of the "
184               "SPIR-V header.";
185 
186   *max_id_bound = static_cast<uint32_t>(id_bound);
187 
188   uint32_t id_offset = modules->front()->IdBound() - 1u;
189   for (auto module_iter = modules->begin() + 1; module_iter != modules->end();
190        ++module_iter) {
191     Module* module = *module_iter;
192     module->ForEachInst([&id_offset](Instruction* insn) {
193       insn->ForEachId([&id_offset](uint32_t* id) { *id += id_offset; });
194     });
195     id_offset += module->IdBound() - 1u;
196 
197     // Invalidate the DefUseManager
198     module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse);
199   }
200 
201   return SPV_SUCCESS;
202 }
203 
GenerateHeader(const MessageConsumer & consumer,const std::vector<opt::Module * > & modules,uint32_t max_id_bound,opt::ModuleHeader * header,const LinkerOptions & options)204 spv_result_t GenerateHeader(const MessageConsumer& consumer,
205                             const std::vector<opt::Module*>& modules,
206                             uint32_t max_id_bound, opt::ModuleHeader* header,
207                             const LinkerOptions& options) {
208   spv_position_t position = {};
209 
210   if (modules.empty())
211     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
212            << "|modules| of GenerateHeader should not be empty.";
213   if (max_id_bound == 0u)
214     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
215            << "|max_id_bound| of GenerateHeader should not be null.";
216 
217   uint32_t linked_version = modules.front()->version();
218   for (std::size_t i = 1; i < modules.size(); ++i) {
219     const uint32_t module_version = modules[i]->version();
220     if (options.GetUseHighestVersion()) {
221       linked_version = std::max(linked_version, module_version);
222     } else if (module_version != linked_version) {
223       return DiagnosticStream({0, 0, 1}, consumer, "", SPV_ERROR_INTERNAL)
224              << "Conflicting SPIR-V versions: "
225              << SPV_SPIRV_VERSION_MAJOR_PART(linked_version) << "."
226              << SPV_SPIRV_VERSION_MINOR_PART(linked_version)
227              << " (input modules 1 through " << i << ") vs "
228              << SPV_SPIRV_VERSION_MAJOR_PART(module_version) << "."
229              << SPV_SPIRV_VERSION_MINOR_PART(module_version)
230              << " (input module " << (i + 1) << ").";
231     }
232   }
233 
234   header->magic_number = spv::MagicNumber;
235   header->version = linked_version;
236   header->generator = SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_LINKER, 0);
237   header->bound = max_id_bound;
238   header->schema = 0u;
239 
240   return SPV_SUCCESS;
241 }
242 
MergeModules(const MessageConsumer & consumer,const std::vector<Module * > & input_modules,const AssemblyGrammar & grammar,IRContext * linked_context)243 spv_result_t MergeModules(const MessageConsumer& consumer,
244                           const std::vector<Module*>& input_modules,
245                           const AssemblyGrammar& grammar,
246                           IRContext* linked_context) {
247   spv_position_t position = {};
248 
249   if (linked_context == nullptr)
250     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
251            << "|linked_module| of MergeModules should not be null.";
252   Module* linked_module = linked_context->module();
253 
254   if (input_modules.empty()) return SPV_SUCCESS;
255 
256   for (const auto& module : input_modules)
257     for (const auto& inst : module->capabilities())
258       linked_module->AddCapability(
259           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
260 
261   for (const auto& module : input_modules)
262     for (const auto& inst : module->extensions())
263       linked_module->AddExtension(
264           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
265 
266   for (const auto& module : input_modules)
267     for (const auto& inst : module->ext_inst_imports())
268       linked_module->AddExtInstImport(
269           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
270 
271   const Instruction* linked_memory_model_inst =
272       input_modules.front()->GetMemoryModel();
273   if (linked_memory_model_inst == nullptr) {
274     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
275            << "Input module 1 is lacking an OpMemoryModel instruction.";
276   }
277   const uint32_t linked_addressing_model =
278       linked_memory_model_inst->GetSingleWordOperand(0u);
279   const uint32_t linked_memory_model =
280       linked_memory_model_inst->GetSingleWordOperand(1u);
281 
282   for (std::size_t i = 1; i < input_modules.size(); ++i) {
283     const Module* module = input_modules[i];
284     const Instruction* memory_model_inst = module->GetMemoryModel();
285     if (memory_model_inst == nullptr)
286       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
287              << "Input module " << (i + 1)
288              << " is lacking an OpMemoryModel instruction.";
289 
290     const uint32_t module_addressing_model =
291         memory_model_inst->GetSingleWordOperand(0u);
292     if (module_addressing_model != linked_addressing_model) {
293       spv_operand_desc linked_desc = nullptr, module_desc = nullptr;
294       grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
295                             linked_addressing_model, &linked_desc);
296       grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
297                             module_addressing_model, &module_desc);
298       return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
299              << "Conflicting addressing models: " << linked_desc->name
300              << " (input modules 1 through " << i << ") vs "
301              << module_desc->name << " (input module " << (i + 1) << ").";
302     }
303 
304     const uint32_t module_memory_model =
305         memory_model_inst->GetSingleWordOperand(1u);
306     if (module_memory_model != linked_memory_model) {
307       spv_operand_desc linked_desc = nullptr, module_desc = nullptr;
308       grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, linked_memory_model,
309                             &linked_desc);
310       grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, module_memory_model,
311                             &module_desc);
312       return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
313              << "Conflicting memory models: " << linked_desc->name
314              << " (input modules 1 through " << i << ") vs "
315              << module_desc->name << " (input module " << (i + 1) << ").";
316     }
317   }
318   linked_module->SetMemoryModel(std::unique_ptr<Instruction>(
319       linked_memory_model_inst->Clone(linked_context)));
320 
321   std::vector<std::pair<uint32_t, std::string>> entry_points;
322   for (const auto& module : input_modules)
323     for (const auto& inst : module->entry_points()) {
324       const uint32_t model = inst.GetSingleWordInOperand(0);
325       const std::string name = inst.GetInOperand(2).AsString();
326       const auto i = std::find_if(
327           entry_points.begin(), entry_points.end(),
328           [model, name](const std::pair<uint32_t, std::string>& v) {
329             return v.first == model && v.second == name;
330           });
331       if (i != entry_points.end()) {
332         spv_operand_desc desc = nullptr;
333         grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc);
334         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
335                << "The entry point \"" << name << "\", with execution model "
336                << desc->name << ", was already defined.";
337       }
338       linked_module->AddEntryPoint(
339           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
340       entry_points.emplace_back(model, name);
341     }
342 
343   for (const auto& module : input_modules)
344     for (const auto& inst : module->execution_modes())
345       linked_module->AddExecutionMode(
346           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
347 
348   for (const auto& module : input_modules)
349     for (const auto& inst : module->debugs1())
350       linked_module->AddDebug1Inst(
351           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
352 
353   for (const auto& module : input_modules)
354     for (const auto& inst : module->debugs2())
355       linked_module->AddDebug2Inst(
356           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
357 
358   for (const auto& module : input_modules)
359     for (const auto& inst : module->debugs3())
360       linked_module->AddDebug3Inst(
361           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
362 
363   for (const auto& module : input_modules)
364     for (const auto& inst : module->ext_inst_debuginfo())
365       linked_module->AddExtInstDebugInfo(
366           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
367 
368   // If the generated module uses SPIR-V 1.1 or higher, add an
369   // OpModuleProcessed instruction about the linking step.
370   if (linked_module->version() >= SPV_SPIRV_VERSION_WORD(1, 1)) {
371     const std::string processed_string("Linked by SPIR-V Tools Linker");
372     std::vector<uint32_t> processed_words =
373         spvtools::utils::MakeVector(processed_string);
374     linked_module->AddDebug3Inst(std::unique_ptr<Instruction>(
375         new Instruction(linked_context, spv::Op::OpModuleProcessed, 0u, 0u,
376                         {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}})));
377   }
378 
379   for (const auto& module : input_modules)
380     for (const auto& inst : module->annotations())
381       linked_module->AddAnnotationInst(
382           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
383 
384   // TODO(pierremoreau): Since the modules have not been validate, should we
385   //                     expect spv::StorageClass::Function variables outside
386   //                     functions?
387   for (const auto& module : input_modules) {
388     for (const auto& inst : module->types_values()) {
389       linked_module->AddType(
390           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
391     }
392   }
393 
394   // Process functions and their basic blocks
395   for (const auto& module : input_modules) {
396     for (const auto& func : *module) {
397       std::unique_ptr<opt::Function> cloned_func(func.Clone(linked_context));
398       linked_module->AddFunction(std::move(cloned_func));
399     }
400   }
401 
402   return SPV_SUCCESS;
403 }
404 
GetImportExportPairs(const MessageConsumer & consumer,const opt::IRContext & linked_context,const DefUseManager & def_use_manager,const DecorationManager & decoration_manager,bool allow_partial_linkage,LinkageTable * linkings_to_do)405 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
406                                   const opt::IRContext& linked_context,
407                                   const DefUseManager& def_use_manager,
408                                   const DecorationManager& decoration_manager,
409                                   bool allow_partial_linkage,
410                                   LinkageTable* linkings_to_do) {
411   spv_position_t position = {};
412 
413   if (linkings_to_do == nullptr)
414     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
415            << "|linkings_to_do| of GetImportExportPairs should not be empty.";
416 
417   std::vector<LinkageSymbolInfo> imports;
418   std::unordered_map<std::string, std::vector<LinkageSymbolInfo>> exports;
419 
420   // Figure out the imports and exports
421   for (const auto& decoration : linked_context.annotations()) {
422     if (decoration.opcode() != spv::Op::OpDecorate ||
423         spv::Decoration(decoration.GetSingleWordInOperand(1u)) !=
424             spv::Decoration::LinkageAttributes)
425       continue;
426 
427     const spv::Id id = decoration.GetSingleWordInOperand(0u);
428     // Ignore if the targeted symbol is a built-in
429     bool is_built_in = false;
430     for (const auto& id_decoration :
431          decoration_manager.GetDecorationsFor(id, false)) {
432       if (spv::Decoration(id_decoration->GetSingleWordInOperand(1u)) ==
433           spv::Decoration::BuiltIn) {
434         is_built_in = true;
435         break;
436       }
437     }
438     if (is_built_in) {
439       continue;
440     }
441 
442     const uint32_t type = decoration.GetSingleWordInOperand(3u);
443 
444     LinkageSymbolInfo symbol_info;
445     symbol_info.name = decoration.GetInOperand(2u).AsString();
446     symbol_info.id = id;
447     symbol_info.type_id = 0u;
448 
449     // Retrieve the type of the current symbol. This information will be used
450     // when checking that the imported and exported symbols have the same
451     // types.
452     const Instruction* def_inst = def_use_manager.GetDef(id);
453     if (def_inst == nullptr)
454       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
455              << "ID " << id << " is never defined:\n";
456 
457     if (def_inst->opcode() == spv::Op::OpVariable) {
458       symbol_info.type_id = def_inst->type_id();
459     } else if (def_inst->opcode() == spv::Op::OpFunction) {
460       symbol_info.type_id = def_inst->GetSingleWordInOperand(1u);
461 
462       // range-based for loop calls begin()/end(), but never cbegin()/cend(),
463       // which will not work here.
464       for (auto func_iter = linked_context.module()->cbegin();
465            func_iter != linked_context.module()->cend(); ++func_iter) {
466         if (func_iter->result_id() != id) continue;
467         func_iter->ForEachParam([&symbol_info](const Instruction* inst) {
468           symbol_info.parameter_ids.push_back(inst->result_id());
469         });
470       }
471     } else {
472       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
473              << "Only global variables and functions can be decorated using"
474              << " LinkageAttributes; " << id << " is neither of them.\n";
475     }
476 
477     if (spv::LinkageType(type) == spv::LinkageType::Import)
478       imports.push_back(symbol_info);
479     else if (spv::LinkageType(type) == spv::LinkageType::Export)
480       exports[symbol_info.name].push_back(symbol_info);
481   }
482 
483   // Find the import/export pairs
484   for (const auto& import : imports) {
485     std::vector<LinkageSymbolInfo> possible_exports;
486     const auto& exp = exports.find(import.name);
487     if (exp != exports.end()) possible_exports = exp->second;
488     if (possible_exports.empty() && !allow_partial_linkage)
489       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
490              << "Unresolved external reference to \"" << import.name << "\".";
491     else if (possible_exports.size() > 1u)
492       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
493              << "Too many external references, " << possible_exports.size()
494              << ", were found for \"" << import.name << "\".";
495 
496     if (!possible_exports.empty())
497       linkings_to_do->emplace_back(import, possible_exports.front());
498   }
499 
500   return SPV_SUCCESS;
501 }
502 
CheckImportExportCompatibility(const MessageConsumer & consumer,const LinkageTable & linkings_to_do,opt::IRContext * context)503 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
504                                             const LinkageTable& linkings_to_do,
505                                             opt::IRContext* context) {
506   spv_position_t position = {};
507 
508   // Ensure the import and export types are the same.
509   const DecorationManager& decoration_manager = *context->get_decoration_mgr();
510   const TypeManager& type_manager = *context->get_type_mgr();
511   for (const auto& linking_entry : linkings_to_do) {
512     Type* imported_symbol_type =
513         type_manager.GetType(linking_entry.imported_symbol.type_id);
514     Type* exported_symbol_type =
515         type_manager.GetType(linking_entry.exported_symbol.type_id);
516     if (!(*imported_symbol_type == *exported_symbol_type))
517       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
518              << "Type mismatch on symbol \""
519              << linking_entry.imported_symbol.name
520              << "\" between imported variable/function %"
521              << linking_entry.imported_symbol.id
522              << " and exported variable/function %"
523              << linking_entry.exported_symbol.id << ".";
524   }
525 
526   // Ensure the import and export decorations are similar
527   for (const auto& linking_entry : linkings_to_do) {
528     if (!decoration_manager.HaveTheSameDecorations(
529             linking_entry.imported_symbol.id, linking_entry.exported_symbol.id))
530       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
531              << "Decorations mismatch on symbol \""
532              << linking_entry.imported_symbol.name
533              << "\" between imported variable/function %"
534              << linking_entry.imported_symbol.id
535              << " and exported variable/function %"
536              << linking_entry.exported_symbol.id << ".";
537     // TODO(pierremoreau): Decorations on function parameters should probably
538     //                     match, except for FuncParamAttr if I understand the
539     //                     spec correctly.
540     // TODO(pierremoreau): Decorations on the function return type should
541     //                     match, except for FuncParamAttr.
542   }
543 
544   return SPV_SUCCESS;
545 }
546 
RemoveLinkageSpecificInstructions(const MessageConsumer & consumer,const LinkerOptions & options,const LinkageTable & linkings_to_do,DecorationManager * decoration_manager,opt::IRContext * linked_context)547 spv_result_t RemoveLinkageSpecificInstructions(
548     const MessageConsumer& consumer, const LinkerOptions& options,
549     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
550     opt::IRContext* linked_context) {
551   spv_position_t position = {};
552 
553   if (decoration_manager == nullptr)
554     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
555            << "|decoration_manager| of RemoveLinkageSpecificInstructions "
556               "should not be empty.";
557   if (linked_context == nullptr)
558     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
559            << "|linked_module| of RemoveLinkageSpecificInstructions should not "
560               "be empty.";
561 
562   // TODO(pierremoreau): Remove FuncParamAttr decorations of imported
563   // functions' return type.
564 
565   // Remove prototypes of imported functions
566   for (const auto& linking_entry : linkings_to_do) {
567     for (auto func_iter = linked_context->module()->begin();
568          func_iter != linked_context->module()->end();) {
569       if (func_iter->result_id() == linking_entry.imported_symbol.id)
570         func_iter = func_iter.Erase();
571       else
572         ++func_iter;
573     }
574   }
575 
576   // Remove declarations of imported variables
577   for (const auto& linking_entry : linkings_to_do) {
578     auto next = linked_context->types_values_begin();
579     for (auto inst = next; inst != linked_context->types_values_end();
580          inst = next) {
581       ++next;
582       if (inst->result_id() == linking_entry.imported_symbol.id) {
583         linked_context->KillInst(&*inst);
584       }
585     }
586   }
587 
588   // If partial linkage is allowed, we need an efficient way to check whether
589   // an imported ID had a corresponding export symbol. As uses of the imported
590   // symbol have already been replaced by the exported symbol, use the exported
591   // symbol ID.
592   // TODO(pierremoreau): This will not work if the decoration is applied
593   //                     through a group, but the linker does not support that
594   //                     either.
595   std::unordered_set<spv::Id> imports;
596   if (options.GetAllowPartialLinkage()) {
597     imports.reserve(linkings_to_do.size());
598     for (const auto& linking_entry : linkings_to_do)
599       imports.emplace(linking_entry.exported_symbol.id);
600   }
601 
602   // Remove import linkage attributes
603   auto next = linked_context->annotation_begin();
604   for (auto inst = next; inst != linked_context->annotation_end();
605        inst = next) {
606     ++next;
607     // If this is an import annotation:
608     // * if we do not allow partial linkage, remove all import annotations;
609     // * otherwise, remove the annotation only if there was a corresponding
610     //   export.
611     if (inst->opcode() == spv::Op::OpDecorate &&
612         spv::Decoration(inst->GetSingleWordOperand(1u)) ==
613             spv::Decoration::LinkageAttributes &&
614         spv::LinkageType(inst->GetSingleWordOperand(3u)) ==
615             spv::LinkageType::Import &&
616         (!options.GetAllowPartialLinkage() ||
617          imports.find(inst->GetSingleWordOperand(0u)) != imports.end())) {
618       linked_context->KillInst(&*inst);
619     }
620   }
621 
622   // Remove export linkage attributes if making an executable
623   if (!options.GetCreateLibrary()) {
624     next = linked_context->annotation_begin();
625     for (auto inst = next; inst != linked_context->annotation_end();
626          inst = next) {
627       ++next;
628       if (inst->opcode() == spv::Op::OpDecorate &&
629           spv::Decoration(inst->GetSingleWordOperand(1u)) ==
630               spv::Decoration::LinkageAttributes &&
631           spv::LinkageType(inst->GetSingleWordOperand(3u)) ==
632               spv::LinkageType::Export) {
633         linked_context->KillInst(&*inst);
634       }
635     }
636   }
637 
638   // Remove Linkage capability if making an executable and partial linkage is
639   // not allowed
640   if (!options.GetCreateLibrary() && !options.GetAllowPartialLinkage()) {
641     for (auto& inst : linked_context->capabilities())
642       if (spv::Capability(inst.GetSingleWordInOperand(0u)) ==
643           spv::Capability::Linkage) {
644         linked_context->KillInst(&inst);
645         // The RemoveDuplicatesPass did remove duplicated capabilities, so we
646         // now there aren’t more spv::Capability::Linkage further down.
647         break;
648       }
649   }
650 
651   return SPV_SUCCESS;
652 }
653 
VerifyIds(const MessageConsumer & consumer,opt::IRContext * linked_context)654 spv_result_t VerifyIds(const MessageConsumer& consumer,
655                        opt::IRContext* linked_context) {
656   std::unordered_set<uint32_t> ids;
657   bool ok = true;
658   linked_context->module()->ForEachInst(
659       [&ids, &ok](const opt::Instruction* inst) {
660         ok &= ids.insert(inst->unique_id()).second;
661       });
662 
663   if (!ok) {
664     consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module");
665     return SPV_ERROR_INVALID_ID;
666   }
667 
668   return SPV_SUCCESS;
669 }
670 
VerifyLimits(const MessageConsumer & consumer,const opt::IRContext & linked_context)671 spv_result_t VerifyLimits(const MessageConsumer& consumer,
672                           const opt::IRContext& linked_context) {
673   spv_position_t position = {};
674 
675   const uint32_t max_id_bound = linked_context.module()->id_bound();
676   if (max_id_bound >= SPV_LIMIT_RESULT_ID_BOUND)
677     DiagnosticStream({0u, 0u, 4u}, consumer, "", SPV_WARNING)
678         << "The minimum limit of IDs, " << (SPV_LIMIT_RESULT_ID_BOUND - 1)
679         << ", was exceeded:"
680         << " " << max_id_bound << " is the current ID bound.\n"
681         << "The resulting module might not be supported by all "
682            "implementations.";
683 
684   size_t num_global_values = 0u;
685   for (const auto& inst : linked_context.module()->types_values()) {
686     num_global_values += inst.opcode() == spv::Op::OpVariable;
687   }
688   if (num_global_values >= SPV_LIMIT_GLOBAL_VARIABLES_MAX)
689     DiagnosticStream(position, consumer, "", SPV_WARNING)
690         << "The minimum limit of global values, "
691         << (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded;"
692         << " " << num_global_values << " global values were found.\n"
693         << "The resulting module might not be supported by all "
694            "implementations.";
695 
696   return SPV_SUCCESS;
697 }
698 
699 }  // namespace
700 
Link(const Context & context,const std::vector<std::vector<uint32_t>> & binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)701 spv_result_t Link(const Context& context,
702                   const std::vector<std::vector<uint32_t>>& binaries,
703                   std::vector<uint32_t>* linked_binary,
704                   const LinkerOptions& options) {
705   std::vector<const uint32_t*> binary_ptrs;
706   binary_ptrs.reserve(binaries.size());
707   std::vector<size_t> binary_sizes;
708   binary_sizes.reserve(binaries.size());
709 
710   for (const auto& binary : binaries) {
711     binary_ptrs.push_back(binary.data());
712     binary_sizes.push_back(binary.size());
713   }
714 
715   return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(),
716               linked_binary, options);
717 }
718 
Link(const Context & context,const uint32_t * const * binaries,const size_t * binary_sizes,size_t num_binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)719 spv_result_t Link(const Context& context, const uint32_t* const* binaries,
720                   const size_t* binary_sizes, size_t num_binaries,
721                   std::vector<uint32_t>* linked_binary,
722                   const LinkerOptions& options) {
723   spv_position_t position = {};
724   const spv_context& c_context = context.CContext();
725   const MessageConsumer& consumer = c_context->consumer;
726 
727   linked_binary->clear();
728   if (num_binaries == 0u)
729     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
730            << "No modules were given.";
731 
732   std::vector<std::unique_ptr<IRContext>> ir_contexts;
733   std::vector<Module*> modules;
734   modules.reserve(num_binaries);
735   for (size_t i = 0u; i < num_binaries; ++i) {
736     const uint32_t schema = binaries[i][4u];
737     if (schema != 0u) {
738       position.index = 4u;
739       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
740              << "Schema is non-zero for module " << i + 1 << ".";
741     }
742 
743     std::unique_ptr<IRContext> ir_context = BuildModule(
744         c_context->target_env, consumer, binaries[i], binary_sizes[i]);
745     if (ir_context == nullptr)
746       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
747              << "Failed to build module " << i + 1 << " out of " << num_binaries
748              << ".";
749     modules.push_back(ir_context->module());
750     ir_contexts.push_back(std::move(ir_context));
751   }
752 
753   // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
754   //          range from the other binaries, and compute the new ID bound.
755   uint32_t max_id_bound = 0u;
756   spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound);
757   if (res != SPV_SUCCESS) return res;
758 
759   // Phase 2: Generate the header
760   opt::ModuleHeader header;
761   res = GenerateHeader(consumer, modules, max_id_bound, &header, options);
762   if (res != SPV_SUCCESS) return res;
763   IRContext linked_context(c_context->target_env, consumer);
764   linked_context.module()->SetHeader(header);
765 
766   // Phase 3: Merge all the binaries into a single one.
767   AssemblyGrammar grammar(c_context);
768   res = MergeModules(consumer, modules, grammar, &linked_context);
769   if (res != SPV_SUCCESS) return res;
770 
771   if (options.GetVerifyIds()) {
772     res = VerifyIds(consumer, &linked_context);
773     if (res != SPV_SUCCESS) return res;
774   }
775 
776   // Phase 4: Find the import/export pairs
777   LinkageTable linkings_to_do;
778   res = GetImportExportPairs(consumer, linked_context,
779                              *linked_context.get_def_use_mgr(),
780                              *linked_context.get_decoration_mgr(),
781                              options.GetAllowPartialLinkage(), &linkings_to_do);
782   if (res != SPV_SUCCESS) return res;
783 
784   // Phase 5: Ensure the import and export have the same types and decorations.
785   res =
786       CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
787   if (res != SPV_SUCCESS) return res;
788 
789   // Phase 6: Remove duplicates
790   PassManager manager;
791   manager.SetMessageConsumer(consumer);
792   manager.AddPass<RemoveDuplicatesPass>();
793   opt::Pass::Status pass_res = manager.Run(&linked_context);
794   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
795 
796   // Phase 7: Remove all names and decorations of import variables/functions
797   for (const auto& linking_entry : linkings_to_do) {
798     linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id);
799     for (const auto parameter_id :
800          linking_entry.imported_symbol.parameter_ids) {
801       linked_context.KillNamesAndDecorates(parameter_id);
802     }
803   }
804 
805   // Phase 8: Rematch import variables/functions to export variables/functions
806   for (const auto& linking_entry : linkings_to_do) {
807     linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id,
808                                       linking_entry.exported_symbol.id);
809   }
810 
811   // Phase 9: Remove linkage specific instructions, such as import/export
812   // attributes, linkage capability, etc. if applicable
813   res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do,
814                                           linked_context.get_decoration_mgr(),
815                                           &linked_context);
816   if (res != SPV_SUCCESS) return res;
817 
818   // Phase 10: Compact the IDs used in the module
819   manager.AddPass<opt::CompactIdsPass>();
820   pass_res = manager.Run(&linked_context);
821   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
822 
823   // Phase 11: Recompute EntryPoint variables
824   manager.AddPass<opt::RemoveUnusedInterfaceVariablesPass>();
825   pass_res = manager.Run(&linked_context);
826   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
827 
828   // Phase 12: Warn if SPIR-V limits were exceeded
829   res = VerifyLimits(consumer, linked_context);
830   if (res != SPV_SUCCESS) return res;
831 
832   // Phase 13: Output the module
833   linked_context.module()->ToBinary(linked_binary, true);
834 
835   return SPV_SUCCESS;
836 }
837 
838 }  // namespace spvtools
839