• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25 
26 #include <cstdlib>
27 #include <filesystem>
28 #include <sstream>
29 #include <mutex>
30 
31 #include <llvm/ADT/ArrayRef.h>
32 #include <llvm/IR/DiagnosticPrinter.h>
33 #include <llvm/IR/DiagnosticInfo.h>
34 #include <llvm/IR/LegacyPassManager.h>
35 #include <llvm/IR/LLVMContext.h>
36 #include <llvm/IR/Type.h>
37 #include <llvm/MC/TargetRegistry.h>
38 #include <llvm/Target/TargetMachine.h>
39 #include <llvm/Support/raw_ostream.h>
40 #include <llvm/Bitcode/BitcodeWriter.h>
41 #include <llvm/Bitcode/BitcodeReader.h>
42 #include <llvm-c/Core.h>
43 #include <llvm-c/Target.h>
44 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
45 
46 #include <clang/Config/config.h>
47 #include <clang/Driver/Driver.h>
48 #include <clang/CodeGen/CodeGenAction.h>
49 #include <clang/Lex/PreprocessorOptions.h>
50 #include <clang/Frontend/CompilerInstance.h>
51 #include <clang/Frontend/TextDiagnosticBuffer.h>
52 #include <clang/Frontend/TextDiagnosticPrinter.h>
53 #include <clang/Basic/TargetInfo.h>
54 
55 #include <spirv-tools/libspirv.hpp>
56 #include <spirv-tools/linker.hpp>
57 #include <spirv-tools/optimizer.hpp>
58 
59 #include "util/macros.h"
60 #include "glsl_types.h"
61 
62 #include "spirv.h"
63 
64 #if DETECT_OS_UNIX
65 #include <dlfcn.h>
66 #endif
67 
68 #ifdef USE_STATIC_OPENCL_C_H
69 #include "opencl-c-base.h.h"
70 #include "opencl-c.h.h"
71 #endif
72 
73 #include "clc_helpers.h"
74 
75 namespace fs = std::filesystem;
76 
77 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
78 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
79 
80 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
81 
82 using ::llvm::Function;
83 using ::llvm::legacy::PassManager;
84 using ::llvm::LLVMContext;
85 using ::llvm::Module;
86 using ::llvm::raw_string_ostream;
87 using ::llvm::TargetRegistry;
88 using ::clang::driver::Driver;
89 
90 static void
91 clc_dump_llvm(const llvm::Module *mod, FILE *f);
92 
93 static void
llvm_log_handler(const::llvm::DiagnosticInfo & di,void * data)94 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
95    const clc_logger *logger = static_cast<clc_logger *>(data);
96 
97    std::string log;
98    raw_string_ostream os { log };
99    ::llvm::DiagnosticPrinterRawOStream printer { os };
100    di.print(printer);
101 
102    clc_error(logger, "%s", log.c_str());
103 }
104 
105 class SPIRVKernelArg {
106 public:
SPIRVKernelArg(uint32_t id,uint32_t typeId)107    SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
108                                                   addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
109                                                   accessQualifier(0),
110                                                   typeQualifier(0) { }
~SPIRVKernelArg()111    ~SPIRVKernelArg() { }
112 
113    uint32_t id;
114    uint32_t typeId;
115    std::string name;
116    std::string typeName;
117    enum clc_kernel_arg_address_qualifier addrQualifier;
118    unsigned accessQualifier;
119    unsigned typeQualifier;
120 };
121 
122 class SPIRVKernelInfo {
123 public:
SPIRVKernelInfo(uint32_t fid,const char * nm)124    SPIRVKernelInfo(uint32_t fid, const char *nm)
125       : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
~SPIRVKernelInfo()126    ~SPIRVKernelInfo() { }
127 
128    uint32_t funcId;
129    std::string name;
130    std::vector<SPIRVKernelArg> args;
131    unsigned vecHint;
132    unsigned localSize[3];
133    unsigned localSizeHint[3];
134 };
135 
136 class SPIRVKernelParser {
137 public:
SPIRVKernelParser()138    SPIRVKernelParser() : curKernel(NULL)
139    {
140       ctx = spvContextCreate(spirv_target);
141    }
142 
~SPIRVKernelParser()143    ~SPIRVKernelParser()
144    {
145      spvContextDestroy(ctx);
146    }
147 
parseEntryPoint(const spv_parsed_instruction_t * ins)148    void parseEntryPoint(const spv_parsed_instruction_t *ins)
149    {
150       assert(ins->num_operands >= 3);
151 
152       const spv_parsed_operand_t *op = &ins->operands[1];
153 
154       assert(op->type == SPV_OPERAND_TYPE_ID);
155 
156       uint32_t funcId = ins->words[op->offset];
157 
158       for (auto &iter : kernels) {
159          if (funcId == iter.funcId)
160             return;
161       }
162 
163       op = &ins->operands[2];
164       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
165       const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
166 
167       kernels.push_back(SPIRVKernelInfo(funcId, name));
168    }
169 
parseFunction(const spv_parsed_instruction_t * ins)170    void parseFunction(const spv_parsed_instruction_t *ins)
171    {
172       assert(ins->num_operands == 4);
173 
174       const spv_parsed_operand_t *op = &ins->operands[1];
175 
176       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
177 
178       uint32_t funcId = ins->words[op->offset];
179 
180       for (auto &kernel : kernels) {
181          if (funcId == kernel.funcId && !kernel.args.size()) {
182             curKernel = &kernel;
183 	    return;
184          }
185       }
186    }
187 
parseFunctionParam(const spv_parsed_instruction_t * ins)188    void parseFunctionParam(const spv_parsed_instruction_t *ins)
189    {
190       const spv_parsed_operand_t *op;
191       uint32_t id, typeId;
192 
193       if (!curKernel)
194          return;
195 
196       assert(ins->num_operands == 2);
197       op = &ins->operands[0];
198       assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
199       typeId = ins->words[op->offset];
200       op = &ins->operands[1];
201       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
202       id = ins->words[op->offset];
203       curKernel->args.push_back(SPIRVKernelArg(id, typeId));
204    }
205 
parseName(const spv_parsed_instruction_t * ins)206    void parseName(const spv_parsed_instruction_t *ins)
207    {
208       const spv_parsed_operand_t *op;
209       const char *name;
210       uint32_t id;
211 
212       assert(ins->num_operands == 2);
213 
214       op = &ins->operands[0];
215       assert(op->type == SPV_OPERAND_TYPE_ID);
216       id = ins->words[op->offset];
217       op = &ins->operands[1];
218       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
219       name = reinterpret_cast<const char *>(ins->words + op->offset);
220 
221       for (auto &kernel : kernels) {
222          for (auto &arg : kernel.args) {
223             if (arg.id == id && arg.name.empty()) {
224               arg.name = name;
225               break;
226 	    }
227          }
228       }
229    }
230 
parseTypePointer(const spv_parsed_instruction_t * ins)231    void parseTypePointer(const spv_parsed_instruction_t *ins)
232    {
233       enum clc_kernel_arg_address_qualifier addrQualifier;
234       uint32_t typeId, storageClass;
235       const spv_parsed_operand_t *op;
236 
237       assert(ins->num_operands == 3);
238 
239       op = &ins->operands[0];
240       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
241       typeId = ins->words[op->offset];
242 
243       op = &ins->operands[1];
244       assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
245       storageClass = ins->words[op->offset];
246       switch (storageClass) {
247       case SpvStorageClassCrossWorkgroup:
248          addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
249          break;
250       case SpvStorageClassWorkgroup:
251          addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
252          break;
253       case SpvStorageClassUniformConstant:
254          addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
255          break;
256       default:
257          addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
258          break;
259       }
260 
261       for (auto &kernel : kernels) {
262 	 for (auto &arg : kernel.args) {
263             if (arg.typeId == typeId) {
264                arg.addrQualifier = addrQualifier;
265                if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
266                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
267             }
268          }
269       }
270    }
271 
parseOpString(const spv_parsed_instruction_t * ins)272    void parseOpString(const spv_parsed_instruction_t *ins)
273    {
274       const spv_parsed_operand_t *op;
275       std::string str;
276 
277       assert(ins->num_operands == 2);
278 
279       op = &ins->operands[1];
280       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
281       str = reinterpret_cast<const char *>(ins->words + op->offset);
282 
283       size_t start = 0;
284       enum class string_type {
285          arg_type,
286          arg_type_qual,
287       } str_type;
288 
289       if (str.find("kernel_arg_type.") == 0) {
290          start = sizeof("kernel_arg_type.") - 1;
291          str_type = string_type::arg_type;
292       } else if (str.find("kernel_arg_type_qual.") == 0) {
293          start = sizeof("kernel_arg_type_qual.") - 1;
294          str_type = string_type::arg_type_qual;
295       } else {
296          return;
297       }
298 
299       for (auto &kernel : kernels) {
300          size_t pos;
301 
302 	 pos = str.find(kernel.name, start);
303          if (pos == std::string::npos ||
304              pos != start || str[start + kernel.name.size()] != '.')
305             continue;
306 
307 	 pos = start + kernel.name.size();
308          if (str[pos++] != '.')
309             continue;
310 
311          for (auto &arg : kernel.args) {
312             if (arg.name.empty())
313                break;
314 
315             size_t entryEnd = str.find(',', pos);
316 	    if (entryEnd == std::string::npos)
317                break;
318 
319             std::string entryVal = str.substr(pos, entryEnd - pos);
320             pos = entryEnd + 1;
321 
322             if (str_type == string_type::arg_type) {
323                arg.typeName = std::move(entryVal);
324             } else if (str_type == string_type::arg_type_qual) {
325                if (entryVal.find("const") != std::string::npos)
326                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
327             }
328          }
329       }
330    }
331 
applyDecoration(uint32_t id,const spv_parsed_instruction_t * ins)332    void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
333    {
334       auto iter = decorationGroups.find(id);
335       if (iter != decorationGroups.end()) {
336          for (uint32_t entry : iter->second)
337             applyDecoration(entry, ins);
338          return;
339       }
340 
341       const spv_parsed_operand_t *op;
342       uint32_t decoration;
343 
344       assert(ins->num_operands >= 2);
345 
346       op = &ins->operands[1];
347       assert(op->type == SPV_OPERAND_TYPE_DECORATION);
348       decoration = ins->words[op->offset];
349 
350       if (decoration == SpvDecorationSpecId) {
351          uint32_t spec_id = ins->words[ins->operands[2].offset];
352          for (auto &c : specConstants) {
353             if (c.second.id == spec_id) {
354                return;
355             }
356          }
357          specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
358          return;
359       }
360 
361       for (auto &kernel : kernels) {
362          for (auto &arg : kernel.args) {
363             if (arg.id == id) {
364                switch (decoration) {
365                case SpvDecorationVolatile:
366                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
367                   break;
368                case SpvDecorationConstant:
369                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
370                   break;
371                case SpvDecorationRestrict:
372                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
373                   break;
374                case SpvDecorationFuncParamAttr:
375                   op = &ins->operands[2];
376                   assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
377                   switch (ins->words[op->offset]) {
378                   case SpvFunctionParameterAttributeNoAlias:
379                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
380                      break;
381                   case SpvFunctionParameterAttributeNoWrite:
382                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
383                      break;
384                   }
385                   break;
386                }
387             }
388 
389          }
390       }
391    }
392 
parseOpDecorate(const spv_parsed_instruction_t * ins)393    void parseOpDecorate(const spv_parsed_instruction_t *ins)
394    {
395       const spv_parsed_operand_t *op;
396       uint32_t id;
397 
398       assert(ins->num_operands >= 2);
399 
400       op = &ins->operands[0];
401       assert(op->type == SPV_OPERAND_TYPE_ID);
402       id = ins->words[op->offset];
403 
404       applyDecoration(id, ins);
405    }
406 
parseOpGroupDecorate(const spv_parsed_instruction_t * ins)407    void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
408    {
409       assert(ins->num_operands >= 2);
410 
411       const spv_parsed_operand_t *op = &ins->operands[0];
412       assert(op->type == SPV_OPERAND_TYPE_ID);
413       uint32_t groupId = ins->words[op->offset];
414 
415       auto lowerBound = decorationGroups.lower_bound(groupId);
416       if (lowerBound != decorationGroups.end() &&
417           lowerBound->first == groupId)
418          // Group already filled out
419          return;
420 
421       auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
422       auto& vec = iter->second;
423       vec.reserve(ins->num_operands - 1);
424       for (uint32_t i = 1; i < ins->num_operands; ++i) {
425          op = &ins->operands[i];
426          assert(op->type == SPV_OPERAND_TYPE_ID);
427          vec.push_back(ins->words[op->offset]);
428       }
429    }
430 
parseOpTypeImage(const spv_parsed_instruction_t * ins)431    void parseOpTypeImage(const spv_parsed_instruction_t *ins)
432    {
433       const spv_parsed_operand_t *op;
434       uint32_t typeId;
435       unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
436 
437       op = &ins->operands[0];
438       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
439       typeId = ins->words[op->offset];
440 
441       if (ins->num_operands >= 9) {
442          op = &ins->operands[8];
443          assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
444          switch (ins->words[op->offset]) {
445          case SpvAccessQualifierReadOnly:
446             accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
447             break;
448          case SpvAccessQualifierWriteOnly:
449             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
450             break;
451          case SpvAccessQualifierReadWrite:
452             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
453                CLC_KERNEL_ARG_ACCESS_READ;
454             break;
455          }
456       }
457 
458       for (auto &kernel : kernels) {
459 	 for (auto &arg : kernel.args) {
460             if (arg.typeId == typeId) {
461                arg.accessQualifier = accessQualifier;
462                arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
463             }
464          }
465       }
466    }
467 
parseExecutionMode(const spv_parsed_instruction_t * ins)468    void parseExecutionMode(const spv_parsed_instruction_t *ins)
469    {
470       uint32_t executionMode = ins->words[ins->operands[1].offset];
471       uint32_t funcId = ins->words[ins->operands[0].offset];
472 
473       for (auto& kernel : kernels) {
474          if (kernel.funcId == funcId) {
475             switch (executionMode) {
476             case SpvExecutionModeVecTypeHint:
477                kernel.vecHint = ins->words[ins->operands[2].offset];
478                break;
479             case SpvExecutionModeLocalSize:
480                kernel.localSize[0] = ins->words[ins->operands[2].offset];
481                kernel.localSize[1] = ins->words[ins->operands[3].offset];
482                kernel.localSize[2] = ins->words[ins->operands[4].offset];
483             case SpvExecutionModeLocalSizeHint:
484                kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
485                kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
486                kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
487             default:
488                return;
489             }
490          }
491       }
492    }
493 
parseLiteralType(const spv_parsed_instruction_t * ins)494    void parseLiteralType(const spv_parsed_instruction_t *ins)
495    {
496       uint32_t typeId = ins->words[ins->operands[0].offset];
497       auto& literalType = literalTypes[typeId];
498       switch (ins->opcode) {
499       case SpvOpTypeBool:
500          literalType = CLC_SPEC_CONSTANT_BOOL;
501          break;
502       case SpvOpTypeFloat: {
503          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
504          switch (sizeInBits) {
505          case 32:
506             literalType = CLC_SPEC_CONSTANT_FLOAT;
507             break;
508          case 64:
509             literalType = CLC_SPEC_CONSTANT_DOUBLE;
510             break;
511          case 16:
512             /* Can't be used for a spec constant */
513             break;
514          default:
515             unreachable("Unexpected float bit size");
516          }
517          break;
518       }
519       case SpvOpTypeInt: {
520          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
521          bool isSigned = ins->words[ins->operands[2].offset];
522          if (isSigned) {
523             switch (sizeInBits) {
524             case 8:
525                literalType = CLC_SPEC_CONSTANT_INT8;
526                break;
527             case 16:
528                literalType = CLC_SPEC_CONSTANT_INT16;
529                break;
530             case 32:
531                literalType = CLC_SPEC_CONSTANT_INT32;
532                break;
533             case 64:
534                literalType = CLC_SPEC_CONSTANT_INT64;
535                break;
536             default:
537                unreachable("Unexpected int bit size");
538             }
539          } else {
540             switch (sizeInBits) {
541             case 8:
542                literalType = CLC_SPEC_CONSTANT_UINT8;
543                break;
544             case 16:
545                literalType = CLC_SPEC_CONSTANT_UINT16;
546                break;
547             case 32:
548                literalType = CLC_SPEC_CONSTANT_UINT32;
549                break;
550             case 64:
551                literalType = CLC_SPEC_CONSTANT_UINT64;
552                break;
553             default:
554                unreachable("Unexpected uint bit size");
555             }
556          }
557          break;
558       }
559       default:
560          unreachable("Unexpected type opcode");
561       }
562    }
563 
parseSpecConstant(const spv_parsed_instruction_t * ins)564    void parseSpecConstant(const spv_parsed_instruction_t *ins)
565    {
566       uint32_t id = ins->result_id;
567       for (auto& c : specConstants) {
568          if (c.first == id) {
569             auto& data = c.second;
570             switch (ins->opcode) {
571             case SpvOpSpecConstant: {
572                uint32_t typeId = ins->words[ins->operands[0].offset];
573 
574                // This better be an integer or float type
575                auto typeIter = literalTypes.find(typeId);
576                assert(typeIter != literalTypes.end());
577 
578                data.type = typeIter->second;
579                break;
580             }
581             case SpvOpSpecConstantFalse:
582             case SpvOpSpecConstantTrue:
583                data.type = CLC_SPEC_CONSTANT_BOOL;
584                break;
585             default:
586                unreachable("Composites and Ops are not directly specializable.");
587             }
588          }
589       }
590    }
591 
592    static spv_result_t
parseInstruction(void * data,const spv_parsed_instruction_t * ins)593    parseInstruction(void *data, const spv_parsed_instruction_t *ins)
594    {
595       SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
596 
597       switch (ins->opcode) {
598       case SpvOpName:
599          parser->parseName(ins);
600          break;
601       case SpvOpEntryPoint:
602          parser->parseEntryPoint(ins);
603          break;
604       case SpvOpFunction:
605          parser->parseFunction(ins);
606          break;
607       case SpvOpFunctionParameter:
608          parser->parseFunctionParam(ins);
609          break;
610       case SpvOpFunctionEnd:
611       case SpvOpLabel:
612          parser->curKernel = NULL;
613          break;
614       case SpvOpTypePointer:
615          parser->parseTypePointer(ins);
616          break;
617       case SpvOpTypeImage:
618          parser->parseOpTypeImage(ins);
619          break;
620       case SpvOpString:
621          parser->parseOpString(ins);
622          break;
623       case SpvOpDecorate:
624          parser->parseOpDecorate(ins);
625          break;
626       case SpvOpGroupDecorate:
627          parser->parseOpGroupDecorate(ins);
628          break;
629       case SpvOpExecutionMode:
630          parser->parseExecutionMode(ins);
631          break;
632       case SpvOpTypeBool:
633       case SpvOpTypeInt:
634       case SpvOpTypeFloat:
635          parser->parseLiteralType(ins);
636          break;
637       case SpvOpSpecConstant:
638       case SpvOpSpecConstantFalse:
639       case SpvOpSpecConstantTrue:
640          parser->parseSpecConstant(ins);
641          break;
642       default:
643          break;
644       }
645 
646       return SPV_SUCCESS;
647    }
648 
parseBinary(const struct clc_binary & spvbin,const struct clc_logger * logger)649    bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
650    {
651       /* 3 passes should be enough to retrieve all kernel information:
652        * 1st pass: all entry point name and number of args
653        * 2nd pass: argument names and type names
654        * 3rd pass: pointer type names
655        */
656       for (unsigned pass = 0; pass < 3; pass++) {
657          spv_diagnostic diagnostic = NULL;
658          auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
659                                       static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
660                                       NULL, parseInstruction, &diagnostic);
661 
662          if (result != SPV_SUCCESS) {
663             if (diagnostic && logger)
664                logger->error(logger->priv, diagnostic->error);
665             return false;
666          }
667       }
668 
669       return true;
670    }
671 
672    std::vector<SPIRVKernelInfo> kernels;
673    std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
674    std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
675    std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
676    SPIRVKernelInfo *curKernel;
677    spv_context ctx;
678 };
679 
680 bool
clc_spirv_get_kernels_info(const struct clc_binary * spvbin,const struct clc_kernel_info ** out_kernels,unsigned * num_kernels,const struct clc_parsed_spec_constant ** out_spec_constants,unsigned * num_spec_constants,const struct clc_logger * logger)681 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
682                            const struct clc_kernel_info **out_kernels,
683                            unsigned *num_kernels,
684                            const struct clc_parsed_spec_constant **out_spec_constants,
685                            unsigned *num_spec_constants,
686                            const struct clc_logger *logger)
687 {
688    struct clc_kernel_info *kernels = NULL;
689    struct clc_parsed_spec_constant *spec_constants = NULL;
690 
691    SPIRVKernelParser parser;
692 
693    if (!parser.parseBinary(*spvbin, logger))
694       return false;
695 
696    *num_kernels = parser.kernels.size();
697    *num_spec_constants = parser.specConstants.size();
698    if (*num_kernels) {
699       kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
700                                                                   sizeof(*kernels)));
701       assert(kernels);
702       for (unsigned i = 0; i < parser.kernels.size(); i++) {
703          kernels[i].name = strdup(parser.kernels[i].name.c_str());
704          kernels[i].num_args = parser.kernels[i].args.size();
705          kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
706          kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
707          memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
708          memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
709          if (!kernels[i].num_args)
710             continue;
711 
712          struct clc_kernel_arg *args;
713 
714          args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
715                                                                  sizeof(*kernels->args)));
716          kernels[i].args = args;
717          assert(args);
718          for (unsigned j = 0; j < kernels[i].num_args; j++) {
719             if (!parser.kernels[i].args[j].name.empty())
720                args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
721             args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
722             args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
723             args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
724             args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
725          }
726       }
727    }
728 
729    if (*num_spec_constants) {
730       spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
731                                                                                   sizeof(*spec_constants)));
732       assert(spec_constants);
733 
734       for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
735          spec_constants[i] = parser.specConstants[i].second;
736       }
737    }
738 
739    *out_kernels = kernels;
740    *out_spec_constants = spec_constants;
741 
742    return true;
743 }
744 
745 void
clc_free_kernels_info(const struct clc_kernel_info * kernels,unsigned num_kernels)746 clc_free_kernels_info(const struct clc_kernel_info *kernels,
747                       unsigned num_kernels)
748 {
749    if (!kernels)
750       return;
751 
752    for (unsigned i = 0; i < num_kernels; i++) {
753       if (kernels[i].args) {
754          for (unsigned j = 0; j < kernels[i].num_args; j++) {
755             free((void *)kernels[i].args[j].name);
756             free((void *)kernels[i].args[j].type_name);
757          }
758          free((void *)kernels[i].args);
759       }
760       free((void *)kernels[i].name);
761    }
762 
763    free((void *)kernels);
764 }
765 
766 static std::unique_ptr<::llvm::Module>
clc_compile_to_llvm_module(LLVMContext & llvm_ctx,const struct clc_compile_args * args,const struct clc_logger * logger)767 clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
768                            const struct clc_compile_args *args,
769                            const struct clc_logger *logger)
770 {
771    static_assert(std::has_unique_object_representations<clc_optional_features>(),
772                  "no padding allowed inside clc_optional_features");
773 
774    std::string diag_log_str;
775    raw_string_ostream diag_log_stream { diag_log_str };
776 
777    std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
778 
779    clang::DiagnosticsEngine diag {
780       new clang::DiagnosticIDs,
781       new clang::DiagnosticOptions,
782       new clang::TextDiagnosticPrinter(diag_log_stream,
783                                        &c->getDiagnosticOpts())
784    };
785 
786 #if LLVM_VERSION_MAJOR >= 17
787    const char *triple = args->address_bits == 32 ? "spirv-unknown-unknown" : "spirv64-unknown-unknown";
788 #else
789    const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spir64-unknown-unknown";
790 #endif
791 
792    std::vector<const char *> clang_opts = {
793       args->source.name,
794       "-triple", triple,
795       // By default, clang prefers to use modules to pull in the default headers,
796       // which doesn't work with our technique of embedding the headers in our binary
797       "-fdeclare-opencl-builtins",
798 #if LLVM_VERSION_MAJOR < 17
799       "-no-opaque-pointers",
800 #endif
801       // Add a default CL compiler version. Clang will pick the last one specified
802       // on the command line, so the app can override this one.
803       "-cl-std=cl1.2",
804       // The LLVM-SPIRV-Translator doesn't support memset with variable size
805       "-fno-builtin-memset",
806       // LLVM's optimizations can produce code that the translator can't translate
807       "-O0",
808       // Ensure inline functions are actually emitted
809       "-fgnu89-inline",
810    };
811 
812    // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
813    // being provided by the caller here.
814    clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
815 
816    if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
817                                                   clang_opts,
818                                                   diag)) {
819       clc_error(logger, "Couldn't create Clang invocation.\n");
820       return {};
821    }
822 
823    if (diag.hasErrorOccurred()) {
824       clc_error(logger, "%sErrors occurred during Clang invocation.\n",
825                 diag_log_str.c_str());
826       return {};
827    }
828 
829    // This is a workaround for a Clang bug which causes the number
830    // of warnings and errors to be printed to stderr.
831    // http://www.llvm.org/bugs/show_bug.cgi?id=19735
832    c->getDiagnosticOpts().ShowCarets = false;
833 
834    c->createDiagnostics(new clang::TextDiagnosticPrinter(
835                            diag_log_stream,
836                            &c->getDiagnosticOpts()));
837 
838    c->setTarget(clang::TargetInfo::CreateTargetInfo(
839                    c->getDiagnostics(), c->getInvocation().TargetOpts));
840 
841    c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
842 
843 #ifdef USE_STATIC_OPENCL_C_H
844    c->getHeaderSearchOpts().UseBuiltinIncludes = false;
845    c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
846 
847    // Add opencl-c generic search path
848    {
849       ::llvm::SmallString<128> system_header_path;
850       ::llvm::sys::path::system_temp_directory(true, system_header_path);
851       ::llvm::sys::path::append(system_header_path, "openclon12");
852       c->getHeaderSearchOpts().AddPath(system_header_path.str(),
853                                        clang::frontend::Angled,
854                                        false, false);
855 
856       ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
857       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
858          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
859       // this line is actually important to make it include `opencl-c.h`
860       ::llvm::sys::path::remove_filename(system_header_path);
861       ::llvm::sys::path::append(system_header_path, "opencl-c.h");
862       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
863          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
864    }
865 #else
866 
867    Dl_info info;
868    if (dladdr((void *)clang::CompilerInvocation::CreateFromArgs, &info) == 0) {
869       clc_error(logger, "Couldn't find libclang path.\n");
870       return {};
871    }
872 
873    char *clang_path = realpath(info.dli_fname, NULL);
874    if (clang_path == nullptr) {
875       clc_error(logger, "Couldn't find libclang path.\n");
876       return {};
877    }
878 
879    // GetResourcePath is a way to retrive the actual libclang resource dir based on a given binary
880    // or library.
881    auto clang_res_path =
882       fs::path(Driver::GetResourcesPath(std::string(clang_path), CLANG_RESOURCE_DIR)) / "include";
883    free(clang_path);
884 
885    c->getHeaderSearchOpts().UseBuiltinIncludes = true;
886    c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
887    c->getHeaderSearchOpts().ResourceDir = clang_res_path.string();
888 
889    // Add opencl-c generic search path
890    c->getHeaderSearchOpts().AddPath(clang_res_path.string(),
891                                     clang::frontend::Angled,
892                                     false, false);
893 #endif
894 
895    // Enable/Disable optional OpenCL C features. Some can be toggled via `OpenCLExtensionsAsWritten`
896    // others we have to (un)define via macros ourselves.
897 
898    // Undefine clang added SPIR(V) defines so we don't magically enable extensions
899    c->getPreprocessorOpts().addMacroUndef("__SPIR__");
900    c->getPreprocessorOpts().addMacroUndef("__SPIRV__");
901 
902    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
903    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
904    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
905    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
906    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
907    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
908    c->getPreprocessorOpts().addMacroDef("cl_khr_expect_assume=1");
909 
910    bool needs_opencl_c_h = false;
911    if (args->features.fp16) {
912       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
913    }
914    if (args->features.fp64) {
915       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
916       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
917    }
918    if (args->features.int64) {
919       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
920       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
921    } else {
922       // clang defines this unconditionally, we need to fix that.
923       c->getPreprocessorOpts().addMacroUndef("__opencl_c_int64");
924    }
925    if (args->features.images) {
926       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
927    } else {
928       // clang defines this unconditionally, we need to fix that.
929       c->getPreprocessorOpts().addMacroUndef("__IMAGE_SUPPORT__");
930    }
931    if (args->features.images_read_write) {
932       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
933    }
934    if (args->features.images_write_3d) {
935       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
936       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
937    }
938    if (args->features.intel_subgroups) {
939       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
940       needs_opencl_c_h = true;
941    }
942    if (args->features.subgroups) {
943       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_subgroups");
944       if (args->features.subgroups_shuffle) {
945          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle=1");
946       }
947       if (args->features.subgroups_shuffle_relative) {
948          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle_relative=1");
949       }
950    }
951    if (args->features.subgroups_ifp) {
952       assert(args->features.subgroups);
953       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
954    }
955    if (args->features.integer_dot_product) {
956       c->getPreprocessorOpts().addMacroDef("cl_khr_integer_dot_product=1");
957       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit_packed=1");
958       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit=1");
959    }
960 
961    // Add opencl include
962    c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
963    if (needs_opencl_c_h) {
964       c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
965    }
966 
967    if (args->num_headers) {
968       ::llvm::SmallString<128> tmp_header_path;
969       ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
970       ::llvm::sys::path::append(tmp_header_path, "openclon12");
971 
972       c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
973                                        clang::frontend::Quoted,
974                                        false, false);
975 
976       for (size_t i = 0; i < args->num_headers; i++) {
977          auto path_copy = tmp_header_path;
978          ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
979          c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
980             ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
981       }
982    }
983 
984    c->getPreprocessorOpts().addRemappedFile(
985            args->source.name,
986            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
987 
988    // Compile the code
989    clang::EmitLLVMOnlyAction act(&llvm_ctx);
990    if (!c->ExecuteAction(act)) {
991       clc_error(logger, "%sError executing LLVM compilation action.\n",
992                 diag_log_str.c_str());
993       return {};
994    }
995 
996    auto mod = act.takeModule();
997 
998    if (clc_debug_flags() & CLC_DEBUG_DUMP_LLVM)
999       clc_dump_llvm(mod.get(), stdout);
1000 
1001    return mod;
1002 }
1003 
1004 static SPIRV::VersionNumber
spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)1005 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
1006 {
1007    switch (version) {
1008    case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
1009    case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
1010    case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
1011    case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
1012    case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
1013    case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
1014    default:      return invalid_spirv_trans_version;
1015    }
1016 }
1017 
1018 static int
llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,LLVMContext & context,const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1019 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
1020                   LLVMContext &context,
1021                   const struct clc_compile_args *args,
1022                   const struct clc_logger *logger,
1023                   struct clc_binary *out_spirv)
1024 {
1025    std::string log;
1026 
1027    SPIRV::VersionNumber version =
1028       spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
1029    if (version == invalid_spirv_trans_version) {
1030       clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
1031       return -1;
1032    }
1033 
1034    const char *const *extensions = NULL;
1035    if (args)
1036       extensions = args->allowed_spirv_extensions;
1037    if (!extensions) {
1038       /* The SPIR-V parser doesn't handle all extensions */
1039       static const char *default_extensions[] = {
1040          "SPV_EXT_shader_atomic_float_add",
1041          "SPV_EXT_shader_atomic_float_min_max",
1042          "SPV_KHR_float_controls",
1043          NULL,
1044       };
1045       extensions = default_extensions;
1046    }
1047 
1048    SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
1049    for (int i = 0; extensions[i]; i++) {
1050 #define EXT(X) \
1051       if (strcmp(#X, extensions[i]) == 0) \
1052          ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
1053 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
1054 #undef EXT
1055    }
1056    SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
1057 
1058    /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
1059    spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
1060 
1061 #if LLVM_VERSION_MAJOR >= 17
1062    if (args->use_llvm_spirv_target) {
1063       const char *triple = args->address_bits == 32 ? "spirv-unknown-unknown" : "spirv64-unknown-unknown";
1064       std::string error_msg("");
1065       auto target = TargetRegistry::lookupTarget(triple, error_msg);
1066       if (target) {
1067          auto TM = target->createTargetMachine(
1068             triple, "", "", {}, std::nullopt, std::nullopt,
1069 #if LLVM_VERSION_MAJOR >= 18
1070             ::llvm::CodeGenOptLevel::None
1071 #else
1072             ::llvm::CodeGenOpt::None
1073 #endif
1074          );
1075 
1076          auto PM = PassManager();
1077          ::llvm::SmallVector<char> buf;
1078          auto OS = ::llvm::raw_svector_ostream(buf);
1079          TM->addPassesToEmitFile(
1080             PM, OS, nullptr,
1081 #if LLVM_VERSION_MAJOR >= 18
1082             ::llvm::CodeGenFileType::ObjectFile
1083 #else
1084             ::llvm::CGFT_ObjectFile
1085 #endif
1086          );
1087 
1088          PM.run(*mod);
1089 
1090          out_spirv->size = buf.size_in_bytes();
1091          out_spirv->data = malloc(out_spirv->size);
1092          memcpy(out_spirv->data, buf.data(), out_spirv->size);
1093          return 0;
1094       } else {
1095          clc_error(logger, "LLVM SPIR-V target not found.\n");
1096          return -1;
1097       }
1098    }
1099 #endif
1100 
1101    std::ostringstream spv_stream;
1102    if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
1103       clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
1104                 log.c_str());
1105       return -1;
1106    }
1107 
1108    const std::string spv_out = spv_stream.str();
1109    out_spirv->size = spv_out.size();
1110    out_spirv->data = malloc(out_spirv->size);
1111    memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1112 
1113    return 0;
1114 }
1115 
1116 int
clc_c_to_spir(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spir)1117 clc_c_to_spir(const struct clc_compile_args *args,
1118               const struct clc_logger *logger,
1119               struct clc_binary *out_spir)
1120 {
1121    clc_initialize_llvm();
1122 
1123    LLVMContext llvm_ctx;
1124    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1125                                          const_cast<clc_logger *>(logger));
1126 
1127    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1128    if (!mod)
1129       return -1;
1130 
1131    ::llvm::SmallVector<char, 0> buffer;
1132    ::llvm::BitcodeWriter writer(buffer);
1133    writer.writeModule(*mod);
1134 
1135    out_spir->size = buffer.size_in_bytes();
1136    out_spir->data = malloc(out_spir->size);
1137    memcpy(out_spir->data, buffer.data(), out_spir->size);
1138 
1139    return 0;
1140 }
1141 
1142 int
clc_c_to_spirv(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1143 clc_c_to_spirv(const struct clc_compile_args *args,
1144                const struct clc_logger *logger,
1145                struct clc_binary *out_spirv)
1146 {
1147    clc_initialize_llvm();
1148 
1149    LLVMContext llvm_ctx;
1150    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1151                                          const_cast<clc_logger *>(logger));
1152 
1153    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1154    if (!mod)
1155       return -1;
1156    return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1157 }
1158 
1159 int
clc_spir_to_spirv(const struct clc_binary * in_spir,const struct clc_logger * logger,struct clc_binary * out_spirv)1160 clc_spir_to_spirv(const struct clc_binary *in_spir,
1161                   const struct clc_logger *logger,
1162                   struct clc_binary *out_spirv)
1163 {
1164    clc_initialize_llvm();
1165 
1166    LLVMContext llvm_ctx;
1167    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1168                                          const_cast<clc_logger *>(logger));
1169 
1170    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1171    auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1172    if (!mod)
1173       return -1;
1174 
1175    return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1176 }
1177 
1178 class SPIRVMessageConsumer {
1179 public:
SPIRVMessageConsumer(const struct clc_logger * logger)1180    SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1181 
operator ()(spv_message_level_t level,const char * src,const spv_position_t & pos,const char * msg)1182    void operator()(spv_message_level_t level, const char *src,
1183                    const spv_position_t &pos, const char *msg)
1184    {
1185       if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1186          return;
1187 
1188       std::ostringstream message;
1189       message << "(file=" << (src ? src : "input")
1190               << ",line=" << pos.line
1191               << ",column=" << pos.column
1192               << ",index=" << pos.index
1193               << "): " << msg << "\n";
1194 
1195       if (level == SPV_MSG_WARNING)
1196          clc_warning(logger, "%s", message.str().c_str());
1197       else
1198          clc_error(logger, "%s", message.str().c_str());
1199    }
1200 
1201 private:
1202    const struct clc_logger *logger;
1203 };
1204 
1205 int
clc_link_spirv_binaries(const struct clc_linker_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1206 clc_link_spirv_binaries(const struct clc_linker_args *args,
1207                         const struct clc_logger *logger,
1208                         struct clc_binary *out_spirv)
1209 {
1210    std::vector<std::vector<uint32_t>> binaries;
1211 
1212    for (unsigned i = 0; i < args->num_in_objs; i++) {
1213       const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1214       std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1215       binaries.push_back(bin);
1216    }
1217 
1218    SPIRVMessageConsumer msgconsumer(logger);
1219    spvtools::Context context(spirv_target);
1220    context.SetMessageConsumer(msgconsumer);
1221    spvtools::LinkerOptions options;
1222    options.SetAllowPartialLinkage(args->create_library);
1223    options.SetCreateLibrary(args->create_library);
1224    std::vector<uint32_t> linkingResult;
1225    spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1226    if (status != SPV_SUCCESS) {
1227       return -1;
1228    }
1229 
1230    out_spirv->size = linkingResult.size() * 4;
1231    out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1232    memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1233 
1234    return 0;
1235 }
1236 
1237 bool
clc_validate_spirv(const struct clc_binary * spirv,const struct clc_logger * logger,const struct clc_validator_options * options)1238 clc_validate_spirv(const struct clc_binary *spirv,
1239                    const struct clc_logger *logger,
1240                    const struct clc_validator_options *options)
1241 {
1242    SPIRVMessageConsumer msgconsumer(logger);
1243    spvtools::SpirvTools tools(spirv_target);
1244    tools.SetMessageConsumer(msgconsumer);
1245    spvtools::ValidatorOptions spirv_options;
1246    const uint32_t *data = static_cast<const uint32_t *>(spirv->data);
1247 
1248    if (options) {
1249       spirv_options.SetUniversalLimit(
1250          spv_validator_limit_max_function_args,
1251          options->limit_max_function_arg);
1252    }
1253 
1254    return tools.Validate(data, spirv->size / 4, spirv_options);
1255 }
1256 
1257 int
clc_spirv_specialize(const struct clc_binary * in_spirv,const struct clc_parsed_spirv * parsed_data,const struct clc_spirv_specialization_consts * consts,struct clc_binary * out_spirv)1258 clc_spirv_specialize(const struct clc_binary *in_spirv,
1259                      const struct clc_parsed_spirv *parsed_data,
1260                      const struct clc_spirv_specialization_consts *consts,
1261                      struct clc_binary *out_spirv)
1262 {
1263    std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1264    for (unsigned i = 0; i < consts->num_specializations; ++i) {
1265       unsigned id = consts->specializations[i].id;
1266       auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1267          parsed_data->spec_constants + parsed_data->num_spec_constants,
1268          [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1269       assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1270 
1271       std::vector<uint32_t> words;
1272       switch (parsed_spec_const->type) {
1273       case CLC_SPEC_CONSTANT_BOOL:
1274          words.push_back(consts->specializations[i].value.b);
1275          break;
1276       case CLC_SPEC_CONSTANT_INT32:
1277       case CLC_SPEC_CONSTANT_UINT32:
1278       case CLC_SPEC_CONSTANT_FLOAT:
1279          words.push_back(consts->specializations[i].value.u32);
1280          break;
1281       case CLC_SPEC_CONSTANT_INT16:
1282          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1283          break;
1284       case CLC_SPEC_CONSTANT_INT8:
1285          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1286          break;
1287       case CLC_SPEC_CONSTANT_UINT16:
1288          words.push_back((uint32_t)consts->specializations[i].value.u16);
1289          break;
1290       case CLC_SPEC_CONSTANT_UINT8:
1291          words.push_back((uint32_t)consts->specializations[i].value.u8);
1292          break;
1293       case CLC_SPEC_CONSTANT_DOUBLE:
1294       case CLC_SPEC_CONSTANT_INT64:
1295       case CLC_SPEC_CONSTANT_UINT64:
1296          words.resize(2);
1297          memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1298          break;
1299       case CLC_SPEC_CONSTANT_UNKNOWN:
1300          assert(0);
1301          break;
1302       }
1303 
1304       ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1305       assert(ret.second);
1306    }
1307 
1308    spvtools::Optimizer opt(spirv_target);
1309    opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1310 
1311    std::vector<uint32_t> result;
1312    if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1313       return false;
1314 
1315    out_spirv->size = result.size() * 4;
1316    out_spirv->data = malloc(out_spirv->size);
1317    memcpy(out_spirv->data, result.data(), out_spirv->size);
1318    return true;
1319 }
1320 
1321 static void
clc_dump_llvm(const llvm::Module * mod,FILE * f)1322 clc_dump_llvm(const llvm::Module *mod, FILE *f)
1323 {
1324    std::string out;
1325    raw_string_ostream os(out);
1326 
1327    mod->print(os, nullptr);
1328    os.flush();
1329 
1330    fwrite(out.c_str(), out.size(), 1, f);
1331 }
1332 
1333 void
clc_dump_spirv(const struct clc_binary * spvbin,FILE * f)1334 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1335 {
1336    spvtools::SpirvTools tools(spirv_target);
1337    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1338    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1339    std::string out;
1340    tools.Disassemble(bin, &out,
1341                      SPV_BINARY_TO_TEXT_OPTION_INDENT |
1342                      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1343    fwrite(out.c_str(), out.size(), 1, f);
1344 }
1345 
1346 void
clc_free_spir_binary(struct clc_binary * spir)1347 clc_free_spir_binary(struct clc_binary *spir)
1348 {
1349    free(spir->data);
1350 }
1351 
1352 void
clc_free_spirv_binary(struct clc_binary * spvbin)1353 clc_free_spirv_binary(struct clc_binary *spvbin)
1354 {
1355    free(spvbin->data);
1356 }
1357 
1358 void
initialize_llvm_once(void)1359 initialize_llvm_once(void)
1360 {
1361    LLVMInitializeAllTargets();
1362    LLVMInitializeAllTargetInfos();
1363    LLVMInitializeAllTargetMCs();
1364    LLVMInitializeAllAsmParsers();
1365    LLVMInitializeAllAsmPrinters();
1366 }
1367 
1368 std::once_flag initialize_llvm_once_flag;
1369 
1370 void
clc_initialize_llvm(void)1371 clc_initialize_llvm(void)
1372 {
1373    std::call_once(initialize_llvm_once_flag,
1374                   []() { initialize_llvm_once(); });
1375 }
1376