• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25 
26 #include <cstdlib>
27 #include <filesystem>
28 #include <sstream>
29 #include <mutex>
30 
31 #include "util/ralloc.h"
32 #include "util/set.h"
33 #include <llvm/ADT/ArrayRef.h>
34 #include <llvm/IR/DiagnosticPrinter.h>
35 #include <llvm/IR/DiagnosticInfo.h>
36 #include <llvm/IR/LegacyPassManager.h>
37 #include <llvm/IR/LLVMContext.h>
38 #include <llvm/IR/Type.h>
39 #include <llvm/MC/TargetRegistry.h>
40 #include <llvm/Target/TargetMachine.h>
41 #include <llvm/Support/raw_ostream.h>
42 #include <llvm/Bitcode/BitcodeWriter.h>
43 #include <llvm/Bitcode/BitcodeReader.h>
44 #include <llvm-c/Core.h>
45 #include <llvm-c/Target.h>
46 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
47 
48 #include <clang/Config/config.h>
49 #include <clang/Driver/Driver.h>
50 #include <clang/CodeGen/CodeGenAction.h>
51 #include <clang/Lex/PreprocessorOptions.h>
52 #include <clang/Frontend/CompilerInstance.h>
53 #include <clang/Frontend/TextDiagnosticBuffer.h>
54 #include <clang/Frontend/TextDiagnosticPrinter.h>
55 #include <clang/Frontend/Utils.h>
56 #include <clang/Basic/TargetInfo.h>
57 
58 #include <spirv-tools/libspirv.hpp>
59 #include <spirv-tools/linker.hpp>
60 #include <spirv-tools/optimizer.hpp>
61 
62 #if LLVM_VERSION_MAJOR >= 20
63 #include <llvm/Support/VirtualFileSystem.h>
64 #endif
65 
66 #include "util/macros.h"
67 #include "glsl_types.h"
68 
69 #include "spirv.h"
70 
71 #if DETECT_OS_POSIX
72 #include <dlfcn.h>
73 #endif
74 
75 #ifdef USE_STATIC_OPENCL_C_H
76 #include "opencl-c-base.h.h"
77 #include "opencl-c.h.h"
78 #endif
79 
80 #include "clc_helpers.h"
81 
82 namespace fs = std::filesystem;
83 
84 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
85 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
86 
87 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
88 
89 using ::llvm::Function;
90 using ::llvm::legacy::PassManager;
91 using ::llvm::LLVMContext;
92 using ::llvm::Module;
93 using ::llvm::raw_string_ostream;
94 using ::llvm::TargetRegistry;
95 using ::clang::driver::Driver;
96 
97 static void
98 clc_dump_llvm(const llvm::Module *mod, FILE *f);
99 
100 static void
101 #if LLVM_VERSION_MAJOR >= 19
llvm_log_handler(const::llvm::DiagnosticInfo * di,void * data)102 llvm_log_handler(const ::llvm::DiagnosticInfo *di, void *data) {
103 #else
104 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
105 #endif
106    const clc_logger *logger = static_cast<clc_logger *>(data);
107 
108    std::string log;
109    raw_string_ostream os { log };
110    ::llvm::DiagnosticPrinterRawOStream printer { os };
111 #if LLVM_VERSION_MAJOR >= 19
112    di->print(printer);
113 #else
114    di.print(printer);
115 #endif
116 
117    clc_error(logger, "%s", log.c_str());
118 }
119 
120 class SPIRVKernelArg {
121 public:
122    SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
123                                                   addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
124                                                   accessQualifier(0),
125                                                   typeQualifier(0) { }
126    ~SPIRVKernelArg() { }
127 
128    uint32_t id;
129    uint32_t typeId;
130    std::string name;
131    std::string typeName;
132    enum clc_kernel_arg_address_qualifier addrQualifier;
133    unsigned accessQualifier;
134    unsigned typeQualifier;
135 };
136 
137 class SPIRVKernelInfo {
138 public:
139    SPIRVKernelInfo(uint32_t fid, const char *nm)
140       : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
141    ~SPIRVKernelInfo() { }
142 
143    uint32_t funcId;
144    std::string name;
145    std::vector<SPIRVKernelArg> args;
146    unsigned vecHint;
147    unsigned localSize[3];
148    unsigned localSizeHint[3];
149 };
150 
151 class SPIRVKernelParser {
152 public:
153    SPIRVKernelParser() : curKernel(NULL)
154    {
155       ctx = spvContextCreate(spirv_target);
156    }
157 
158    ~SPIRVKernelParser()
159    {
160      spvContextDestroy(ctx);
161    }
162 
163    void parseEntryPoint(const spv_parsed_instruction_t *ins)
164    {
165       assert(ins->num_operands >= 3);
166 
167       const spv_parsed_operand_t *op = &ins->operands[1];
168 
169       assert(op->type == SPV_OPERAND_TYPE_ID);
170 
171       uint32_t funcId = ins->words[op->offset];
172 
173       for (auto &iter : kernels) {
174          if (funcId == iter.funcId)
175             return;
176       }
177 
178       op = &ins->operands[2];
179       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
180       const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
181 
182       kernels.push_back(SPIRVKernelInfo(funcId, name));
183    }
184 
185    void parseFunction(const spv_parsed_instruction_t *ins)
186    {
187       assert(ins->num_operands == 4);
188 
189       const spv_parsed_operand_t *op = &ins->operands[1];
190 
191       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
192 
193       uint32_t funcId = ins->words[op->offset];
194 
195       for (auto &kernel : kernels) {
196          if (funcId == kernel.funcId && !kernel.args.size()) {
197             curKernel = &kernel;
198 	    return;
199          }
200       }
201    }
202 
203    void parseFunctionParam(const spv_parsed_instruction_t *ins)
204    {
205       const spv_parsed_operand_t *op;
206       uint32_t id, typeId;
207 
208       if (!curKernel)
209          return;
210 
211       assert(ins->num_operands == 2);
212       op = &ins->operands[0];
213       assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
214       typeId = ins->words[op->offset];
215       op = &ins->operands[1];
216       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
217       id = ins->words[op->offset];
218       curKernel->args.push_back(SPIRVKernelArg(id, typeId));
219    }
220 
221    void parseName(const spv_parsed_instruction_t *ins)
222    {
223       const spv_parsed_operand_t *op;
224       const char *name;
225       uint32_t id;
226 
227       assert(ins->num_operands == 2);
228 
229       op = &ins->operands[0];
230       assert(op->type == SPV_OPERAND_TYPE_ID);
231       id = ins->words[op->offset];
232       op = &ins->operands[1];
233       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
234       name = reinterpret_cast<const char *>(ins->words + op->offset);
235 
236       for (auto &kernel : kernels) {
237          for (auto &arg : kernel.args) {
238             if (arg.id == id && arg.name.empty()) {
239               arg.name = name;
240               break;
241 	    }
242          }
243       }
244    }
245 
246    void parseTypePointer(const spv_parsed_instruction_t *ins)
247    {
248       enum clc_kernel_arg_address_qualifier addrQualifier;
249       uint32_t typeId, storageClass;
250       const spv_parsed_operand_t *op;
251 
252       assert(ins->num_operands == 3);
253 
254       op = &ins->operands[0];
255       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
256       typeId = ins->words[op->offset];
257 
258       op = &ins->operands[1];
259       assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
260       storageClass = ins->words[op->offset];
261       switch (storageClass) {
262       case SpvStorageClassCrossWorkgroup:
263          addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
264          break;
265       case SpvStorageClassWorkgroup:
266          addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
267          break;
268       case SpvStorageClassUniformConstant:
269          addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
270          break;
271       default:
272          addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
273          break;
274       }
275 
276       for (auto &kernel : kernels) {
277 	 for (auto &arg : kernel.args) {
278             if (arg.typeId == typeId) {
279                arg.addrQualifier = addrQualifier;
280                if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
281                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
282             }
283          }
284       }
285    }
286 
287    void parseOpString(const spv_parsed_instruction_t *ins)
288    {
289       const spv_parsed_operand_t *op;
290       std::string str;
291 
292       assert(ins->num_operands == 2);
293 
294       op = &ins->operands[1];
295       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
296       str = reinterpret_cast<const char *>(ins->words + op->offset);
297 
298       size_t start = 0;
299       enum class string_type {
300          arg_type,
301          arg_type_qual,
302       } str_type;
303 
304       if (str.find("kernel_arg_type.") == 0) {
305          start = sizeof("kernel_arg_type.") - 1;
306          str_type = string_type::arg_type;
307       } else if (str.find("kernel_arg_type_qual.") == 0) {
308          start = sizeof("kernel_arg_type_qual.") - 1;
309          str_type = string_type::arg_type_qual;
310       } else {
311          return;
312       }
313 
314       for (auto &kernel : kernels) {
315          size_t pos;
316 
317 	 pos = str.find(kernel.name, start);
318          if (pos == std::string::npos ||
319              pos != start || str[start + kernel.name.size()] != '.')
320             continue;
321 
322 	 pos = start + kernel.name.size();
323          if (str[pos++] != '.')
324             continue;
325 
326          for (auto &arg : kernel.args) {
327             if (arg.name.empty())
328                break;
329 
330             size_t entryEnd = str.find(',', pos);
331 	    if (entryEnd == std::string::npos)
332                break;
333 
334             std::string entryVal = str.substr(pos, entryEnd - pos);
335             pos = entryEnd + 1;
336 
337             if (str_type == string_type::arg_type) {
338                arg.typeName = std::move(entryVal);
339             } else if (str_type == string_type::arg_type_qual) {
340                if (entryVal.find("const") != std::string::npos)
341                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
342             }
343          }
344       }
345    }
346 
347    void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
348    {
349       auto iter = decorationGroups.find(id);
350       if (iter != decorationGroups.end()) {
351          for (uint32_t entry : iter->second)
352             applyDecoration(entry, ins);
353          return;
354       }
355 
356       const spv_parsed_operand_t *op;
357       uint32_t decoration;
358 
359       assert(ins->num_operands >= 2);
360 
361       op = &ins->operands[1];
362       assert(op->type == SPV_OPERAND_TYPE_DECORATION);
363       decoration = ins->words[op->offset];
364 
365       if (decoration == SpvDecorationSpecId) {
366          uint32_t spec_id = ins->words[ins->operands[2].offset];
367          for (auto &c : specConstants) {
368             if (c.second.id == spec_id) {
369                return;
370             }
371          }
372          specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
373          return;
374       }
375 
376       for (auto &kernel : kernels) {
377          for (auto &arg : kernel.args) {
378             if (arg.id == id) {
379                switch (decoration) {
380                case SpvDecorationVolatile:
381                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
382                   break;
383                case SpvDecorationConstant:
384                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
385                   break;
386                case SpvDecorationRestrict:
387                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
388                   break;
389                case SpvDecorationFuncParamAttr:
390                   op = &ins->operands[2];
391                   assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
392                   switch (ins->words[op->offset]) {
393                   case SpvFunctionParameterAttributeNoAlias:
394                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
395                      break;
396                   case SpvFunctionParameterAttributeNoWrite:
397                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
398                      break;
399                   }
400                   break;
401                }
402             }
403 
404          }
405       }
406    }
407 
408    void parseOpDecorate(const spv_parsed_instruction_t *ins)
409    {
410       const spv_parsed_operand_t *op;
411       uint32_t id;
412 
413       assert(ins->num_operands >= 2);
414 
415       op = &ins->operands[0];
416       assert(op->type == SPV_OPERAND_TYPE_ID);
417       id = ins->words[op->offset];
418 
419       applyDecoration(id, ins);
420    }
421 
422    void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
423    {
424       assert(ins->num_operands >= 2);
425 
426       const spv_parsed_operand_t *op = &ins->operands[0];
427       assert(op->type == SPV_OPERAND_TYPE_ID);
428       uint32_t groupId = ins->words[op->offset];
429 
430       auto lowerBound = decorationGroups.lower_bound(groupId);
431       if (lowerBound != decorationGroups.end() &&
432           lowerBound->first == groupId)
433          // Group already filled out
434          return;
435 
436       auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
437       auto& vec = iter->second;
438       vec.reserve(ins->num_operands - 1);
439       for (uint32_t i = 1; i < ins->num_operands; ++i) {
440          op = &ins->operands[i];
441          assert(op->type == SPV_OPERAND_TYPE_ID);
442          vec.push_back(ins->words[op->offset]);
443       }
444    }
445 
446    void parseOpTypeImage(const spv_parsed_instruction_t *ins)
447    {
448       const spv_parsed_operand_t *op;
449       uint32_t typeId;
450       unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
451 
452       op = &ins->operands[0];
453       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
454       typeId = ins->words[op->offset];
455 
456       if (ins->num_operands >= 9) {
457          op = &ins->operands[8];
458          assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
459          switch (ins->words[op->offset]) {
460          case SpvAccessQualifierReadOnly:
461             accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
462             break;
463          case SpvAccessQualifierWriteOnly:
464             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
465             break;
466          case SpvAccessQualifierReadWrite:
467             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
468                CLC_KERNEL_ARG_ACCESS_READ;
469             break;
470          }
471       }
472 
473       for (auto &kernel : kernels) {
474 	 for (auto &arg : kernel.args) {
475             if (arg.typeId == typeId) {
476                arg.accessQualifier = accessQualifier;
477                arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
478             }
479          }
480       }
481    }
482 
483    void parseExecutionMode(const spv_parsed_instruction_t *ins)
484    {
485       uint32_t executionMode = ins->words[ins->operands[1].offset];
486       uint32_t funcId = ins->words[ins->operands[0].offset];
487 
488       for (auto& kernel : kernels) {
489          if (kernel.funcId == funcId) {
490             switch (executionMode) {
491             case SpvExecutionModeVecTypeHint:
492                kernel.vecHint = ins->words[ins->operands[2].offset];
493                break;
494             case SpvExecutionModeLocalSize:
495                kernel.localSize[0] = ins->words[ins->operands[2].offset];
496                kernel.localSize[1] = ins->words[ins->operands[3].offset];
497                kernel.localSize[2] = ins->words[ins->operands[4].offset];
498                break;
499             case SpvExecutionModeLocalSizeHint:
500                kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
501                kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
502                kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
503                break;
504             default:
505                return;
506             }
507          }
508       }
509    }
510 
511    void parseLiteralType(const spv_parsed_instruction_t *ins)
512    {
513       uint32_t typeId = ins->words[ins->operands[0].offset];
514       auto& literalType = literalTypes[typeId];
515       switch (ins->opcode) {
516       case SpvOpTypeBool:
517          literalType = CLC_SPEC_CONSTANT_BOOL;
518          break;
519       case SpvOpTypeFloat: {
520          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
521          switch (sizeInBits) {
522          case 32:
523             literalType = CLC_SPEC_CONSTANT_FLOAT;
524             break;
525          case 64:
526             literalType = CLC_SPEC_CONSTANT_DOUBLE;
527             break;
528          case 16:
529             /* Can't be used for a spec constant */
530             break;
531          default:
532             unreachable("Unexpected float bit size");
533          }
534          break;
535       }
536       case SpvOpTypeInt: {
537          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
538          bool isSigned = ins->words[ins->operands[2].offset];
539          if (isSigned) {
540             switch (sizeInBits) {
541             case 8:
542                literalType = CLC_SPEC_CONSTANT_INT8;
543                break;
544             case 16:
545                literalType = CLC_SPEC_CONSTANT_INT16;
546                break;
547             case 32:
548                literalType = CLC_SPEC_CONSTANT_INT32;
549                break;
550             case 64:
551                literalType = CLC_SPEC_CONSTANT_INT64;
552                break;
553             default:
554                unreachable("Unexpected int bit size");
555             }
556          } else {
557             switch (sizeInBits) {
558             case 8:
559                literalType = CLC_SPEC_CONSTANT_UINT8;
560                break;
561             case 16:
562                literalType = CLC_SPEC_CONSTANT_UINT16;
563                break;
564             case 32:
565                literalType = CLC_SPEC_CONSTANT_UINT32;
566                break;
567             case 64:
568                literalType = CLC_SPEC_CONSTANT_UINT64;
569                break;
570             default:
571                unreachable("Unexpected uint bit size");
572             }
573          }
574          break;
575       }
576       default:
577          unreachable("Unexpected type opcode");
578       }
579    }
580 
581    void parseSpecConstant(const spv_parsed_instruction_t *ins)
582    {
583       uint32_t id = ins->result_id;
584       for (auto& c : specConstants) {
585          if (c.first == id) {
586             auto& data = c.second;
587             switch (ins->opcode) {
588             case SpvOpSpecConstant: {
589                uint32_t typeId = ins->words[ins->operands[0].offset];
590 
591                // This better be an integer or float type
592                auto typeIter = literalTypes.find(typeId);
593                assert(typeIter != literalTypes.end());
594 
595                data.type = typeIter->second;
596                break;
597             }
598             case SpvOpSpecConstantFalse:
599             case SpvOpSpecConstantTrue:
600                data.type = CLC_SPEC_CONSTANT_BOOL;
601                break;
602             default:
603                unreachable("Composites and Ops are not directly specializable.");
604             }
605          }
606       }
607    }
608 
609    static spv_result_t
610    parseInstruction(void *data, const spv_parsed_instruction_t *ins)
611    {
612       SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
613 
614       switch (ins->opcode) {
615       case SpvOpName:
616          parser->parseName(ins);
617          break;
618       case SpvOpEntryPoint:
619          parser->parseEntryPoint(ins);
620          break;
621       case SpvOpFunction:
622          parser->parseFunction(ins);
623          break;
624       case SpvOpFunctionParameter:
625          parser->parseFunctionParam(ins);
626          break;
627       case SpvOpFunctionEnd:
628       case SpvOpLabel:
629          parser->curKernel = NULL;
630          break;
631       case SpvOpTypePointer:
632          parser->parseTypePointer(ins);
633          break;
634       case SpvOpTypeImage:
635          parser->parseOpTypeImage(ins);
636          break;
637       case SpvOpString:
638          parser->parseOpString(ins);
639          break;
640       case SpvOpDecorate:
641          parser->parseOpDecorate(ins);
642          break;
643       case SpvOpGroupDecorate:
644          parser->parseOpGroupDecorate(ins);
645          break;
646       case SpvOpExecutionMode:
647          parser->parseExecutionMode(ins);
648          break;
649       case SpvOpTypeBool:
650       case SpvOpTypeInt:
651       case SpvOpTypeFloat:
652          parser->parseLiteralType(ins);
653          break;
654       case SpvOpSpecConstant:
655       case SpvOpSpecConstantFalse:
656       case SpvOpSpecConstantTrue:
657          parser->parseSpecConstant(ins);
658          break;
659       default:
660          break;
661       }
662 
663       return SPV_SUCCESS;
664    }
665 
666    bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
667    {
668       /* 3 passes should be enough to retrieve all kernel information:
669        * 1st pass: all entry point name and number of args
670        * 2nd pass: argument names and type names
671        * 3rd pass: pointer type names
672        */
673       for (unsigned pass = 0; pass < 3; pass++) {
674          spv_diagnostic diagnostic = NULL;
675          auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
676                                       static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
677                                       NULL, parseInstruction, &diagnostic);
678 
679          if (result != SPV_SUCCESS) {
680             if (diagnostic && logger)
681                logger->error(logger->priv, diagnostic->error);
682             return false;
683          }
684       }
685 
686       return true;
687    }
688 
689    std::vector<SPIRVKernelInfo> kernels;
690    std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
691    std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
692    std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
693    SPIRVKernelInfo *curKernel;
694    spv_context ctx;
695 };
696 
697 bool
698 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
699                            const struct clc_kernel_info **out_kernels,
700                            unsigned *num_kernels,
701                            const struct clc_parsed_spec_constant **out_spec_constants,
702                            unsigned *num_spec_constants,
703                            const struct clc_logger *logger)
704 {
705    struct clc_kernel_info *kernels = NULL;
706    struct clc_parsed_spec_constant *spec_constants = NULL;
707 
708    SPIRVKernelParser parser;
709 
710    if (!parser.parseBinary(*spvbin, logger))
711       return false;
712 
713    *num_kernels = parser.kernels.size();
714    *num_spec_constants = parser.specConstants.size();
715    if (*num_kernels) {
716       kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
717                                                                   sizeof(*kernels)));
718       assert(kernels);
719       for (unsigned i = 0; i < parser.kernels.size(); i++) {
720          kernels[i].name = strdup(parser.kernels[i].name.c_str());
721          kernels[i].num_args = parser.kernels[i].args.size();
722          kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
723          kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
724          memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
725          memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
726          if (!kernels[i].num_args)
727             continue;
728 
729          struct clc_kernel_arg *args;
730 
731          args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
732                                                                  sizeof(*kernels->args)));
733          kernels[i].args = args;
734          assert(args);
735          for (unsigned j = 0; j < kernels[i].num_args; j++) {
736             if (!parser.kernels[i].args[j].name.empty())
737                args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
738             args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
739             args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
740             args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
741             args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
742          }
743       }
744    }
745 
746    if (*num_spec_constants) {
747       spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
748                                                                                   sizeof(*spec_constants)));
749       assert(spec_constants);
750 
751       for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
752          spec_constants[i] = parser.specConstants[i].second;
753       }
754    }
755 
756    *out_kernels = kernels;
757    *out_spec_constants = spec_constants;
758 
759    return true;
760 }
761 
762 void
763 clc_free_kernels_info(const struct clc_kernel_info *kernels,
764                       unsigned num_kernels)
765 {
766    if (!kernels)
767       return;
768 
769    for (unsigned i = 0; i < num_kernels; i++) {
770       if (kernels[i].args) {
771          for (unsigned j = 0; j < kernels[i].num_args; j++) {
772             free((void *)kernels[i].args[j].name);
773             free((void *)kernels[i].args[j].type_name);
774          }
775          free((void *)kernels[i].args);
776       }
777       free((void *)kernels[i].name);
778    }
779 
780    free((void *)kernels);
781 }
782 
783 static std::unique_ptr<::llvm::Module>
784 clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
785                            const struct clc_compile_args *args,
786                            const struct clc_logger *logger,
787                            struct set *dependencies)
788 {
789    static_assert(std::has_unique_object_representations<clc_optional_features>(),
790                  "no padding allowed inside clc_optional_features");
791 
792    std::string diag_log_str;
793    raw_string_ostream diag_log_stream { diag_log_str };
794 
795    std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
796    std::shared_ptr<clang::DependencyCollector> dep;
797    if (dependencies != nullptr) {
798       dep = std::make_shared<clang::DependencyCollector>();
799       c->addDependencyCollector(dep);
800    }
801 
802    clang::DiagnosticsEngine diag {
803       new clang::DiagnosticIDs,
804       new clang::DiagnosticOptions,
805       new clang::TextDiagnosticPrinter(diag_log_stream,
806                                        &c->getDiagnosticOpts())
807    };
808 
809 #if LLVM_VERSION_MAJOR >= 17
810    const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spirv64-unknown-unknown";
811 #else
812    const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spir64-unknown-unknown";
813 #endif
814 
815    std::vector<const char *> clang_opts = {
816       args->source.name,
817       "-triple", triple,
818       // By default, clang prefers to use modules to pull in the default headers,
819       // which doesn't work with our technique of embedding the headers in our binary
820       "-fdeclare-opencl-builtins",
821 #if LLVM_VERSION_MAJOR < 17
822       "-no-opaque-pointers",
823 #endif
824       // Add a default CL compiler version. Clang will pick the last one specified
825       // on the command line, so the app can override this one.
826       "-cl-std=cl1.2",
827       // The LLVM-SPIRV-Translator doesn't support memset with variable size
828       "-fno-builtin-memset",
829       // LLVM's optimizations can produce code that the translator can't translate
830       "-O0",
831       // Ensure inline functions are actually emitted
832       "-fgnu89-inline",
833    };
834 
835    // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
836    // being provided by the caller here.
837    clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
838 
839    if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
840                                                   clang_opts,
841                                                   diag)) {
842       clc_error(logger, "Couldn't create Clang invocation.\n");
843       return {};
844    }
845 
846    if (diag.hasErrorOccurred()) {
847       clc_error(logger, "%sErrors occurred during Clang invocation.\n",
848                 diag_log_str.c_str());
849       return {};
850    }
851 
852    // This is a workaround for a Clang bug which causes the number
853    // of warnings and errors to be printed to stderr.
854    // http://www.llvm.org/bugs/show_bug.cgi?id=19735
855    c->getDiagnosticOpts().ShowCarets = false;
856 
857    c->createDiagnostics(
858 #if LLVM_VERSION_MAJOR >= 20
859                    *llvm::vfs::getRealFileSystem(),
860 #endif
861                    new clang::TextDiagnosticPrinter(
862                            diag_log_stream,
863                            &c->getDiagnosticOpts()));
864 
865    c->setTarget(clang::TargetInfo::CreateTargetInfo(
866                    c->getDiagnostics(), c->getInvocation().TargetOpts));
867 
868    c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
869 
870 #ifdef USE_STATIC_OPENCL_C_H
871    c->getHeaderSearchOpts().UseBuiltinIncludes = false;
872    c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
873 
874    // Add opencl-c generic search path
875    {
876       ::llvm::SmallString<128> system_header_path;
877       ::llvm::sys::path::system_temp_directory(true, system_header_path);
878       ::llvm::sys::path::append(system_header_path, "openclon12");
879       c->getHeaderSearchOpts().AddPath(system_header_path.str(),
880                                        clang::frontend::Angled,
881                                        false, false);
882 
883       ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
884       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
885          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
886       // this line is actually important to make it include `opencl-c.h`
887       ::llvm::sys::path::remove_filename(system_header_path);
888       ::llvm::sys::path::append(system_header_path, "opencl-c.h");
889       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
890          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
891    }
892 #else
893 
894    Dl_info info;
895    if (dladdr((void *)clang::CompilerInvocation::CreateFromArgs, &info) == 0) {
896       clc_error(logger, "Couldn't find libclang path.\n");
897       return {};
898    }
899 
900    char *clang_path = realpath(info.dli_fname, NULL);
901    if (clang_path == nullptr) {
902       clc_error(logger, "Couldn't find libclang path.\n");
903       return {};
904    }
905 
906    // GetResourcePath is a way to retrieve the actual libclang resource dir based on a given binary
907    // or library.
908    auto tmp_res_path =
909 #if LLVM_VERSION_MAJOR >= 20
910       Driver::GetResourcesPath(std::string(clang_path));
911 #else
912       Driver::GetResourcesPath(std::string(clang_path), CLANG_RESOURCE_DIR);
913 #endif
914    auto clang_res_path = fs::path(tmp_res_path) / "include";
915 
916    free(clang_path);
917 
918    c->getHeaderSearchOpts().UseBuiltinIncludes = true;
919    c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
920    c->getHeaderSearchOpts().ResourceDir = clang_res_path.string();
921 
922    // Add opencl-c generic search path
923    c->getHeaderSearchOpts().AddPath(clang_res_path.string(),
924                                     clang::frontend::Angled,
925                                     false, false);
926 
927    auto clang_install_res_path =
928       fs::path(LLVM_LIB_DIR) / "clang" / std::to_string(LLVM_VERSION_MAJOR) / "include";
929    c->getHeaderSearchOpts().AddPath(clang_install_res_path.string(),
930                                     clang::frontend::Angled,
931                                     false, false);
932 #endif
933 
934    // Enable/Disable optional OpenCL C features. Some can be toggled via `OpenCLExtensionsAsWritten`
935    // others we have to (un)define via macros ourselves.
936 
937    // Undefine clang added SPIR(V) defines so we don't magically enable extensions
938    c->getPreprocessorOpts().addMacroUndef("__SPIR__");
939    c->getPreprocessorOpts().addMacroUndef("__SPIRV__");
940 
941    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
942    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
943    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
944    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
945    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
946    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
947    c->getPreprocessorOpts().addMacroDef("cl_khr_expect_assume=1");
948 
949    bool needs_opencl_c_h = false;
950    if (args->features.fp16) {
951       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
952    }
953    if (args->features.fp64) {
954       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
955       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
956    }
957    if (args->features.int64) {
958       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
959       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
960    } else {
961       // clang defines this unconditionally, we need to fix that.
962       c->getPreprocessorOpts().addMacroUndef("__opencl_c_int64");
963    }
964    if (args->features.images) {
965       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
966    } else {
967       // clang defines this unconditionally, we need to fix that.
968       c->getPreprocessorOpts().addMacroUndef("__IMAGE_SUPPORT__");
969    }
970    if (args->features.images_read_write) {
971       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
972    }
973    if (args->features.images_write_3d) {
974       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
975       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
976    }
977    if (args->features.images_depth) {
978       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_depth_images");
979    }
980    if (args->features.images_gl_depth) {
981       c->getPreprocessorOpts().addMacroDef("cl_khr_gl_depth_images=1");
982    }
983    if (args->features.images_mipmap) {
984       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image");
985    }
986    if (args->features.images_mipmap_writes) {
987       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image_writes");
988    }
989    if (args->features.images_gl_msaa) {
990       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_gl_msaa_sharing");
991    }
992    if (args->features.intel_subgroups) {
993       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
994       needs_opencl_c_h = true;
995    }
996    if (args->features.subgroups) {
997       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_subgroups");
998       if (args->features.subgroups_shuffle) {
999          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle=1");
1000       }
1001       if (args->features.subgroups_shuffle_relative) {
1002          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle_relative=1");
1003       }
1004       if (args->features.subgroups_ballot) {
1005          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_ballot=1");
1006       }
1007    }
1008    if (args->features.subgroups_ifp) {
1009       assert(args->features.subgroups);
1010       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
1011    }
1012    if (args->features.integer_dot_product) {
1013       c->getPreprocessorOpts().addMacroDef("cl_khr_integer_dot_product=1");
1014       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit_packed=1");
1015       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit=1");
1016    }
1017 
1018    // Add opencl include
1019    c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
1020    if (needs_opencl_c_h) {
1021       c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
1022    }
1023 
1024    if (args->num_headers) {
1025       ::llvm::SmallString<128> tmp_header_path;
1026       ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
1027       ::llvm::sys::path::append(tmp_header_path, "openclon12");
1028 
1029       c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
1030                                        clang::frontend::Quoted,
1031                                        false, false);
1032 
1033       for (size_t i = 0; i < args->num_headers; i++) {
1034          auto path_copy = tmp_header_path;
1035          ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
1036          c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
1037             ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
1038       }
1039    }
1040 
1041    c->getPreprocessorOpts().addRemappedFile(
1042            args->source.name,
1043            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
1044 
1045    // Compile the code
1046    clang::EmitLLVMOnlyAction act(&llvm_ctx);
1047    if (!c->ExecuteAction(act)) {
1048       clc_error(logger, "%sError executing LLVM compilation action.\n",
1049                 diag_log_str.c_str());
1050       return {};
1051    }
1052 
1053    if (dependencies != nullptr) {
1054       for (auto dep : dep->getDependencies()) {
1055          _mesa_set_add(dependencies, ralloc_strdup(dependencies, dep.c_str()));
1056       }
1057    }
1058 
1059    auto mod = act.takeModule();
1060 
1061    if (clc_debug_flags() & CLC_DEBUG_DUMP_LLVM)
1062       clc_dump_llvm(mod.get(), stdout);
1063 
1064    return mod;
1065 }
1066 
1067 static SPIRV::VersionNumber
1068 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
1069 {
1070    switch (version) {
1071    case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
1072    case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
1073    case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
1074    case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
1075    case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
1076    case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
1077    default:      return invalid_spirv_trans_version;
1078    }
1079 }
1080 
1081 static int
1082 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
1083                   LLVMContext &context,
1084                   const struct clc_compile_args *args,
1085                   const struct clc_logger *logger,
1086                   struct clc_binary *out_spirv)
1087 {
1088    std::string log;
1089 
1090    SPIRV::VersionNumber version =
1091       spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
1092    if (version == invalid_spirv_trans_version) {
1093       clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
1094       return -1;
1095    }
1096 
1097    const char *const *extensions = args->allowed_spirv_extensions;
1098    if (!extensions) {
1099       /* The SPIR-V parser doesn't handle all extensions */
1100       static const char *default_extensions[] = {
1101          "SPV_EXT_shader_atomic_float_add",
1102          "SPV_EXT_shader_atomic_float_min_max",
1103          "SPV_KHR_float_controls",
1104          NULL,
1105       };
1106       extensions = default_extensions;
1107    }
1108 
1109    SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
1110    for (int i = 0; extensions[i]; i++) {
1111 #define EXT(X) \
1112       if (strcmp(#X, extensions[i]) == 0) \
1113          ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
1114 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
1115 #undef EXT
1116    }
1117    SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
1118 
1119    /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
1120    spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
1121 
1122 #if LLVM_VERSION_MAJOR >= 17
1123    if (args->use_llvm_spirv_target) {
1124       const char *triple = args->address_bits == 32 ? "spirv-unknown-unknown" : "spirv64-unknown-unknown";
1125       std::string error_msg("");
1126       auto target = TargetRegistry::lookupTarget(triple, error_msg);
1127       if (target) {
1128          auto TM = target->createTargetMachine(
1129             triple, "", "", {}, std::nullopt, std::nullopt,
1130 #if LLVM_VERSION_MAJOR >= 18
1131             ::llvm::CodeGenOptLevel::None
1132 #else
1133             ::llvm::CodeGenOpt::None
1134 #endif
1135          );
1136 
1137          auto PM = PassManager();
1138          ::llvm::SmallVector<char> buf;
1139          auto OS = ::llvm::raw_svector_ostream(buf);
1140          TM->addPassesToEmitFile(
1141             PM, OS, nullptr,
1142 #if LLVM_VERSION_MAJOR >= 18
1143             ::llvm::CodeGenFileType::ObjectFile
1144 #else
1145             ::llvm::CGFT_ObjectFile
1146 #endif
1147          );
1148 
1149          PM.run(*mod);
1150 
1151          out_spirv->size = buf.size_in_bytes();
1152          out_spirv->data = malloc(out_spirv->size);
1153          memcpy(out_spirv->data, buf.data(), out_spirv->size);
1154          return 0;
1155       } else {
1156          clc_error(logger, "LLVM SPIR-V target not found.\n");
1157          return -1;
1158       }
1159    }
1160 #endif
1161 
1162    std::ostringstream spv_stream;
1163    if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
1164       clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
1165                 log.c_str());
1166       return -1;
1167    }
1168 
1169    const std::string spv_out = spv_stream.str();
1170    out_spirv->size = spv_out.size();
1171    out_spirv->data = malloc(out_spirv->size);
1172    memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1173 
1174    return 0;
1175 }
1176 
1177 int
1178 clc_c_to_spir(const struct clc_compile_args *args,
1179               const struct clc_logger *logger,
1180               struct clc_binary *out_spir,
1181               struct set *dependencies)
1182 {
1183    clc_initialize_llvm();
1184 
1185    LLVMContext llvm_ctx;
1186    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1187                                          const_cast<clc_logger *>(logger));
1188 
1189    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger, dependencies);
1190    if (!mod)
1191       return -1;
1192 
1193    ::llvm::SmallVector<char, 0> buffer;
1194    ::llvm::BitcodeWriter writer(buffer);
1195    writer.writeModule(*mod);
1196 
1197    out_spir->size = buffer.size_in_bytes();
1198    out_spir->data = malloc(out_spir->size);
1199    memcpy(out_spir->data, buffer.data(), out_spir->size);
1200 
1201    return 0;
1202 }
1203 
1204 int
1205 clc_c_to_spirv(const struct clc_compile_args *args,
1206                const struct clc_logger *logger,
1207                struct clc_binary *out_spirv,
1208                struct set *dependencies)
1209 {
1210    clc_initialize_llvm();
1211 
1212    LLVMContext llvm_ctx;
1213    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1214                                          const_cast<clc_logger *>(logger));
1215 
1216    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger, dependencies);
1217    if (!mod)
1218       return -1;
1219 
1220    return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1221 }
1222 
1223 int
1224 clc_spir_to_spirv(const struct clc_binary *in_spir,
1225                   const struct clc_logger *logger,
1226                   struct clc_binary *out_spirv)
1227 {
1228    clc_initialize_llvm();
1229 
1230    LLVMContext llvm_ctx;
1231    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1232                                          const_cast<clc_logger *>(logger));
1233 
1234    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1235    auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1236    if (!mod)
1237       return -1;
1238 
1239    return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1240 }
1241 
1242 class SPIRVMessageConsumer {
1243 public:
1244    SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1245 
1246    void operator()(spv_message_level_t level, const char *src,
1247                    const spv_position_t &pos, const char *msg)
1248    {
1249       if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1250          return;
1251 
1252       std::ostringstream message;
1253       message << "(file=" << (src ? src : "input")
1254               << ",line=" << pos.line
1255               << ",column=" << pos.column
1256               << ",index=" << pos.index
1257               << "): " << msg << "\n";
1258 
1259       if (level == SPV_MSG_WARNING)
1260          clc_warning(logger, "%s", message.str().c_str());
1261       else
1262          clc_error(logger, "%s", message.str().c_str());
1263    }
1264 
1265 private:
1266    const struct clc_logger *logger;
1267 };
1268 
1269 int
1270 clc_link_spirv_binaries(const struct clc_linker_args *args,
1271                         const struct clc_logger *logger,
1272                         struct clc_binary *out_spirv)
1273 {
1274    std::vector<std::vector<uint32_t>> binaries;
1275 
1276    for (unsigned i = 0; i < args->num_in_objs; i++) {
1277       const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1278       std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1279       binaries.push_back(bin);
1280    }
1281 
1282    SPIRVMessageConsumer msgconsumer(logger);
1283    spvtools::Context context(spirv_target);
1284    context.SetMessageConsumer(msgconsumer);
1285    spvtools::LinkerOptions options;
1286    options.SetAllowPartialLinkage(args->create_library);
1287    #if defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1288       options.SetAllowPtrTypeMismatch(true);
1289    #endif
1290    options.SetCreateLibrary(args->create_library);
1291    std::vector<uint32_t> linkingResult;
1292    spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1293    if (status != SPV_SUCCESS) {
1294       #if !defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1295         clc_warning(logger, "SPIRV-Tools doesn't contain https://github.com/KhronosGroup/SPIRV-Tools/pull/5534\n");
1296         clc_warning(logger, "Please update in order to prevent spurious linking failures\n");
1297       #endif
1298       return -1;
1299    }
1300 
1301    out_spirv->size = linkingResult.size() * 4;
1302    out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1303    memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1304 
1305    return 0;
1306 }
1307 
1308 bool
1309 clc_validate_spirv(const struct clc_binary *spirv,
1310                    const struct clc_logger *logger,
1311                    const struct clc_validator_options *options)
1312 {
1313    SPIRVMessageConsumer msgconsumer(logger);
1314    spvtools::SpirvTools tools(spirv_target);
1315    tools.SetMessageConsumer(msgconsumer);
1316    spvtools::ValidatorOptions spirv_options;
1317    const uint32_t *data = static_cast<const uint32_t *>(spirv->data);
1318 
1319    if (options) {
1320       spirv_options.SetUniversalLimit(
1321          spv_validator_limit_max_function_args,
1322          options->limit_max_function_arg);
1323    }
1324 
1325    return tools.Validate(data, spirv->size / 4, spirv_options);
1326 }
1327 
1328 int
1329 clc_spirv_specialize(const struct clc_binary *in_spirv,
1330                      const struct clc_parsed_spirv *parsed_data,
1331                      const struct clc_spirv_specialization_consts *consts,
1332                      struct clc_binary *out_spirv)
1333 {
1334    std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1335    for (unsigned i = 0; i < consts->num_specializations; ++i) {
1336       unsigned id = consts->specializations[i].id;
1337       auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1338          parsed_data->spec_constants + parsed_data->num_spec_constants,
1339          [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1340       assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1341 
1342       std::vector<uint32_t> words;
1343       switch (parsed_spec_const->type) {
1344       case CLC_SPEC_CONSTANT_BOOL:
1345          words.push_back(consts->specializations[i].value.b);
1346          break;
1347       case CLC_SPEC_CONSTANT_INT32:
1348       case CLC_SPEC_CONSTANT_UINT32:
1349       case CLC_SPEC_CONSTANT_FLOAT:
1350          words.push_back(consts->specializations[i].value.u32);
1351          break;
1352       case CLC_SPEC_CONSTANT_INT16:
1353          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1354          break;
1355       case CLC_SPEC_CONSTANT_INT8:
1356          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1357          break;
1358       case CLC_SPEC_CONSTANT_UINT16:
1359          words.push_back((uint32_t)consts->specializations[i].value.u16);
1360          break;
1361       case CLC_SPEC_CONSTANT_UINT8:
1362          words.push_back((uint32_t)consts->specializations[i].value.u8);
1363          break;
1364       case CLC_SPEC_CONSTANT_DOUBLE:
1365       case CLC_SPEC_CONSTANT_INT64:
1366       case CLC_SPEC_CONSTANT_UINT64:
1367          words.resize(2);
1368          memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1369          break;
1370       case CLC_SPEC_CONSTANT_UNKNOWN:
1371          assert(0);
1372          break;
1373       }
1374 
1375       ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1376       assert(ret.second);
1377    }
1378 
1379    spvtools::Optimizer opt(spirv_target);
1380    opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1381 
1382    std::vector<uint32_t> result;
1383    if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1384       return false;
1385 
1386    out_spirv->size = result.size() * 4;
1387    out_spirv->data = malloc(out_spirv->size);
1388    memcpy(out_spirv->data, result.data(), out_spirv->size);
1389    return true;
1390 }
1391 
1392 static void
1393 clc_dump_llvm(const llvm::Module *mod, FILE *f)
1394 {
1395    std::string out;
1396    raw_string_ostream os(out);
1397 
1398    mod->print(os, nullptr);
1399    os.flush();
1400 
1401    fwrite(out.c_str(), out.size(), 1, f);
1402 }
1403 
1404 void
1405 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1406 {
1407    spvtools::SpirvTools tools(spirv_target);
1408    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1409    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1410    std::string out;
1411    tools.Disassemble(bin, &out,
1412                      SPV_BINARY_TO_TEXT_OPTION_INDENT |
1413                      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1414    fwrite(out.c_str(), out.size(), 1, f);
1415 }
1416 
1417 void
1418 clc_free_spir_binary(struct clc_binary *spir)
1419 {
1420    free(spir->data);
1421 }
1422 
1423 void
1424 clc_free_spirv_binary(struct clc_binary *spvbin)
1425 {
1426    free(spvbin->data);
1427 }
1428 
1429 void
1430 initialize_llvm_once(void)
1431 {
1432    LLVMInitializeAllTargets();
1433    LLVMInitializeAllTargetInfos();
1434    LLVMInitializeAllTargetMCs();
1435    LLVMInitializeAllAsmParsers();
1436    LLVMInitializeAllAsmPrinters();
1437 }
1438 
1439 std::once_flag initialize_llvm_once_flag;
1440 
1441 void
1442 clc_initialize_llvm(void)
1443 {
1444    std::call_once(initialize_llvm_once_flag,
1445                   []() { initialize_llvm_once(); });
1446 }
1447