• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25 
26 #include <sstream>
27 
28 #include <llvm/ADT/ArrayRef.h>
29 #include <llvm/IR/DiagnosticPrinter.h>
30 #include <llvm/IR/DiagnosticInfo.h>
31 #include <llvm/IR/LLVMContext.h>
32 #include <llvm/IR/Type.h>
33 #include <llvm/Support/raw_ostream.h>
34 #include <llvm/Bitcode/BitcodeWriter.h>
35 #include <llvm/Bitcode/BitcodeReader.h>
36 #include <llvm-c/Core.h>
37 #include <llvm-c/Target.h>
38 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
39 
40 #include <clang/CodeGen/CodeGenAction.h>
41 #include <clang/Lex/PreprocessorOptions.h>
42 #include <clang/Frontend/CompilerInstance.h>
43 #include <clang/Frontend/TextDiagnosticBuffer.h>
44 #include <clang/Frontend/TextDiagnosticPrinter.h>
45 #include <clang/Basic/TargetInfo.h>
46 
47 #include <spirv-tools/libspirv.hpp>
48 #include <spirv-tools/linker.hpp>
49 #include <spirv-tools/optimizer.hpp>
50 
51 #include "util/macros.h"
52 #include "glsl_types.h"
53 
54 #include "spirv.h"
55 
56 #ifdef USE_STATIC_OPENCL_C_H
57 #include "opencl-c.h.h"
58 #include "opencl-c-base.h.h"
59 #endif
60 
61 #include "clc_helpers.h"
62 
63 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
64 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
65 
66 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
67 
68 using ::llvm::Function;
69 using ::llvm::LLVMContext;
70 using ::llvm::Module;
71 using ::llvm::raw_string_ostream;
72 
73 static void
llvm_log_handler(const::llvm::DiagnosticInfo & di,void * data)74 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
75    raw_string_ostream os { *reinterpret_cast<std::string *>(data) };
76    ::llvm::DiagnosticPrinterRawOStream printer { os };
77    di.print(printer);
78 }
79 
80 class SPIRVKernelArg {
81 public:
SPIRVKernelArg(uint32_t id,uint32_t typeId)82    SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
83                                                   addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
84                                                   accessQualifier(0),
85                                                   typeQualifier(0) { }
~SPIRVKernelArg()86    ~SPIRVKernelArg() { }
87 
88    uint32_t id;
89    uint32_t typeId;
90    std::string name;
91    std::string typeName;
92    enum clc_kernel_arg_address_qualifier addrQualifier;
93    unsigned accessQualifier;
94    unsigned typeQualifier;
95 };
96 
97 class SPIRVKernelInfo {
98 public:
SPIRVKernelInfo(uint32_t fid,const char * nm)99    SPIRVKernelInfo(uint32_t fid, const char *nm) : funcId(fid), name(nm), vecHint(0) { }
~SPIRVKernelInfo()100    ~SPIRVKernelInfo() { }
101 
102    uint32_t funcId;
103    std::string name;
104    std::vector<SPIRVKernelArg> args;
105    unsigned vecHint;
106 };
107 
108 class SPIRVKernelParser {
109 public:
SPIRVKernelParser()110    SPIRVKernelParser() : curKernel(NULL)
111    {
112       ctx = spvContextCreate(spirv_target);
113    }
114 
~SPIRVKernelParser()115    ~SPIRVKernelParser()
116    {
117      spvContextDestroy(ctx);
118    }
119 
parseEntryPoint(const spv_parsed_instruction_t * ins)120    void parseEntryPoint(const spv_parsed_instruction_t *ins)
121    {
122       assert(ins->num_operands >= 3);
123 
124       const spv_parsed_operand_t *op = &ins->operands[1];
125 
126       assert(op->type == SPV_OPERAND_TYPE_ID);
127 
128       uint32_t funcId = ins->words[op->offset];
129 
130       for (auto &iter : kernels) {
131          if (funcId == iter.funcId)
132             return;
133       }
134 
135       op = &ins->operands[2];
136       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
137       const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
138 
139       kernels.push_back(SPIRVKernelInfo(funcId, name));
140    }
141 
parseFunction(const spv_parsed_instruction_t * ins)142    void parseFunction(const spv_parsed_instruction_t *ins)
143    {
144       assert(ins->num_operands == 4);
145 
146       const spv_parsed_operand_t *op = &ins->operands[1];
147 
148       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
149 
150       uint32_t funcId = ins->words[op->offset];
151 
152       for (auto &kernel : kernels) {
153          if (funcId == kernel.funcId && !kernel.args.size()) {
154             curKernel = &kernel;
155 	    return;
156          }
157       }
158    }
159 
parseFunctionParam(const spv_parsed_instruction_t * ins)160    void parseFunctionParam(const spv_parsed_instruction_t *ins)
161    {
162       const spv_parsed_operand_t *op;
163       uint32_t id, typeId;
164 
165       if (!curKernel)
166          return;
167 
168       assert(ins->num_operands == 2);
169       op = &ins->operands[0];
170       assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
171       typeId = ins->words[op->offset];
172       op = &ins->operands[1];
173       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
174       id = ins->words[op->offset];
175       curKernel->args.push_back(SPIRVKernelArg(id, typeId));
176    }
177 
parseName(const spv_parsed_instruction_t * ins)178    void parseName(const spv_parsed_instruction_t *ins)
179    {
180       const spv_parsed_operand_t *op;
181       const char *name;
182       uint32_t id;
183 
184       assert(ins->num_operands == 2);
185 
186       op = &ins->operands[0];
187       assert(op->type == SPV_OPERAND_TYPE_ID);
188       id = ins->words[op->offset];
189       op = &ins->operands[1];
190       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
191       name = reinterpret_cast<const char *>(ins->words + op->offset);
192 
193       for (auto &kernel : kernels) {
194          for (auto &arg : kernel.args) {
195             if (arg.id == id && arg.name.empty()) {
196               arg.name = name;
197               break;
198 	    }
199          }
200       }
201    }
202 
parseTypePointer(const spv_parsed_instruction_t * ins)203    void parseTypePointer(const spv_parsed_instruction_t *ins)
204    {
205       enum clc_kernel_arg_address_qualifier addrQualifier;
206       uint32_t typeId, storageClass;
207       const spv_parsed_operand_t *op;
208 
209       assert(ins->num_operands == 3);
210 
211       op = &ins->operands[0];
212       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
213       typeId = ins->words[op->offset];
214 
215       op = &ins->operands[1];
216       assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
217       storageClass = ins->words[op->offset];
218       switch (storageClass) {
219       case SpvStorageClassCrossWorkgroup:
220          addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
221          break;
222       case SpvStorageClassWorkgroup:
223          addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
224          break;
225       case SpvStorageClassUniformConstant:
226          addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
227          break;
228       default:
229          addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
230          break;
231       }
232 
233       for (auto &kernel : kernels) {
234 	 for (auto &arg : kernel.args) {
235             if (arg.typeId == typeId)
236                arg.addrQualifier = addrQualifier;
237          }
238       }
239    }
240 
parseOpString(const spv_parsed_instruction_t * ins)241    void parseOpString(const spv_parsed_instruction_t *ins)
242    {
243       const spv_parsed_operand_t *op;
244       std::string str;
245 
246       assert(ins->num_operands == 2);
247 
248       op = &ins->operands[1];
249       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
250       str = reinterpret_cast<const char *>(ins->words + op->offset);
251 
252       if (str.find("kernel_arg_type.") != 0)
253          return;
254 
255       size_t start = sizeof("kernel_arg_type.") - 1;
256 
257       for (auto &kernel : kernels) {
258          size_t pos;
259 
260 	 pos = str.find(kernel.name, start);
261          if (pos == std::string::npos ||
262              pos != start || str[start + kernel.name.size()] != '.')
263             continue;
264 
265 	 pos = start + kernel.name.size();
266          if (str[pos++] != '.')
267             continue;
268 
269          for (auto &arg : kernel.args) {
270             if (arg.name.empty())
271                break;
272 
273             size_t typeEnd = str.find(',', pos);
274 	    if (typeEnd == std::string::npos)
275                break;
276 
277             arg.typeName = str.substr(pos, typeEnd - pos);
278             pos = typeEnd + 1;
279          }
280       }
281    }
282 
applyDecoration(uint32_t id,const spv_parsed_instruction_t * ins)283    void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
284    {
285       auto iter = decorationGroups.find(id);
286       if (iter != decorationGroups.end()) {
287          for (uint32_t entry : iter->second)
288             applyDecoration(entry, ins);
289          return;
290       }
291 
292       const spv_parsed_operand_t *op;
293       uint32_t decoration;
294 
295       assert(ins->num_operands >= 2);
296 
297       op = &ins->operands[1];
298       assert(op->type == SPV_OPERAND_TYPE_DECORATION);
299       decoration = ins->words[op->offset];
300 
301       if (decoration == SpvDecorationSpecId) {
302          uint32_t spec_id = ins->words[ins->operands[2].offset];
303          for (auto &c : specConstants) {
304             if (c.second.id == spec_id) {
305                assert(c.first == id);
306                return;
307             }
308          }
309          specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
310          return;
311       }
312 
313       for (auto &kernel : kernels) {
314          for (auto &arg : kernel.args) {
315             if (arg.id == id) {
316                switch (decoration) {
317                case SpvDecorationVolatile:
318                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
319                   break;
320                case SpvDecorationConstant:
321                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
322                   break;
323                case SpvDecorationRestrict:
324                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
325                   break;
326                case SpvDecorationFuncParamAttr:
327                   op = &ins->operands[2];
328                   assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
329                   switch (ins->words[op->offset]) {
330                   case SpvFunctionParameterAttributeNoAlias:
331                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
332                      break;
333                   case SpvFunctionParameterAttributeNoWrite:
334                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
335                      break;
336                   }
337                   break;
338                }
339             }
340 
341          }
342       }
343    }
344 
parseOpDecorate(const spv_parsed_instruction_t * ins)345    void parseOpDecorate(const spv_parsed_instruction_t *ins)
346    {
347       const spv_parsed_operand_t *op;
348       uint32_t id;
349 
350       assert(ins->num_operands >= 2);
351 
352       op = &ins->operands[0];
353       assert(op->type == SPV_OPERAND_TYPE_ID);
354       id = ins->words[op->offset];
355 
356       applyDecoration(id, ins);
357    }
358 
parseOpGroupDecorate(const spv_parsed_instruction_t * ins)359    void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
360    {
361       assert(ins->num_operands >= 2);
362 
363       const spv_parsed_operand_t *op = &ins->operands[0];
364       assert(op->type == SPV_OPERAND_TYPE_ID);
365       uint32_t groupId = ins->words[op->offset];
366 
367       auto lowerBound = decorationGroups.lower_bound(groupId);
368       if (lowerBound != decorationGroups.end() &&
369           lowerBound->first == groupId)
370          // Group already filled out
371          return;
372 
373       auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
374       auto& vec = iter->second;
375       vec.reserve(ins->num_operands - 1);
376       for (uint32_t i = 1; i < ins->num_operands; ++i) {
377          op = &ins->operands[i];
378          assert(op->type == SPV_OPERAND_TYPE_ID);
379          vec.push_back(ins->words[op->offset]);
380       }
381    }
382 
parseOpTypeImage(const spv_parsed_instruction_t * ins)383    void parseOpTypeImage(const spv_parsed_instruction_t *ins)
384    {
385       const spv_parsed_operand_t *op;
386       uint32_t typeId;
387       unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
388 
389       op = &ins->operands[0];
390       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
391       typeId = ins->words[op->offset];
392 
393       if (ins->num_operands >= 9) {
394          op = &ins->operands[8];
395          assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
396          switch (ins->words[op->offset]) {
397          case SpvAccessQualifierReadOnly:
398             accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
399             break;
400          case SpvAccessQualifierWriteOnly:
401             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
402             break;
403          case SpvAccessQualifierReadWrite:
404             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
405                CLC_KERNEL_ARG_ACCESS_READ;
406             break;
407          }
408       }
409 
410       for (auto &kernel : kernels) {
411 	 for (auto &arg : kernel.args) {
412             if (arg.typeId == typeId) {
413                arg.accessQualifier = accessQualifier;
414                arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
415             }
416          }
417       }
418    }
419 
parseExecutionMode(const spv_parsed_instruction_t * ins)420    void parseExecutionMode(const spv_parsed_instruction_t *ins)
421    {
422       uint32_t executionMode = ins->words[ins->operands[1].offset];
423       if (executionMode != SpvExecutionModeVecTypeHint)
424          return;
425 
426       uint32_t funcId = ins->words[ins->operands[0].offset];
427       uint32_t vecHint = ins->words[ins->operands[2].offset];
428       for (auto& kernel : kernels) {
429          if (kernel.funcId == funcId)
430             kernel.vecHint = vecHint;
431       }
432    }
433 
parseLiteralType(const spv_parsed_instruction_t * ins)434    void parseLiteralType(const spv_parsed_instruction_t *ins)
435    {
436       uint32_t typeId = ins->words[ins->operands[0].offset];
437       auto& literalType = literalTypes[typeId];
438       switch (ins->opcode) {
439       case SpvOpTypeBool:
440          literalType = CLC_SPEC_CONSTANT_BOOL;
441          break;
442       case SpvOpTypeFloat: {
443          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
444          switch (sizeInBits) {
445          case 32:
446             literalType = CLC_SPEC_CONSTANT_FLOAT;
447             break;
448          case 64:
449             literalType = CLC_SPEC_CONSTANT_DOUBLE;
450             break;
451          case 16:
452             /* Can't be used for a spec constant */
453             break;
454          default:
455             unreachable("Unexpected float bit size");
456          }
457          break;
458       }
459       case SpvOpTypeInt: {
460          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
461          bool isSigned = ins->words[ins->operands[2].offset];
462          if (isSigned) {
463             switch (sizeInBits) {
464             case 8:
465                literalType = CLC_SPEC_CONSTANT_INT8;
466                break;
467             case 16:
468                literalType = CLC_SPEC_CONSTANT_INT16;
469                break;
470             case 32:
471                literalType = CLC_SPEC_CONSTANT_INT32;
472                break;
473             case 64:
474                literalType = CLC_SPEC_CONSTANT_INT64;
475                break;
476             default:
477                unreachable("Unexpected int bit size");
478             }
479          } else {
480             switch (sizeInBits) {
481             case 8:
482                literalType = CLC_SPEC_CONSTANT_UINT8;
483                break;
484             case 16:
485                literalType = CLC_SPEC_CONSTANT_UINT16;
486                break;
487             case 32:
488                literalType = CLC_SPEC_CONSTANT_UINT32;
489                break;
490             case 64:
491                literalType = CLC_SPEC_CONSTANT_UINT64;
492                break;
493             default:
494                unreachable("Unexpected uint bit size");
495             }
496          }
497          break;
498       }
499       default:
500          unreachable("Unexpected type opcode");
501       }
502    }
503 
parseSpecConstant(const spv_parsed_instruction_t * ins)504    void parseSpecConstant(const spv_parsed_instruction_t *ins)
505    {
506       uint32_t id = ins->result_id;
507       for (auto& c : specConstants) {
508          if (c.first == id) {
509             auto& data = c.second;
510             switch (ins->opcode) {
511             case SpvOpSpecConstant: {
512                uint32_t typeId = ins->words[ins->operands[0].offset];
513 
514                // This better be an integer or float type
515                auto typeIter = literalTypes.find(typeId);
516                assert(typeIter != literalTypes.end());
517 
518                data.type = typeIter->second;
519                break;
520             }
521             case SpvOpSpecConstantFalse:
522             case SpvOpSpecConstantTrue:
523                data.type = CLC_SPEC_CONSTANT_BOOL;
524                break;
525             default:
526                unreachable("Composites and Ops are not directly specializable.");
527             }
528          }
529       }
530    }
531 
532    static spv_result_t
parseInstruction(void * data,const spv_parsed_instruction_t * ins)533    parseInstruction(void *data, const spv_parsed_instruction_t *ins)
534    {
535       SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
536 
537       switch (ins->opcode) {
538       case SpvOpName:
539          parser->parseName(ins);
540          break;
541       case SpvOpEntryPoint:
542          parser->parseEntryPoint(ins);
543          break;
544       case SpvOpFunction:
545          parser->parseFunction(ins);
546          break;
547       case SpvOpFunctionParameter:
548          parser->parseFunctionParam(ins);
549          break;
550       case SpvOpFunctionEnd:
551       case SpvOpLabel:
552          parser->curKernel = NULL;
553          break;
554       case SpvOpTypePointer:
555          parser->parseTypePointer(ins);
556          break;
557       case SpvOpTypeImage:
558          parser->parseOpTypeImage(ins);
559          break;
560       case SpvOpString:
561          parser->parseOpString(ins);
562          break;
563       case SpvOpDecorate:
564          parser->parseOpDecorate(ins);
565          break;
566       case SpvOpGroupDecorate:
567          parser->parseOpGroupDecorate(ins);
568          break;
569       case SpvOpExecutionMode:
570          parser->parseExecutionMode(ins);
571          break;
572       case SpvOpTypeBool:
573       case SpvOpTypeInt:
574       case SpvOpTypeFloat:
575          parser->parseLiteralType(ins);
576          break;
577       case SpvOpSpecConstant:
578       case SpvOpSpecConstantFalse:
579       case SpvOpSpecConstantTrue:
580          parser->parseSpecConstant(ins);
581          break;
582       default:
583          break;
584       }
585 
586       return SPV_SUCCESS;
587    }
588 
parsingComplete()589    bool parsingComplete()
590    {
591       for (auto &kernel : kernels) {
592          if (kernel.name.empty())
593             return false;
594 
595          for (auto &arg : kernel.args) {
596             if (arg.name.empty() || arg.typeName.empty())
597                return false;
598          }
599       }
600 
601       return true;
602    }
603 
parseBinary(const struct clc_binary & spvbin,const struct clc_logger * logger)604    bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
605    {
606       /* 3 passes should be enough to retrieve all kernel information:
607        * 1st pass: all entry point name and number of args
608        * 2nd pass: argument names and type names
609        * 3rd pass: pointer type names
610        */
611       for (unsigned pass = 0; pass < 3; pass++) {
612          spv_diagnostic diagnostic = NULL;
613          auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
614                                       static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
615                                       NULL, parseInstruction, &diagnostic);
616 
617          if (result != SPV_SUCCESS) {
618             if (diagnostic && logger)
619                logger->error(logger->priv, diagnostic->error);
620             return false;
621          }
622 
623          if (parsingComplete())
624             return true;
625       }
626 
627       assert(0);
628       return false;
629    }
630 
631    std::vector<SPIRVKernelInfo> kernels;
632    std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
633    std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
634    std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
635    SPIRVKernelInfo *curKernel;
636    spv_context ctx;
637 };
638 
639 bool
clc_spirv_get_kernels_info(const struct clc_binary * spvbin,const struct clc_kernel_info ** out_kernels,unsigned * num_kernels,const struct clc_parsed_spec_constant ** out_spec_constants,unsigned * num_spec_constants,const struct clc_logger * logger)640 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
641                            const struct clc_kernel_info **out_kernels,
642                            unsigned *num_kernels,
643                            const struct clc_parsed_spec_constant **out_spec_constants,
644                            unsigned *num_spec_constants,
645                            const struct clc_logger *logger)
646 {
647    struct clc_kernel_info *kernels;
648    struct clc_parsed_spec_constant *spec_constants = NULL;
649 
650    SPIRVKernelParser parser;
651 
652    if (!parser.parseBinary(*spvbin, logger))
653       return false;
654 
655    *num_kernels = parser.kernels.size();
656    *num_spec_constants = parser.specConstants.size();
657    if (!*num_kernels)
658       return false;
659 
660    kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
661                                                                sizeof(*kernels)));
662    assert(kernels);
663    for (unsigned i = 0; i < parser.kernels.size(); i++) {
664       kernels[i].name = strdup(parser.kernels[i].name.c_str());
665       kernels[i].num_args = parser.kernels[i].args.size();
666       kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
667       kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
668       if (!kernels[i].num_args)
669          continue;
670 
671       struct clc_kernel_arg *args;
672 
673       args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
674                                                        sizeof(*kernels->args)));
675       kernels[i].args = args;
676       assert(args);
677       for (unsigned j = 0; j < kernels[i].num_args; j++) {
678          if (!parser.kernels[i].args[j].name.empty())
679             args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
680          args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
681          args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
682          args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
683          args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
684       }
685    }
686 
687    if (*num_spec_constants) {
688       spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
689                                                                                   sizeof(*spec_constants)));
690       assert(spec_constants);
691 
692       for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
693          spec_constants[i] = parser.specConstants[i].second;
694       }
695    }
696 
697    *out_kernels = kernels;
698    *out_spec_constants = spec_constants;
699 
700    return true;
701 }
702 
703 void
clc_free_kernels_info(const struct clc_kernel_info * kernels,unsigned num_kernels)704 clc_free_kernels_info(const struct clc_kernel_info *kernels,
705                       unsigned num_kernels)
706 {
707    if (!kernels)
708       return;
709 
710    for (unsigned i = 0; i < num_kernels; i++) {
711       if (kernels[i].args) {
712          for (unsigned j = 0; j < kernels[i].num_args; j++) {
713             free((void *)kernels[i].args[j].name);
714             free((void *)kernels[i].args[j].type_name);
715          }
716       }
717       free((void *)kernels[i].name);
718    }
719 
720    free((void *)kernels);
721 }
722 
723 static std::pair<std::unique_ptr<::llvm::Module>, std::unique_ptr<LLVMContext>>
clc_compile_to_llvm_module(const struct clc_compile_args * args,const struct clc_logger * logger)724 clc_compile_to_llvm_module(const struct clc_compile_args *args,
725                            const struct clc_logger *logger)
726 {
727    LLVMInitializeAllTargets();
728    LLVMInitializeAllTargetInfos();
729    LLVMInitializeAllTargetMCs();
730    LLVMInitializeAllAsmPrinters();
731 
732    std::string log;
733    std::unique_ptr<LLVMContext> llvm_ctx { new LLVMContext };
734    llvm_ctx->setDiagnosticHandlerCallBack(llvm_log_handler, &log);
735 
736    std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
737    clang::DiagnosticsEngine diag { new clang::DiagnosticIDs,
738          new clang::DiagnosticOptions,
739          new clang::TextDiagnosticPrinter(*new raw_string_ostream(log),
740                                           &c->getDiagnosticOpts(), true)};
741 
742    std::vector<const char *> clang_opts = {
743       args->source.name,
744       "-triple", "spir64-unknown-unknown",
745       // By default, clang prefers to use modules to pull in the default headers,
746       // which doesn't work with our technique of embedding the headers in our binary
747       "-finclude-default-header",
748       // Add a default CL compiler version. Clang will pick the last one specified
749       // on the command line, so the app can override this one.
750       "-cl-std=cl1.2",
751       // The LLVM-SPIRV-Translator doesn't support memset with variable size
752       "-fno-builtin-memset",
753       // LLVM's optimizations can produce code that the translator can't translate
754       "-O0",
755       // Ensure inline functions are actually emitted
756       "-fgnu89-inline"
757    };
758    // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
759    // being provided by the caller here.
760    clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
761 
762    if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
763 #if LLVM_VERSION_MAJOR >= 10
764                                                   clang_opts,
765 #else
766                                                   clang_opts.data(),
767                                                   clang_opts.data() + clang_opts.size(),
768 #endif
769                                                   diag)) {
770       clc_error(logger, "%sCouldn't create Clang invocation.\n", log.c_str());
771       return {};
772    }
773 
774    if (diag.hasErrorOccurred()) {
775       clc_error(logger, "%sErrors occurred during Clang invocation.\n",
776                 log.c_str());
777       return {};
778    }
779 
780    // This is a workaround for a Clang bug which causes the number
781    // of warnings and errors to be printed to stderr.
782    // http://www.llvm.org/bugs/show_bug.cgi?id=19735
783    c->getDiagnosticOpts().ShowCarets = false;
784 
785    c->createDiagnostics(new clang::TextDiagnosticPrinter(
786                            *new raw_string_ostream(log),
787                            &c->getDiagnosticOpts(), true));
788 
789    c->setTarget(clang::TargetInfo::CreateTargetInfo(
790                    c->getDiagnostics(), c->getInvocation().TargetOpts));
791 
792    c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
793 
794 #ifdef USE_STATIC_OPENCL_C_H
795    c->getHeaderSearchOpts().UseBuiltinIncludes = false;
796    c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
797 
798    // Add opencl-c generic search path
799    {
800       ::llvm::SmallString<128> system_header_path;
801       ::llvm::sys::path::system_temp_directory(true, system_header_path);
802       ::llvm::sys::path::append(system_header_path, "openclon12");
803       c->getHeaderSearchOpts().AddPath(system_header_path.str(),
804                                        clang::frontend::Angled,
805                                        false, false);
806 
807       ::llvm::sys::path::append(system_header_path, "opencl-c.h");
808       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
809          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
810 
811       ::llvm::sys::path::remove_filename(system_header_path);
812       ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
813       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
814          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
815    }
816 #else
817    c->getHeaderSearchOpts().UseBuiltinIncludes = true;
818    c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
819    c->getHeaderSearchOpts().ResourceDir = CLANG_RESOURCE_DIR;
820 
821    // Add opencl-c generic search path
822    c->getHeaderSearchOpts().AddPath(CLANG_RESOURCE_DIR,
823                                     clang::frontend::Angled,
824                                     false, false);
825    // Add opencl include
826    c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
827 #endif
828 
829    if (args->num_headers) {
830       ::llvm::SmallString<128> tmp_header_path;
831       ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
832       ::llvm::sys::path::append(tmp_header_path, "openclon12");
833 
834       c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
835                                        clang::frontend::Quoted,
836                                        false, false);
837 
838       for (size_t i = 0; i < args->num_headers; i++) {
839          auto path_copy = tmp_header_path;
840          ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
841          c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
842             ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
843       }
844    }
845 
846    c->getPreprocessorOpts().addRemappedFile(
847            args->source.name,
848            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
849 
850    // Compile the code
851    clang::EmitLLVMOnlyAction act(llvm_ctx.get());
852    if (!c->ExecuteAction(act)) {
853       clc_error(logger, "%sError executing LLVM compilation action.\n",
854                 log.c_str());
855       return {};
856    }
857 
858    return { act.takeModule(), std::move(llvm_ctx) };
859 }
860 
861 static SPIRV::VersionNumber
spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)862 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
863 {
864    switch (version) {
865    case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
866    case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
867    case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
868    case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
869    case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
870 #ifdef HAS_SPIRV_1_4
871    case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
872 #endif
873    default:      return invalid_spirv_trans_version;
874    }
875 }
876 
877 static int
llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,std::unique_ptr<LLVMContext> context,const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)878 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
879                   std::unique_ptr<LLVMContext> context,
880                   const struct clc_compile_args *args,
881                   const struct clc_logger *logger,
882                   struct clc_binary *out_spirv)
883 {
884    std::string log;
885 
886    SPIRV::VersionNumber version =
887       spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
888    if (version == invalid_spirv_trans_version) {
889       clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
890       return -1;
891    }
892 
893    const char *const *extensions = NULL;
894    if (args)
895       extensions = args->allowed_spirv_extensions;
896    if (!extensions) {
897       /* The SPIR-V parser doesn't handle all extensions */
898       static const char *default_extensions[] = {
899          "SPV_EXT_shader_atomic_float_add",
900          "SPV_EXT_shader_atomic_float_min_max",
901          "SPV_KHR_float_controls",
902          NULL,
903       };
904       extensions = default_extensions;
905    }
906 
907    SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
908    for (int i = 0; extensions[i]; i++) {
909 #define EXT(X) \
910       if (strcmp(#X, extensions[i]) == 0) \
911          ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
912 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
913 #undef EXT
914    }
915    SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
916 
917 #if LLVM_VERSION_MAJOR >= 13
918    /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
919    spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
920 #endif
921 
922    std::ostringstream spv_stream;
923    if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
924       clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
925                 log.c_str());
926       return -1;
927    }
928 
929    const std::string spv_out = spv_stream.str();
930    out_spirv->size = spv_out.size();
931    out_spirv->data = malloc(out_spirv->size);
932    memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
933 
934    return 0;
935 }
936 
937 int
clc_c_to_spir(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spir)938 clc_c_to_spir(const struct clc_compile_args *args,
939               const struct clc_logger *logger,
940               struct clc_binary *out_spir)
941 {
942    auto pair = clc_compile_to_llvm_module(args, logger);
943    if (!pair.first)
944       return -1;
945 
946    ::llvm::SmallVector<char, 0> buffer;
947    ::llvm::BitcodeWriter writer(buffer);
948    writer.writeModule(*pair.first);
949 
950    out_spir->size = buffer.size_in_bytes();
951    out_spir->data = malloc(out_spir->size);
952    memcpy(out_spir->data, buffer.data(), out_spir->size);
953 
954    return 0;
955 }
956 
957 int
clc_c_to_spirv(const struct clc_compile_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)958 clc_c_to_spirv(const struct clc_compile_args *args,
959                const struct clc_logger *logger,
960                struct clc_binary *out_spirv)
961 {
962    auto pair = clc_compile_to_llvm_module(args, logger);
963    if (!pair.first)
964       return -1;
965    return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), args, logger, out_spirv);
966 }
967 
968 int
clc_spir_to_spirv(const struct clc_binary * in_spir,const struct clc_logger * logger,struct clc_binary * out_spirv)969 clc_spir_to_spirv(const struct clc_binary *in_spir,
970                   const struct clc_logger *logger,
971                   struct clc_binary *out_spirv)
972 {
973    LLVMInitializeAllTargets();
974    LLVMInitializeAllTargetInfos();
975    LLVMInitializeAllTargetMCs();
976    LLVMInitializeAllAsmPrinters();
977 
978    std::unique_ptr<LLVMContext> llvm_ctx{ new LLVMContext };
979    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
980    auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), *llvm_ctx);
981    if (!mod)
982       return -1;
983 
984    return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), NULL, logger, out_spirv);
985 }
986 
987 class SPIRVMessageConsumer {
988 public:
SPIRVMessageConsumer(const struct clc_logger * logger)989    SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
990 
operator ()(spv_message_level_t level,const char * src,const spv_position_t & pos,const char * msg)991    void operator()(spv_message_level_t level, const char *src,
992                    const spv_position_t &pos, const char *msg)
993    {
994       switch(level) {
995       case SPV_MSG_FATAL:
996       case SPV_MSG_INTERNAL_ERROR:
997       case SPV_MSG_ERROR:
998          clc_error(logger, "(file=%s,line=%ld,column=%ld,index=%ld): %s\n",
999                    src, pos.line, pos.column, pos.index, msg);
1000          break;
1001 
1002       case SPV_MSG_WARNING:
1003          clc_warning(logger, "(file=%s,line=%ld,column=%ld,index=%ld): %s\n",
1004                      src, pos.line, pos.column, pos.index, msg);
1005          break;
1006 
1007       default:
1008          break;
1009       }
1010    }
1011 
1012 private:
1013    const struct clc_logger *logger;
1014 };
1015 
1016 int
clc_link_spirv_binaries(const struct clc_linker_args * args,const struct clc_logger * logger,struct clc_binary * out_spirv)1017 clc_link_spirv_binaries(const struct clc_linker_args *args,
1018                         const struct clc_logger *logger,
1019                         struct clc_binary *out_spirv)
1020 {
1021    std::vector<std::vector<uint32_t>> binaries;
1022 
1023    for (unsigned i = 0; i < args->num_in_objs; i++) {
1024       const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1025       std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1026       binaries.push_back(bin);
1027    }
1028 
1029    SPIRVMessageConsumer msgconsumer(logger);
1030    spvtools::Context context(spirv_target);
1031    context.SetMessageConsumer(msgconsumer);
1032    spvtools::LinkerOptions options;
1033    options.SetAllowPartialLinkage(args->create_library);
1034    options.SetCreateLibrary(args->create_library);
1035    std::vector<uint32_t> linkingResult;
1036    spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1037    if (status != SPV_SUCCESS) {
1038       return -1;
1039    }
1040 
1041    out_spirv->size = linkingResult.size() * 4;
1042    out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1043    memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1044 
1045    return 0;
1046 }
1047 
1048 int
clc_spirv_specialize(const struct clc_binary * in_spirv,const struct clc_parsed_spirv * parsed_data,const struct clc_spirv_specialization_consts * consts,struct clc_binary * out_spirv)1049 clc_spirv_specialize(const struct clc_binary *in_spirv,
1050                      const struct clc_parsed_spirv *parsed_data,
1051                      const struct clc_spirv_specialization_consts *consts,
1052                      struct clc_binary *out_spirv)
1053 {
1054    std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1055    for (unsigned i = 0; i < consts->num_specializations; ++i) {
1056       unsigned id = consts->specializations[i].id;
1057       auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1058          parsed_data->spec_constants + parsed_data->num_spec_constants,
1059          [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1060       assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1061 
1062       std::vector<uint32_t> words;
1063       switch (parsed_spec_const->type) {
1064       case CLC_SPEC_CONSTANT_BOOL:
1065          words.push_back(consts->specializations[i].value.b);
1066          break;
1067       case CLC_SPEC_CONSTANT_INT32:
1068       case CLC_SPEC_CONSTANT_UINT32:
1069       case CLC_SPEC_CONSTANT_FLOAT:
1070          words.push_back(consts->specializations[i].value.u32);
1071          break;
1072       case CLC_SPEC_CONSTANT_INT16:
1073          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1074          break;
1075       case CLC_SPEC_CONSTANT_INT8:
1076          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1077          break;
1078       case CLC_SPEC_CONSTANT_UINT16:
1079          words.push_back((uint32_t)consts->specializations[i].value.u16);
1080          break;
1081       case CLC_SPEC_CONSTANT_UINT8:
1082          words.push_back((uint32_t)consts->specializations[i].value.u8);
1083          break;
1084       case CLC_SPEC_CONSTANT_DOUBLE:
1085       case CLC_SPEC_CONSTANT_INT64:
1086       case CLC_SPEC_CONSTANT_UINT64:
1087          words.resize(2);
1088          memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1089          break;
1090       case CLC_SPEC_CONSTANT_UNKNOWN:
1091          assert(0);
1092          break;
1093       }
1094 
1095       ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1096       assert(ret.second);
1097    }
1098 
1099    spvtools::Optimizer opt(spirv_target);
1100    opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1101 
1102    std::vector<uint32_t> result;
1103    if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1104       return false;
1105 
1106    out_spirv->size = result.size() * 4;
1107    out_spirv->data = malloc(out_spirv->size);
1108    memcpy(out_spirv->data, result.data(), out_spirv->size);
1109    return true;
1110 }
1111 
1112 void
clc_dump_spirv(const struct clc_binary * spvbin,FILE * f)1113 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1114 {
1115    spvtools::SpirvTools tools(spirv_target);
1116    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1117    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1118    std::string out;
1119    tools.Disassemble(bin, &out,
1120                      SPV_BINARY_TO_TEXT_OPTION_INDENT |
1121                      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1122    fwrite(out.c_str(), out.size(), 1, f);
1123 }
1124 
1125 void
clc_free_spir_binary(struct clc_binary * spir)1126 clc_free_spir_binary(struct clc_binary *spir)
1127 {
1128    free(spir->data);
1129 }
1130 
1131 void
clc_free_spirv_binary(struct clc_binary * spvbin)1132 clc_free_spirv_binary(struct clc_binary *spvbin)
1133 {
1134    free(spvbin->data);
1135 }
1136