• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25 
26 #include <sstream>
27 #include <mutex>
28 
29 #include <llvm/ADT/ArrayRef.h>
30 #include <llvm/IR/DiagnosticPrinter.h>
31 #include <llvm/IR/DiagnosticInfo.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Type.h>
34 #include <llvm/Support/raw_ostream.h>
35 #include <llvm/Bitcode/BitcodeWriter.h>
36 #include <llvm/Bitcode/BitcodeReader.h>
37 #include <llvm-c/Core.h>
38 #include <llvm-c/Target.h>
39 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
40 
41 #include <clang/CodeGen/CodeGenAction.h>
42 #include <clang/Lex/PreprocessorOptions.h>
43 #include <clang/Frontend/CompilerInstance.h>
44 #include <clang/Frontend/TextDiagnosticBuffer.h>
45 #include <clang/Frontend/TextDiagnosticPrinter.h>
46 #include <clang/Basic/TargetInfo.h>
47 
48 #include <spirv-tools/libspirv.hpp>
49 #include <spirv-tools/linker.hpp>
50 #include <spirv-tools/optimizer.hpp>
51 
52 #include "util/macros.h"
53 #include "glsl_types.h"
54 
55 #include "spirv.h"
56 
57 #ifdef USE_STATIC_OPENCL_C_H
58 #if LLVM_VERSION_MAJOR < 15
59 #include "opencl-c.h.h"
60 #endif
61 #include "opencl-c-base.h.h"
62 #endif
63 
64 #include "clc_helpers.h"
65 
66 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
67 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
68 
69 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
70 
71 using ::llvm::Function;
72 using ::llvm::LLVMContext;
73 using ::llvm::Module;
74 using ::llvm::raw_string_ostream;
75 
76 static void
llvm_log_handler(const::llvm::DiagnosticInfo & di,void * data)77 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
78    const clc_logger *logger = static_cast<clc_logger *>(data);
79 
80    std::string log;
81    raw_string_ostream os { log };
82    ::llvm::DiagnosticPrinterRawOStream printer { os };
83    di.print(printer);
84 
85    clc_error(logger, "%s", log.c_str());
86 }
87 
88 class SPIRVKernelArg {
89 public:
SPIRVKernelArg(uint32_t id,uint32_t typeId)90    SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
91                                                   addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
92                                                   accessQualifier(0),
93                                                   typeQualifier(0) { }
~SPIRVKernelArg()94    ~SPIRVKernelArg() { }
95 
96    uint32_t id;
97    uint32_t typeId;
98    std::string name;
99    std::string typeName;
100    enum clc_kernel_arg_address_qualifier addrQualifier;
101    unsigned accessQualifier;
102    unsigned typeQualifier;
103 };
104 
105 class SPIRVKernelInfo {
106 public:
SPIRVKernelInfo(uint32_t fid,const char * nm)107    SPIRVKernelInfo(uint32_t fid, const char *nm)
108       : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
~SPIRVKernelInfo()109    ~SPIRVKernelInfo() { }
110 
111    uint32_t funcId;
112    std::string name;
113    std::vector<SPIRVKernelArg> args;
114    unsigned vecHint;
115    unsigned localSize[3];
116    unsigned localSizeHint[3];
117 };
118 
119 class SPIRVKernelParser {
120 public:
SPIRVKernelParser()121    SPIRVKernelParser() : curKernel(NULL)
122    {
123       ctx = spvContextCreate(spirv_target);
124    }
125 
~SPIRVKernelParser()126    ~SPIRVKernelParser()
127    {
128      spvContextDestroy(ctx);
129    }
130 
parseEntryPoint(const spv_parsed_instruction_t * ins)131    void parseEntryPoint(const spv_parsed_instruction_t *ins)
132    {
133       assert(ins->num_operands >= 3);
134 
135       const spv_parsed_operand_t *op = &ins->operands[1];
136 
137       assert(op->type == SPV_OPERAND_TYPE_ID);
138 
139       uint32_t funcId = ins->words[op->offset];
140 
141       for (auto &iter : kernels) {
142          if (funcId == iter.funcId)
143             return;
144       }
145 
146       op = &ins->operands[2];
147       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
148       const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
149 
150       kernels.push_back(SPIRVKernelInfo(funcId, name));
151    }
152 
parseFunction(const spv_parsed_instruction_t * ins)153    void parseFunction(const spv_parsed_instruction_t *ins)
154    {
155       assert(ins->num_operands == 4);
156 
157       const spv_parsed_operand_t *op = &ins->operands[1];
158 
159       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
160 
161       uint32_t funcId = ins->words[op->offset];
162 
163       for (auto &kernel : kernels) {
164          if (funcId == kernel.funcId && !kernel.args.size()) {
165             curKernel = &kernel;
166 	    return;
167          }
168       }
169    }
170 
parseFunctionParam(const spv_parsed_instruction_t * ins)171    void parseFunctionParam(const spv_parsed_instruction_t *ins)
172    {
173       const spv_parsed_operand_t *op;
174       uint32_t id, typeId;
175 
176       if (!curKernel)
177          return;
178 
179       assert(ins->num_operands == 2);
180       op = &ins->operands[0];
181       assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
182       typeId = ins->words[op->offset];
183       op = &ins->operands[1];
184       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
185       id = ins->words[op->offset];
186       curKernel->args.push_back(SPIRVKernelArg(id, typeId));
187    }
188 
parseName(const spv_parsed_instruction_t * ins)189    void parseName(const spv_parsed_instruction_t *ins)
190    {
191       const spv_parsed_operand_t *op;
192       const char *name;
193       uint32_t id;
194 
195       assert(ins->num_operands == 2);
196 
197       op = &ins->operands[0];
198       assert(op->type == SPV_OPERAND_TYPE_ID);
199       id = ins->words[op->offset];
200       op = &ins->operands[1];
201       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
202       name = reinterpret_cast<const char *>(ins->words + op->offset);
203 
204       for (auto &kernel : kernels) {
205          for (auto &arg : kernel.args) {
206             if (arg.id == id && arg.name.empty()) {
207               arg.name = name;
208               break;
209 	    }
210          }
211       }
212    }
213 
parseTypePointer(const spv_parsed_instruction_t * ins)214    void parseTypePointer(const spv_parsed_instruction_t *ins)
215    {
216       enum clc_kernel_arg_address_qualifier addrQualifier;
217       uint32_t typeId, storageClass;
218       const spv_parsed_operand_t *op;
219 
220       assert(ins->num_operands == 3);
221 
222       op = &ins->operands[0];
223       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
224       typeId = ins->words[op->offset];
225 
226       op = &ins->operands[1];
227       assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
228       storageClass = ins->words[op->offset];
229       switch (storageClass) {
230       case SpvStorageClassCrossWorkgroup:
231          addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
232          break;
233       case SpvStorageClassWorkgroup:
234          addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
235          break;
236       case SpvStorageClassUniformConstant:
237          addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
238          break;
239       default:
240          addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
241          break;
242       }
243 
244       for (auto &kernel : kernels) {
245 	 for (auto &arg : kernel.args) {
246             if (arg.typeId == typeId) {
247                arg.addrQualifier = addrQualifier;
248                if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
249                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
250             }
251          }
252       }
253    }
254 
parseOpString(const spv_parsed_instruction_t * ins)255    void parseOpString(const spv_parsed_instruction_t *ins)
256    {
257       const spv_parsed_operand_t *op;
258       std::string str;
259 
260       assert(ins->num_operands == 2);
261 
262       op = &ins->operands[1];
263       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
264       str = reinterpret_cast<const char *>(ins->words + op->offset);
265 
266       size_t start = 0;
267       enum class string_type {
268          arg_type,
269          arg_type_qual,
270       } str_type;
271 
272       if (str.find("kernel_arg_type.") == 0) {
273          start = sizeof("kernel_arg_type.") - 1;
274          str_type = string_type::arg_type;
275       } else if (str.find("kernel_arg_type_qual.") == 0) {
276          start = sizeof("kernel_arg_type_qual.") - 1;
277          str_type = string_type::arg_type_qual;
278       } else {
279          return;
280       }
281 
282       for (auto &kernel : kernels) {
283          size_t pos;
284 
285 	 pos = str.find(kernel.name, start);
286          if (pos == std::string::npos ||
287              pos != start || str[start + kernel.name.size()] != '.')
288             continue;
289 
290 	 pos = start + kernel.name.size();
291          if (str[pos++] != '.')
292             continue;
293 
294          for (auto &arg : kernel.args) {
295             if (arg.name.empty())
296                break;
297 
298             size_t entryEnd = str.find(',', pos);
299 	    if (entryEnd == std::string::npos)
300                break;
301 
302             std::string entryVal = str.substr(pos, entryEnd - pos);
303             pos = entryEnd + 1;
304 
305             if (str_type == string_type::arg_type) {
306                arg.typeName = std::move(entryVal);
307             } else if (str_type == string_type::arg_type_qual) {
308                if (entryVal.find("const") != std::string::npos)
309                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
310             }
311          }
312       }
313    }
314 
applyDecoration(uint32_t id,const spv_parsed_instruction_t * ins)315    void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
316    {
317       auto iter = decorationGroups.find(id);
318       if (iter != decorationGroups.end()) {
319          for (uint32_t entry : iter->second)
320             applyDecoration(entry, ins);
321          return;
322       }
323 
324       const spv_parsed_operand_t *op;
325       uint32_t decoration;
326 
327       assert(ins->num_operands >= 2);
328 
329       op = &ins->operands[1];
330       assert(op->type == SPV_OPERAND_TYPE_DECORATION);
331       decoration = ins->words[op->offset];
332 
333       if (decoration == SpvDecorationSpecId) {
334          uint32_t spec_id = ins->words[ins->operands[2].offset];
335          for (auto &c : specConstants) {
336             if (c.second.id == spec_id) {
337                assert(c.first == id);
338                return;
339             }
340          }
341          specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
342          return;
343       }
344 
345       for (auto &kernel : kernels) {
346          for (auto &arg : kernel.args) {
347             if (arg.id == id) {
348                switch (decoration) {
349                case SpvDecorationVolatile:
350                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
351                   break;
352                case SpvDecorationConstant:
353                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
354                   break;
355                case SpvDecorationRestrict:
356                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
357                   break;
358                case SpvDecorationFuncParamAttr:
359                   op = &ins->operands[2];
360                   assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
361                   switch (ins->words[op->offset]) {
362                   case SpvFunctionParameterAttributeNoAlias:
363                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
364                      break;
365                   case SpvFunctionParameterAttributeNoWrite:
366                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
367                      break;
368                   }
369                   break;
370                }
371             }
372 
373          }
374       }
375    }
376 
parseOpDecorate(const spv_parsed_instruction_t * ins)377    void parseOpDecorate(const spv_parsed_instruction_t *ins)
378    {
379       const spv_parsed_operand_t *op;
380       uint32_t id;
381 
382       assert(ins->num_operands >= 2);
383 
384       op = &ins->operands[0];
385       assert(op->type == SPV_OPERAND_TYPE_ID);
386       id = ins->words[op->offset];
387 
388       applyDecoration(id, ins);
389    }
390 
parseOpGroupDecorate(const spv_parsed_instruction_t * ins)391    void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
392    {
393       assert(ins->num_operands >= 2);
394 
395       const spv_parsed_operand_t *op = &ins->operands[0];
396       assert(op->type == SPV_OPERAND_TYPE_ID);
397       uint32_t groupId = ins->words[op->offset];
398 
399       auto lowerBound = decorationGroups.lower_bound(groupId);
400       if (lowerBound != decorationGroups.end() &&
401           lowerBound->first == groupId)
402          // Group already filled out
403          return;
404 
405       auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
406       auto& vec = iter->second;
407       vec.reserve(ins->num_operands - 1);
408       for (uint32_t i = 1; i < ins->num_operands; ++i) {
409          op = &ins->operands[i];
410          assert(op->type == SPV_OPERAND_TYPE_ID);
411          vec.push_back(ins->words[op->offset]);
412       }
413    }
414 
parseOpTypeImage(const spv_parsed_instruction_t * ins)415    void parseOpTypeImage(const spv_parsed_instruction_t *ins)
416    {
417       const spv_parsed_operand_t *op;
418       uint32_t typeId;
419       unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
420 
421       op = &ins->operands[0];
422       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
423       typeId = ins->words[op->offset];
424 
425       if (ins->num_operands >= 9) {
426          op = &ins->operands[8];
427          assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
428          switch (ins->words[op->offset]) {
429          case SpvAccessQualifierReadOnly:
430             accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
431             break;
432          case SpvAccessQualifierWriteOnly:
433             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
434             break;
435          case SpvAccessQualifierReadWrite:
436             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
437                CLC_KERNEL_ARG_ACCESS_READ;
438             break;
439          }
440       }
441 
442       for (auto &kernel : kernels) {
443 	 for (auto &arg : kernel.args) {
444             if (arg.typeId == typeId) {
445                arg.accessQualifier = accessQualifier;
446                arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
447             }
448          }
449       }
450    }
451 
parseExecutionMode(const spv_parsed_instruction_t * ins)452    void parseExecutionMode(const spv_parsed_instruction_t *ins)
453    {
454       uint32_t executionMode = ins->words[ins->operands[1].offset];
455       uint32_t funcId = ins->words[ins->operands[0].offset];
456 
457       for (auto& kernel : kernels) {
458          if (kernel.funcId == funcId) {
459             switch (executionMode) {
460             case SpvExecutionModeVecTypeHint:
461                kernel.vecHint = ins->words[ins->operands[2].offset];
462                break;
463             case SpvExecutionModeLocalSize:
464                kernel.localSize[0] = ins->words[ins->operands[2].offset];
465                kernel.localSize[1] = ins->words[ins->operands[3].offset];
466                kernel.localSize[2] = ins->words[ins->operands[4].offset];
467             case SpvExecutionModeLocalSizeHint:
468                kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
469                kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
470                kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
471             default:
472                return;
473             }
474          }
475       }
476    }
477 
parseLiteralType(const spv_parsed_instruction_t * ins)478    void parseLiteralType(const spv_parsed_instruction_t *ins)
479    {
480       uint32_t typeId = ins->words[ins->operands[0].offset];
481       auto& literalType = literalTypes[typeId];
482       switch (ins->opcode) {
483       case SpvOpTypeBool:
484          literalType = CLC_SPEC_CONSTANT_BOOL;
485          break;
486       case SpvOpTypeFloat: {
487          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
488          switch (sizeInBits) {
489          case 32:
490             literalType = CLC_SPEC_CONSTANT_FLOAT;
491             break;
492          case 64:
493             literalType = CLC_SPEC_CONSTANT_DOUBLE;
494             break;
495          case 16:
496             /* Can't be used for a spec constant */
497             break;
498          default:
499             unreachable("Unexpected float bit size");
500          }
501          break;
502       }
503       case SpvOpTypeInt: {
504          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
505          bool isSigned = ins->words[ins->operands[2].offset];
506          if (isSigned) {
507             switch (sizeInBits) {
508             case 8:
509                literalType = CLC_SPEC_CONSTANT_INT8;
510                break;
511             case 16:
512                literalType = CLC_SPEC_CONSTANT_INT16;
513                break;
514             case 32:
515                literalType = CLC_SPEC_CONSTANT_INT32;
516                break;
517             case 64:
518                literalType = CLC_SPEC_CONSTANT_INT64;
519                break;
520             default:
521                unreachable("Unexpected int bit size");
522             }
523          } else {
524             switch (sizeInBits) {
525             case 8:
526                literalType = CLC_SPEC_CONSTANT_UINT8;
527                break;
528             case 16:
529                literalType = CLC_SPEC_CONSTANT_UINT16;
530                break;
531             case 32:
532                literalType = CLC_SPEC_CONSTANT_UINT32;
533                break;
534             case 64:
535                literalType = CLC_SPEC_CONSTANT_UINT64;
536                break;
537             default:
538                unreachable("Unexpected uint bit size");
539             }
540          }
541          break;
542       }
543       default:
544          unreachable("Unexpected type opcode");
545       }
546    }
547 
parseSpecConstant(const spv_parsed_instruction_t * ins)548    void parseSpecConstant(const spv_parsed_instruction_t *ins)
549    {
550       uint32_t id = ins->result_id;
551       for (auto& c : specConstants) {
552          if (c.first == id) {
553             auto& data = c.second;
554             switch (ins->opcode) {
555             case SpvOpSpecConstant: {
556                uint32_t typeId = ins->words[ins->operands[0].offset];
557 
558                // This better be an integer or float type
559                auto typeIter = literalTypes.find(typeId);
560                assert(typeIter != literalTypes.end());
561 
562                data.type = typeIter->second;
563                break;
564             }
565             case SpvOpSpecConstantFalse:
566             case SpvOpSpecConstantTrue:
567                data.type = CLC_SPEC_CONSTANT_BOOL;
568                break;
569             default:
570                unreachable("Composites and Ops are not directly specializable.");
571             }
572          }
573       }
574    }
575 
576    static spv_result_t
parseInstruction(void * data,const spv_parsed_instruction_t * ins)577    parseInstruction(void *data, const spv_parsed_instruction_t *ins)
578    {
579       SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
580 
581       switch (ins->opcode) {
582       case SpvOpName:
583          parser->parseName(ins);
584          break;
585       case SpvOpEntryPoint:
586          parser->parseEntryPoint(ins);
587          break;
588       case SpvOpFunction:
589          parser->parseFunction(ins);
590          break;
591       case SpvOpFunctionParameter:
592          parser->parseFunctionParam(ins);
593          break;
594       case SpvOpFunctionEnd:
595       case SpvOpLabel:
596          parser->curKernel = NULL;
597          break;
598       case SpvOpTypePointer:
599          parser->parseTypePointer(ins);
600          break;
601       case SpvOpTypeImage:
602          parser->parseOpTypeImage(ins);
603          break;
604       case SpvOpString:
605          parser->parseOpString(ins);
606          break;
607       case SpvOpDecorate:
608          parser->parseOpDecorate(ins);
609          break;
610       case SpvOpGroupDecorate:
611          parser->parseOpGroupDecorate(ins);
612          break;
613       case SpvOpExecutionMode:
614          parser->parseExecutionMode(ins);
615          break;
616       case SpvOpTypeBool:
617       case SpvOpTypeInt:
618       case SpvOpTypeFloat:
619          parser->parseLiteralType(ins);
620          break;
621       case SpvOpSpecConstant:
622       case SpvOpSpecConstantFalse:
623       case SpvOpSpecConstantTrue:
624          parser->parseSpecConstant(ins);
625          break;
626       default:
627          break;
628       }
629 
630       return SPV_SUCCESS;
631    }
632 
parseBinary(const struct clc_binary & spvbin,const struct clc_logger * logger)633    bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
634    {
635       /* 3 passes should be enough to retrieve all kernel information:
636        * 1st pass: all entry point name and number of args
637        * 2nd pass: argument names and type names
638        * 3rd pass: pointer type names
639        */
640       for (unsigned pass = 0; pass < 3; pass++) {
641          spv_diagnostic diagnostic = NULL;
642          auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
643                                       static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
644                                       NULL, parseInstruction, &diagnostic);
645 
646          if (result != SPV_SUCCESS) {
647             if (diagnostic && logger)
648                logger->error(logger->priv, diagnostic->error);
649             return false;
650          }
651       }
652 
653       return true;
654    }
655 
656    std::vector<SPIRVKernelInfo> kernels;
657    std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
658    std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
659    std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
660    SPIRVKernelInfo *curKernel;
661    spv_context ctx;
662 };
663 
664 bool
clc_spirv_get_kernels_info(const struct clc_binary * spvbin,const struct clc_kernel_info ** out_kernels,unsigned * num_kernels,const struct clc_parsed_spec_constant ** out_spec_constants,unsigned * num_spec_constants,const struct clc_logger * logger)665 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
666                            const struct clc_kernel_info **out_kernels,
667                            unsigned *num_kernels,
668                            const struct clc_parsed_spec_constant **out_spec_constants,
669                            unsigned *num_spec_constants,
670                            const struct clc_logger *logger)
671 {
672    struct clc_kernel_info *kernels;
673    struct clc_parsed_spec_constant *spec_constants = NULL;
674 
675    SPIRVKernelParser parser;
676 
677    if (!parser.parseBinary(*spvbin, logger))
678       return false;
679 
680    *num_kernels = parser.kernels.size();
681    *num_spec_constants = parser.specConstants.size();
682    if (!*num_kernels)
683       return false;
684 
685    kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
686                                                                sizeof(*kernels)));
687    assert(kernels);
688    for (unsigned i = 0; i < parser.kernels.size(); i++) {
689       kernels[i].name = strdup(parser.kernels[i].name.c_str());
690       kernels[i].num_args = parser.kernels[i].args.size();
691       kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
692       kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
693       memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
694       memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
695       if (!kernels[i].num_args)
696          continue;
697 
698       struct clc_kernel_arg *args;
699 
700       args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
701                                                        sizeof(*kernels->args)));
702       kernels[i].args = args;
703       assert(args);
704       for (unsigned j = 0; j < kernels[i].num_args; j++) {
705          if (!parser.kernels[i].args[j].name.empty())
706             args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
707          args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
708          args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
709          args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
710          args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
711       }
712    }
713 
714    if (*num_spec_constants) {
715       spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
716                                                                                   sizeof(*spec_constants)));
717       assert(spec_constants);
718 
719       for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
720          spec_constants[i] = parser.specConstants[i].second;
721       }
722    }
723 
724    *out_kernels = kernels;
725    *out_spec_constants = spec_constants;
726 
727    return true;
728 }
729 
730 void
clc_free_kernels_info(const struct clc_kernel_info * kernels,unsigned num_kernels)731 clc_free_kernels_info(const struct clc_kernel_info *kernels,
732                       unsigned num_kernels)
733 {
734    if (!kernels)
735       return;
736 
737    for (unsigned i = 0; i < num_kernels; i++) {
738       if (kernels[i].args) {
739          for (unsigned j = 0; j < kernels[i].num_args; j++) {
740             free((void *)kernels[i].args[j].name);
741             free((void *)kernels[i].args[j].type_name);
742          }
743       }
744       free((void *)kernels[i].name);
745    }
746 
747    free((void *)kernels);
748 }
749 
750 static std::unique_ptr<::llvm::Module>
clc_compile_to_llvm_module(LLVMContext & llvm_ctx,const struct clc_compile_args * args,const struct clc_logger * logger)751 clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
752                            const struct clc_compile_args *args,
753                            const struct clc_logger *logger)
754 {
755    std::string diag_log_str;
756    raw_string_ostream diag_log_stream { diag_log_str };
757 
758    std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
759 
760    clang::DiagnosticsEngine diag {
761       new clang::DiagnosticIDs,
762       new clang::DiagnosticOptions,
763       new clang::TextDiagnosticPrinter(diag_log_stream,
764                                        &c->getDiagnosticOpts())
765    };
766 
767    std::vector<const char *> clang_opts = {
768       args->source.name,
769       "-triple", "spir64-unknown-unknown",
770       // By default, clang prefers to use modules to pull in the default headers,
771       // which doesn't work with our technique of embedding the headers in our binary
772 #if LLVM_VERSION_MAJOR >= 15
773       "-fdeclare-opencl-builtins",
774 #else
775       "-finclude-default-header",
776 #endif
777 #if LLVM_VERSION_MAJOR >= 15
778       "-no-opaque-pointers",
779 #endif
780       // Add a default CL compiler version. Clang will pick the last one specified
781       // on the command line, so the app can override this one.
782       "-cl-std=cl1.2",
783       // The LLVM-SPIRV-Translator doesn't support memset with variable size
784       "-fno-builtin-memset",
785       // LLVM's optimizations can produce code that the translator can't translate
786       "-O0",
787       // Ensure inline functions are actually emitted
788       "-fgnu89-inline"
789    };
790    // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
791    // being provided by the caller here.
792    clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
793 
794    if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
795 #if LLVM_VERSION_MAJOR >= 10
796                                                   clang_opts,
797 #else
798                                                   clang_opts.data(),
799                                                   clang_opts.data() + clang_opts.size(),
800 #endif
801                                                   diag)) {
802       clc_error(logger, "Couldn't create Clang invocation.\n");
803       return {};
804    }
805 
806    if (diag.hasErrorOccurred()) {
807       clc_error(logger, "%sErrors occurred during Clang invocation.\n",
808                 diag_log_str.c_str());
809       return {};
810    }
811 
812    // This is a workaround for a Clang bug which causes the number
813    // of warnings and errors to be printed to stderr.
814    // http://www.llvm.org/bugs/show_bug.cgi?id=19735
815    c->getDiagnosticOpts().ShowCarets = false;
816 
817    c->createDiagnostics(new clang::TextDiagnosticPrinter(
818                            diag_log_stream,
819                            &c->getDiagnosticOpts()));
820 
821    c->setTarget(clang::TargetInfo::CreateTargetInfo(
822                    c->getDiagnostics(), c->getInvocation().TargetOpts));
823 
824    c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
825 
826 #ifdef USE_STATIC_OPENCL_C_H
827    c->getHeaderSearchOpts().UseBuiltinIncludes = false;
828    c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
829 
830    // Add opencl-c generic search path
831    {
832       ::llvm::SmallString<128> system_header_path;
833       ::llvm::sys::path::system_temp_directory(true, system_header_path);
834       ::llvm::sys::path::append(system_header_path, "openclon12");
835       c->getHeaderSearchOpts().AddPath(system_header_path.str(),
836                                        clang::frontend::Angled,
837                                        false, false);
838 
839 #if LLVM_VERSION_MAJOR < 15
840       ::llvm::sys::path::append(system_header_path, "opencl-c.h");
841       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
842          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
843 #endif
844 
845       ::llvm::sys::path::remove_filename(system_header_path);
846       ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
847       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
848          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
849    }
850 #else
851    c->getHeaderSearchOpts().UseBuiltinIncludes = true;
852    c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
853    c->getHeaderSearchOpts().ResourceDir = CLANG_RESOURCE_DIR;
854 
855    // Add opencl-c generic search path
856    c->getHeaderSearchOpts().AddPath(CLANG_RESOURCE_DIR,
857                                     clang::frontend::Angled,
858                                     false, false);
859    // Add opencl include
860 #if LLVM_VERSION_MAJOR >= 15
861    c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
862 #else
863    c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
864 #endif
865 #endif
866 
867 #if LLVM_VERSION_MAJOR >= 14
868    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
869    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
870    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
871    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
872    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
873    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
874    if (args->features.fp16) {
875       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
876    }
877    if (args->features.fp64) {
878       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
879       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
880    }
881    if (args->features.int64) {
882       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
883       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
884    }
885    if (args->features.images) {
886       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
887    }
888    if (args->features.images_read_write) {
889       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
890    }
891    if (args->features.images_write_3d) {
892       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
893       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
894    }
895    if (args->features.intel_subgroups) {
896       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
897    }
898    if (args->features.subgroups) {
899       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
900    }
901 #endif
902 
903    if (args->num_headers) {
904       ::llvm::SmallString<128> tmp_header_path;
905       ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
906       ::llvm::sys::path::append(tmp_header_path, "openclon12");
907 
908       c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
909                                        clang::frontend::Quoted,
910                                        false, false);
911 
912       for (size_t i = 0; i < args->num_headers; i++) {
913          auto path_copy = tmp_header_path;
914          ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
915          c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
916             ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
917       }
918    }
919 
920    c->getPreprocessorOpts().addRemappedFile(
921            args->source.name,
922            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
923 
924    // Compile the code
925    clang::EmitLLVMOnlyAction act(&llvm_ctx);
926    if (!c->ExecuteAction(act)) {
927       clc_error(logger, "%sError executing LLVM compilation action.\n",
928                 diag_log_str.c_str());
929       return {};
930    }
931 
932    return act.takeModule();
933 }
934 
935 static SPIRV::VersionNumber
spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)936 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
937 {
938    switch (version) {
939    case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
940    case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
941    case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
942    case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
943    case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
944 #ifdef HAS_SPIRV_1_4
945    case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
946 #endif
947    default:      return invalid_spirv_trans_version;
948    }
949 }
950 
951 static int
llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,LLVMContext & context,const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)952 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
953                   LLVMContext &context,
954                   const struct clc_compile_args *args,
955                   const struct clc_logger *logger,
956                   struct clc_binary *out_spirv)
957 {
958    std::string log;
959 
960    SPIRV::VersionNumber version =
961       spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
962    if (version == invalid_spirv_trans_version) {
963       clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
964       return -1;
965    }
966 
967    const char *const *extensions = NULL;
968    if (args)
969       extensions = args->allowed_spirv_extensions;
970    if (!extensions) {
971       /* The SPIR-V parser doesn't handle all extensions */
972       static const char *default_extensions[] = {
973          "SPV_EXT_shader_atomic_float_add",
974          "SPV_EXT_shader_atomic_float_min_max",
975          "SPV_KHR_float_controls",
976          NULL,
977       };
978       extensions = default_extensions;
979    }
980 
981    SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
982    for (int i = 0; extensions[i]; i++) {
983 #define EXT(X) \
984       if (strcmp(#X, extensions[i]) == 0) \
985          ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
986 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
987 #undef EXT
988    }
989    SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
990 
991 #if LLVM_VERSION_MAJOR >= 13
992    /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
993    spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
994 #endif
995 
996    std::ostringstream spv_stream;
997    if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
998       clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
999                 log.c_str());
1000       return -1;
1001    }
1002 
1003    const std::string spv_out = spv_stream.str();
1004    out_spirv->size = spv_out.size();
1005    out_spirv->data = malloc(out_spirv->size);
1006    memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1007 
1008    return 0;
1009 }
1010 
1011 int
clc_c_to_spir(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spir)1012 clc_c_to_spir(const struct clc_compile_args *args,
1013               const struct clc_logger *logger,
1014               struct clc_binary *out_spir)
1015 {
1016    clc_initialize_llvm();
1017 
1018    LLVMContext llvm_ctx;
1019    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1020                                          const_cast<clc_logger *>(logger));
1021 
1022    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1023    if (!mod)
1024       return -1;
1025 
1026    ::llvm::SmallVector<char, 0> buffer;
1027    ::llvm::BitcodeWriter writer(buffer);
1028    writer.writeModule(*mod);
1029 
1030    out_spir->size = buffer.size_in_bytes();
1031    out_spir->data = malloc(out_spir->size);
1032    memcpy(out_spir->data, buffer.data(), out_spir->size);
1033 
1034    return 0;
1035 }
1036 
1037 int
clc_c_to_spirv(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1038 clc_c_to_spirv(const struct clc_compile_args *args,
1039                const struct clc_logger *logger,
1040                struct clc_binary *out_spirv)
1041 {
1042    clc_initialize_llvm();
1043 
1044    LLVMContext llvm_ctx;
1045    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1046                                          const_cast<clc_logger *>(logger));
1047 
1048    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1049    if (!mod)
1050       return -1;
1051    return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1052 }
1053 
1054 int
clc_spir_to_spirv(const struct clc_binary * in_spir,const struct clc_logger * logger,struct clc_binary * out_spirv)1055 clc_spir_to_spirv(const struct clc_binary *in_spir,
1056                   const struct clc_logger *logger,
1057                   struct clc_binary *out_spirv)
1058 {
1059    clc_initialize_llvm();
1060 
1061    LLVMContext llvm_ctx;
1062    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1063                                          const_cast<clc_logger *>(logger));
1064 
1065    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1066    auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1067    if (!mod)
1068       return -1;
1069 
1070    return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1071 }
1072 
1073 class SPIRVMessageConsumer {
1074 public:
SPIRVMessageConsumer(const struct clc_logger * logger)1075    SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1076 
operator ()(spv_message_level_t level,const char * src,const spv_position_t & pos,const char * msg)1077    void operator()(spv_message_level_t level, const char *src,
1078                    const spv_position_t &pos, const char *msg)
1079    {
1080       if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1081          return;
1082 
1083       std::ostringstream message;
1084       message << "(file=" << src
1085               << ",line=" << pos.line
1086               << ",column=" << pos.column
1087               << ",index=" << pos.index
1088               << "): " << msg << "\n";
1089 
1090       if (level == SPV_MSG_WARNING)
1091          clc_warning(logger, "%s", message.str().c_str());
1092       else
1093          clc_error(logger, "%s", message.str().c_str());
1094    }
1095 
1096 private:
1097    const struct clc_logger *logger;
1098 };
1099 
1100 int
clc_link_spirv_binaries(const struct clc_linker_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1101 clc_link_spirv_binaries(const struct clc_linker_args *args,
1102                         const struct clc_logger *logger,
1103                         struct clc_binary *out_spirv)
1104 {
1105    std::vector<std::vector<uint32_t>> binaries;
1106 
1107    for (unsigned i = 0; i < args->num_in_objs; i++) {
1108       const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1109       std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1110       binaries.push_back(bin);
1111    }
1112 
1113    SPIRVMessageConsumer msgconsumer(logger);
1114    spvtools::Context context(spirv_target);
1115    context.SetMessageConsumer(msgconsumer);
1116    spvtools::LinkerOptions options;
1117    options.SetAllowPartialLinkage(args->create_library);
1118    options.SetCreateLibrary(args->create_library);
1119    std::vector<uint32_t> linkingResult;
1120    spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1121    if (status != SPV_SUCCESS) {
1122       return -1;
1123    }
1124 
1125    out_spirv->size = linkingResult.size() * 4;
1126    out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1127    memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1128 
1129    return 0;
1130 }
1131 
1132 int
clc_spirv_specialize(const struct clc_binary * in_spirv,const struct clc_parsed_spirv * parsed_data,const struct clc_spirv_specialization_consts * consts,struct clc_binary * out_spirv)1133 clc_spirv_specialize(const struct clc_binary *in_spirv,
1134                      const struct clc_parsed_spirv *parsed_data,
1135                      const struct clc_spirv_specialization_consts *consts,
1136                      struct clc_binary *out_spirv)
1137 {
1138    std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1139    for (unsigned i = 0; i < consts->num_specializations; ++i) {
1140       unsigned id = consts->specializations[i].id;
1141       auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1142          parsed_data->spec_constants + parsed_data->num_spec_constants,
1143          [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1144       assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1145 
1146       std::vector<uint32_t> words;
1147       switch (parsed_spec_const->type) {
1148       case CLC_SPEC_CONSTANT_BOOL:
1149          words.push_back(consts->specializations[i].value.b);
1150          break;
1151       case CLC_SPEC_CONSTANT_INT32:
1152       case CLC_SPEC_CONSTANT_UINT32:
1153       case CLC_SPEC_CONSTANT_FLOAT:
1154          words.push_back(consts->specializations[i].value.u32);
1155          break;
1156       case CLC_SPEC_CONSTANT_INT16:
1157          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1158          break;
1159       case CLC_SPEC_CONSTANT_INT8:
1160          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1161          break;
1162       case CLC_SPEC_CONSTANT_UINT16:
1163          words.push_back((uint32_t)consts->specializations[i].value.u16);
1164          break;
1165       case CLC_SPEC_CONSTANT_UINT8:
1166          words.push_back((uint32_t)consts->specializations[i].value.u8);
1167          break;
1168       case CLC_SPEC_CONSTANT_DOUBLE:
1169       case CLC_SPEC_CONSTANT_INT64:
1170       case CLC_SPEC_CONSTANT_UINT64:
1171          words.resize(2);
1172          memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1173          break;
1174       case CLC_SPEC_CONSTANT_UNKNOWN:
1175          assert(0);
1176          break;
1177       }
1178 
1179       ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1180       assert(ret.second);
1181    }
1182 
1183    spvtools::Optimizer opt(spirv_target);
1184    opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1185 
1186    std::vector<uint32_t> result;
1187    if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1188       return false;
1189 
1190    out_spirv->size = result.size() * 4;
1191    out_spirv->data = malloc(out_spirv->size);
1192    memcpy(out_spirv->data, result.data(), out_spirv->size);
1193    return true;
1194 }
1195 
1196 void
clc_dump_spirv(const struct clc_binary * spvbin,FILE * f)1197 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1198 {
1199    spvtools::SpirvTools tools(spirv_target);
1200    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1201    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1202    std::string out;
1203    tools.Disassemble(bin, &out,
1204                      SPV_BINARY_TO_TEXT_OPTION_INDENT |
1205                      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1206    fwrite(out.c_str(), out.size(), 1, f);
1207 }
1208 
1209 void
clc_free_spir_binary(struct clc_binary * spir)1210 clc_free_spir_binary(struct clc_binary *spir)
1211 {
1212    free(spir->data);
1213 }
1214 
1215 void
clc_free_spirv_binary(struct clc_binary * spvbin)1216 clc_free_spirv_binary(struct clc_binary *spvbin)
1217 {
1218    free(spvbin->data);
1219 }
1220 
1221 void
initialize_llvm_once(void)1222 initialize_llvm_once(void)
1223 {
1224    LLVMInitializeAllTargets();
1225    LLVMInitializeAllTargetInfos();
1226    LLVMInitializeAllTargetMCs();
1227    LLVMInitializeAllAsmParsers();
1228    LLVMInitializeAllAsmPrinters();
1229 }
1230 
1231 std::once_flag initialize_llvm_once_flag;
1232 
1233 void
clc_initialize_llvm(void)1234 clc_initialize_llvm(void)
1235 {
1236    std::call_once(initialize_llvm_once_flag,
1237                   []() { initialize_llvm_once(); });
1238 }
1239