• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // CLProgramVk.cpp: Implements the class methods for CLProgramVk.
7 
8 #include "libANGLE/renderer/vulkan/CLProgramVk.h"
9 #include "libANGLE/renderer/vulkan/CLContextVk.h"
10 #include "libANGLE/renderer/vulkan/clspv_utils.h"
11 #include "libANGLE/renderer/vulkan/vk_cache_utils.h"
12 #include "libANGLE/renderer/vulkan/vk_helpers.h"
13 
14 #include "libANGLE/CLContext.h"
15 #include "libANGLE/CLKernel.h"
16 #include "libANGLE/CLProgram.h"
17 #include "libANGLE/cl_utils.h"
18 
19 #include "common/log_utils.h"
20 #include "common/string_utils.h"
21 #include "common/system_utils.h"
22 
23 #include "clspv/Compiler.h"
24 
25 #include "spirv/unified1/NonSemanticClspvReflection.h"
26 #include "spirv/unified1/spirv.hpp"
27 
28 #include "spirv-tools/libspirv.hpp"
29 #include "spirv-tools/optimizer.hpp"
30 
31 namespace rx
32 {
33 
34 namespace
35 {
36 #if defined(ANGLE_ENABLE_ASSERTS)
37 constexpr bool kAngleDebug = true;
38 #else
39 constexpr bool kAngleDebug = false;
40 #endif
41 
42 // Used by SPIRV-Tools to parse reflection info
ParseReflection(CLProgramVk::SpvReflectionData & reflectionData,const spv_parsed_instruction_t & spvInstr)43 spv_result_t ParseReflection(CLProgramVk::SpvReflectionData &reflectionData,
44                              const spv_parsed_instruction_t &spvInstr)
45 {
46     // Parse spir-v opcodes
47     switch (spvInstr.opcode)
48     {
49         // --- Clspv specific parsing for below cases ---
50         case spv::OpExtInst:
51         {
52             if (spvInstr.ext_inst_type != SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION)
53             {
54                 break;
55             }
56             switch (spvInstr.words[4])
57             {
58                 case NonSemanticClspvReflectionKernel:
59                 {
60                     // Extract kernel name and args - add to kernel args map
61                     std::string functionName = reflectionData.spvStrLookup[spvInstr.words[6]];
62                     uint32_t numArgs         = reflectionData.spvIntLookup[spvInstr.words[7]];
63                     reflectionData.kernelArgsMap[functionName] = CLKernelArguments();
64                     reflectionData.kernelArgsMap[functionName].resize(numArgs);
65 
66                     // Store kernel flags and attributes
67                     reflectionData.kernelFlags[functionName] =
68                         reflectionData.spvIntLookup[spvInstr.words[8]];
69                     reflectionData.kernelAttributes[functionName] =
70                         reflectionData.spvStrLookup[spvInstr.words[9]];
71 
72                     // Save kernel name to reflection table for later use/lookup in parser routine
73                     reflectionData.kernelIDs.insert(spvInstr.words[2]);
74                     reflectionData.spvStrLookup[spvInstr.words[2]] = std::string(functionName);
75 
76                     // If we already parsed some args ahead of time, populate them now
77                     if (reflectionData.kernelArgMap.contains(functionName))
78                     {
79                         for (const auto &arg : reflectionData.kernelArgMap)
80                         {
81                             uint32_t ordinal = arg.second.ordinal;
82                             reflectionData.kernelArgsMap[functionName].at(ordinal) =
83                                 std::move(arg.second);
84                         }
85                     }
86                     break;
87                 }
88                 case NonSemanticClspvReflectionArgumentInfo:
89                 {
90                     CLKernelVk::ArgInfo kernelArgInfo;
91                     kernelArgInfo.name = reflectionData.spvStrLookup[spvInstr.words[5]];
92                     // If instruction has more than 5 instruction operands (minus instruction
93                     // name/opcode), that means we have arg qualifiers. ArgumentInfo also counts as
94                     // an operand for OpExtInst. In below example, [ %e %f %g %h ] are the arg
95                     // qualifier operands.
96                     //
97                     // %a = OpExtInst %b %c ArgumentInfo %d [ %e %f %g %h ]
98                     if (spvInstr.num_operands > 5)
99                     {
100                         kernelArgInfo.typeName = reflectionData.spvStrLookup[spvInstr.words[6]];
101                         kernelArgInfo.addressQualifier =
102                             reflectionData.spvIntLookup[spvInstr.words[7]];
103                         kernelArgInfo.accessQualifier =
104                             reflectionData.spvIntLookup[spvInstr.words[8]];
105                         kernelArgInfo.typeQualifier =
106                             reflectionData.spvIntLookup[spvInstr.words[9]];
107                     }
108                     // Store kern arg for later lookup
109                     reflectionData.kernelArgInfos[spvInstr.words[2]] = std::move(kernelArgInfo);
110                     break;
111                 }
112                 case NonSemanticClspvReflectionArgumentPodUniform:
113                 case NonSemanticClspvReflectionArgumentPointerUniform:
114                 case NonSemanticClspvReflectionArgumentPodStorageBuffer:
115                 {
116                     CLKernelArgument kernelArg;
117                     if (spvInstr.num_operands == 11)
118                     {
119                         const CLKernelVk::ArgInfo &kernelArgInfo =
120                             reflectionData.kernelArgInfos[spvInstr.words[11]];
121                         kernelArg.info.name             = kernelArgInfo.name;
122                         kernelArg.info.typeName         = kernelArgInfo.typeName;
123                         kernelArg.info.addressQualifier = kernelArgInfo.addressQualifier;
124                         kernelArg.info.accessQualifier  = kernelArgInfo.accessQualifier;
125                         kernelArg.info.typeQualifier    = kernelArgInfo.typeQualifier;
126                     }
127                     kernelArg.type    = spvInstr.words[4];
128                     kernelArg.used    = true;
129                     kernelArg.ordinal = reflectionData.spvIntLookup[spvInstr.words[6]];
130                     kernelArg.op3     = reflectionData.spvIntLookup[spvInstr.words[7]];
131                     kernelArg.op4     = reflectionData.spvIntLookup[spvInstr.words[8]];
132                     kernelArg.op5     = reflectionData.spvIntLookup[spvInstr.words[9]];
133                     kernelArg.op6     = reflectionData.spvIntLookup[spvInstr.words[10]];
134 
135                     if (reflectionData.kernelIDs.contains(spvInstr.words[5]))
136                     {
137                         CLKernelArguments &kernelArgs =
138                             reflectionData
139                                 .kernelArgsMap[reflectionData.spvStrLookup[spvInstr.words[5]]];
140                         kernelArgs.at(kernelArg.ordinal) = std::move(kernelArg);
141                     }
142                     else
143                     {
144                         // Reflection kernel not yet parsed, place in temp storage for now
145                         reflectionData
146                             .kernelArgMap[reflectionData.spvStrLookup[spvInstr.words[5]]] =
147                             std::move(kernelArg);
148                     }
149 
150                     break;
151                 }
152                 case NonSemanticClspvReflectionArgumentUniform:
153                 case NonSemanticClspvReflectionArgumentWorkgroup:
154                 case NonSemanticClspvReflectionArgumentSampler:
155                 case NonSemanticClspvReflectionArgumentStorageImage:
156                 case NonSemanticClspvReflectionArgumentSampledImage:
157                 case NonSemanticClspvReflectionArgumentStorageBuffer:
158                 case NonSemanticClspvReflectionArgumentStorageTexelBuffer:
159                 case NonSemanticClspvReflectionArgumentUniformTexelBuffer:
160                 case NonSemanticClspvReflectionArgumentPodPushConstant:
161                 case NonSemanticClspvReflectionArgumentPointerPushConstant:
162                 {
163                     CLKernelArgument kernelArg;
164                     if (spvInstr.num_operands == 9)
165                     {
166                         const CLKernelVk::ArgInfo &kernelArgInfo =
167                             reflectionData.kernelArgInfos[spvInstr.words[9]];
168                         kernelArg.info.name             = kernelArgInfo.name;
169                         kernelArg.info.typeName         = kernelArgInfo.typeName;
170                         kernelArg.info.addressQualifier = kernelArgInfo.addressQualifier;
171                         kernelArg.info.accessQualifier  = kernelArgInfo.accessQualifier;
172                         kernelArg.info.typeQualifier    = kernelArgInfo.typeQualifier;
173                     }
174 
175                     kernelArg.type    = spvInstr.words[4];
176                     kernelArg.used    = true;
177                     kernelArg.ordinal = reflectionData.spvIntLookup[spvInstr.words[6]];
178                     kernelArg.op3     = reflectionData.spvIntLookup[spvInstr.words[7]];
179                     kernelArg.op4     = reflectionData.spvIntLookup[spvInstr.words[8]];
180 
181                     if (reflectionData.kernelIDs.contains(spvInstr.words[5]))
182                     {
183                         CLKernelArguments &kernelArgs =
184                             reflectionData
185                                 .kernelArgsMap[reflectionData.spvStrLookup[spvInstr.words[5]]];
186                         kernelArgs.at(kernelArg.ordinal) = std::move(kernelArg);
187                     }
188                     else
189                     {
190                         // Reflection kernel not yet parsed, place in temp storage for now
191                         reflectionData
192                             .kernelArgMap[reflectionData.spvStrLookup[spvInstr.words[5]]] =
193                             std::move(kernelArg);
194                     }
195                     break;
196                 }
197                 case NonSemanticClspvReflectionPushConstantGlobalSize:
198                 case NonSemanticClspvReflectionPushConstantGlobalOffset:
199                 case NonSemanticClspvReflectionPushConstantRegionOffset:
200                 case NonSemanticClspvReflectionPushConstantNumWorkgroups:
201                 case NonSemanticClspvReflectionPushConstantRegionGroupOffset:
202                 case NonSemanticClspvReflectionPushConstantEnqueuedLocalSize:
203                 {
204                     uint32_t offset = reflectionData.spvIntLookup[spvInstr.words[5]];
205                     uint32_t size   = reflectionData.spvIntLookup[spvInstr.words[6]];
206                     reflectionData.pushConstants[spvInstr.words[4]] = {
207                         .stageFlags = 0, .offset = offset, .size = size};
208                     break;
209                 }
210                 case NonSemanticClspvReflectionSpecConstantWorkgroupSize:
211                 {
212                     reflectionData.specConstantIDs[SpecConstantType::WorkgroupSizeX] =
213                         reflectionData.spvIntLookup[spvInstr.words[5]];
214                     reflectionData.specConstantIDs[SpecConstantType::WorkgroupSizeY] =
215                         reflectionData.spvIntLookup[spvInstr.words[6]];
216                     reflectionData.specConstantIDs[SpecConstantType::WorkgroupSizeZ] =
217                         reflectionData.spvIntLookup[spvInstr.words[7]];
218                     reflectionData.specConstantsUsed[SpecConstantType::WorkgroupSizeX] = true;
219                     reflectionData.specConstantsUsed[SpecConstantType::WorkgroupSizeY] = true;
220                     reflectionData.specConstantsUsed[SpecConstantType::WorkgroupSizeZ] = true;
221                     break;
222                 }
223                 case NonSemanticClspvReflectionPropertyRequiredWorkgroupSize:
224                 {
225                     reflectionData.kernelCompileWorkgroupSize
226                         [reflectionData.spvStrLookup[spvInstr.words[5]]] = {
227                         reflectionData.spvIntLookup[spvInstr.words[6]],
228                         reflectionData.spvIntLookup[spvInstr.words[7]],
229                         reflectionData.spvIntLookup[spvInstr.words[8]]};
230                     break;
231                 }
232                 case NonSemanticClspvReflectionSpecConstantWorkDim:
233                 {
234                     reflectionData.specConstantIDs[SpecConstantType::WorkDimension] =
235                         reflectionData.spvIntLookup[spvInstr.words[5]];
236                     reflectionData.specConstantsUsed[SpecConstantType::WorkDimension] = true;
237                     break;
238                 }
239                 case NonSemanticClspvReflectionSpecConstantGlobalOffset:
240                     reflectionData.specConstantIDs[SpecConstantType::GlobalOffsetX] =
241                         reflectionData.spvIntLookup[spvInstr.words[5]];
242                     reflectionData.specConstantIDs[SpecConstantType::GlobalOffsetY] =
243                         reflectionData.spvIntLookup[spvInstr.words[6]];
244                     reflectionData.specConstantIDs[SpecConstantType::GlobalOffsetZ] =
245                         reflectionData.spvIntLookup[spvInstr.words[7]];
246                     reflectionData.specConstantsUsed[SpecConstantType::GlobalOffsetX] = true;
247                     reflectionData.specConstantsUsed[SpecConstantType::GlobalOffsetY] = true;
248                     reflectionData.specConstantsUsed[SpecConstantType::GlobalOffsetZ] = true;
249                     break;
250                 case NonSemanticClspvReflectionPrintfInfo:
251                 {
252                     // Info on the format string used in the builtin printf call in kernel
253                     uint32_t printfID        = reflectionData.spvIntLookup[spvInstr.words[5]];
254                     std::string formatString = reflectionData.spvStrLookup[spvInstr.words[6]];
255                     reflectionData.printfInfoMap[printfID].id              = printfID;
256                     reflectionData.printfInfoMap[printfID].formatSpecifier = formatString;
257                     for (int i = 6; i < spvInstr.num_operands; i++)
258                     {
259                         uint16_t offset = spvInstr.operands[i].offset;
260                         size_t size     = reflectionData.spvIntLookup[spvInstr.words[offset]];
261                         reflectionData.printfInfoMap[printfID].argSizes.push_back(
262                             static_cast<uint32_t>(size));
263                     }
264 
265                     break;
266                 }
267                 case NonSemanticClspvReflectionPrintfBufferStorageBuffer:
268                 {
269                     // Info about the printf storage buffer that contains the formatted content
270                     uint32_t set     = reflectionData.spvIntLookup[spvInstr.words[5]];
271                     uint32_t binding = reflectionData.spvIntLookup[spvInstr.words[6]];
272                     uint32_t size    = reflectionData.spvIntLookup[spvInstr.words[7]];
273                     reflectionData.printfBufferStorage = {set, binding, 0, size};
274                     break;
275                 }
276                 case NonSemanticClspvReflectionPrintfBufferPointerPushConstant:
277                 {
278                     ERR() << "Shouldn't be here. Support of printf builtin function is enabled "
279                              "through "
280                              "PrintfBufferStorageBuffer. Check optins passed down to clspv";
281                     UNREACHABLE();
282                     return SPV_UNSUPPORTED;
283                 }
284                 case NonSemanticClspvReflectionNormalizedSamplerMaskPushConstant:
285                 case NonSemanticClspvReflectionImageArgumentInfoChannelOrderPushConstant:
286                 case NonSemanticClspvReflectionImageArgumentInfoChannelDataTypePushConstant:
287                 {
288                     uint32_t ordinal            = reflectionData.spvIntLookup[spvInstr.words[6]];
289                     uint32_t offset             = reflectionData.spvIntLookup[spvInstr.words[7]];
290                     uint32_t size               = reflectionData.spvIntLookup[spvInstr.words[8]];
291                     VkPushConstantRange pcRange = {.stageFlags = 0, .offset = offset, .size = size};
292                     reflectionData.imagePushConstants[spvInstr.words[4]].push_back(
293                         {.pcRange = pcRange, .ordinal = ordinal});
294                     break;
295                 }
296                 case NonSemanticClspvReflectionLiteralSampler:
297                 {
298                     uint32_t descriptorSet = reflectionData.spvIntLookup[spvInstr.words[5]];
299                     ASSERT(descriptorSet < static_cast<uint32_t>(DescriptorSetIndex::EnumCount));
300                     uint32_t binding         = reflectionData.spvIntLookup[spvInstr.words[6]];
301                     uint32_t mask            = reflectionData.spvIntLookup[spvInstr.words[7]];
302                     cl_bool normalizedCoords = clspv_cl::IsNormalizedCoords(mask);
303                     cl::AddressingMode addressingMode = clspv_cl::GetAddressingMode(mask);
304                     cl::FilterMode filterMode         = clspv_cl::GetFilterMode(mask);
305                     reflectionData.literalSamplers.push_back({.descriptorSet    = descriptorSet,
306                                                               .binding          = binding,
307                                                               .normalizedCoords = normalizedCoords,
308                                                               .addressingMode   = addressingMode,
309                                                               .filterMode       = filterMode});
310                     break;
311                 }
312                 default:
313                     break;
314             }
315             break;
316         }
317         // --- Regular SPIR-V opcode parsing for below cases ---
318         case spv::OpString:
319         {
320             reflectionData.spvStrLookup[spvInstr.words[1]] =
321                 reinterpret_cast<const char *>(&spvInstr.words[2]);
322             break;
323         }
324         case spv::OpConstant:
325         {
326             reflectionData.spvIntLookup[spvInstr.words[2]] = spvInstr.words[3];
327             break;
328         }
329         default:
330             break;
331     }
332     return SPV_SUCCESS;
333 }
334 
ProcessBuildOptions(const std::vector<std::string> & optionTokens,CLProgramVk::BuildType buildType)335 std::string ProcessBuildOptions(const std::vector<std::string> &optionTokens,
336                                 CLProgramVk::BuildType buildType)
337 {
338     std::string processedOptions;
339 
340     // Need to remove/replace options that are not 1-1 mapped to clspv
341     for (const std::string &optionToken : optionTokens)
342     {
343         if (optionToken == "-create-library" && buildType == CLProgramVk::BuildType::LINK)
344         {
345             processedOptions += " --output-format=bc";
346             continue;
347         }
348         processedOptions += " " + optionToken;
349     }
350 
351     switch (buildType)
352     {
353         case CLProgramVk::BuildType::COMPILE:
354             processedOptions += " --output-format=bc";
355             break;
356         case CLProgramVk::BuildType::LINK:
357             processedOptions += " -x ir";
358             break;
359         default:
360             break;
361     }
362 
363     return processedOptions;
364 }
365 
366 }  // namespace
367 
operator ()()368 void CLAsyncBuildTask::operator()()
369 {
370     ANGLE_TRACE_EVENT0("gpu.angle", "CLProgramVk::buildInternal (async)");
371     CLProgramVk::ScopedProgramCallback spc(mNotify);
372     if (!mProgramVk->buildInternal(mDevices, mOptions, mInternalOptions, mBuildType,
373                                    mLinkProgramsList))
374     {
375         ERR() << "Async build failed for program (" << mProgramVk
376               << ")! Check the build status or build log for details.";
377     }
378 }
379 
CLProgramVk(const cl::Program & program)380 CLProgramVk::CLProgramVk(const cl::Program &program)
381     : CLProgramImpl(program),
382       mContext(&program.getContext().getImpl<CLContextVk>()),
383       mAsyncBuildEvent(std::make_shared<angle::WaitableEventDone>())
384 {}
385 
init()386 angle::Result CLProgramVk::init()
387 {
388     cl::DevicePtrs devices;
389     ANGLE_TRY(mContext->getDevices(&devices));
390 
391     // The devices associated with the program object are the devices associated with context
392     for (const cl::DevicePtr &device : devices)
393     {
394         DeviceProgramData deviceProgramData{};
395         mAssociatedDevicePrograms[device->getNative()] = std::move(deviceProgramData);
396     }
397 
398     return angle::Result::Continue;
399 }
400 
init(const size_t * lengths,const unsigned char ** binaries,cl_int * binaryStatus)401 angle::Result CLProgramVk::init(const size_t *lengths,
402                                 const unsigned char **binaries,
403                                 cl_int *binaryStatus)
404 {
405     // The devices associated with program come from device_list param from
406     // clCreateProgramWithBinary
407     for (const cl::DevicePtr &device : mProgram.getDevices())
408     {
409         const unsigned char *binaryHandle = *binaries++;
410         size_t binarySize                 = *lengths++;
411 
412         // Check for header
413         if (binarySize < sizeof(ProgramBinaryOutputHeader))
414         {
415             if (binaryStatus)
416             {
417                 *binaryStatus++ = CL_INVALID_BINARY;
418             }
419             ANGLE_CL_RETURN_ERROR(CL_INVALID_BINARY);
420         }
421         binarySize -= sizeof(ProgramBinaryOutputHeader);
422 
423         // Check for valid binary version from header
424         const ProgramBinaryOutputHeader *binaryHeader =
425             reinterpret_cast<const ProgramBinaryOutputHeader *>(binaryHandle);
426         if (binaryHeader == nullptr)
427         {
428             ERR() << "NULL binary header!";
429             if (binaryStatus)
430             {
431                 *binaryStatus++ = CL_INVALID_BINARY;
432             }
433             ANGLE_CL_RETURN_ERROR(CL_INVALID_BINARY);
434         }
435         else if (binaryHeader->headerVersion < kBinaryVersion)
436         {
437             ERR() << "Binary version not compatible with runtime!";
438             if (binaryStatus)
439             {
440                 *binaryStatus++ = CL_INVALID_BINARY;
441             }
442             ANGLE_CL_RETURN_ERROR(CL_INVALID_BINARY);
443         }
444         binaryHandle += sizeof(ProgramBinaryOutputHeader);
445 
446         // See what kind of binary we have (i.e. SPIR-V or LLVM Bitcode)
447         // https://llvm.org/docs/BitCodeFormat.html#llvm-ir-magic-number
448         // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_magic_number
449         constexpr uint32_t LLVM_BC_MAGIC = 0xDEC04342;
450         constexpr uint32_t SPIRV_MAGIC   = 0x07230203;
451         const uint32_t &firstWord        = reinterpret_cast<const uint32_t *>(binaryHandle)[0];
452         bool isBC                        = firstWord == LLVM_BC_MAGIC;
453         bool isSPV                       = firstWord == SPIRV_MAGIC;
454         if (!isBC && !isSPV)
455         {
456             ERR() << "Binary is neither SPIR-V nor LLVM Bitcode!";
457             if (binaryStatus)
458             {
459                 *binaryStatus++ = CL_INVALID_BINARY;
460             }
461             ANGLE_CL_RETURN_ERROR(CL_INVALID_BINARY);
462         }
463 
464         // Add device binary to program
465         DeviceProgramData deviceBinary;
466         deviceBinary.spirvVersion = device->getImpl<CLDeviceVk>().getSpirvVersion();
467         deviceBinary.binaryType  = binaryHeader->binaryType;
468         deviceBinary.buildStatus = binaryHeader->buildStatus;
469         switch (deviceBinary.binaryType)
470         {
471             case CL_PROGRAM_BINARY_TYPE_EXECUTABLE:
472                 deviceBinary.binary.assign(binarySize / sizeof(uint32_t), 0);
473                 std::memcpy(deviceBinary.binary.data(), binaryHandle, binarySize);
474                 break;
475             case CL_PROGRAM_BINARY_TYPE_LIBRARY:
476             case CL_PROGRAM_BINARY_TYPE_COMPILED_OBJECT:
477                 deviceBinary.IR.assign(binarySize, 0);
478                 std::memcpy(deviceBinary.IR.data(), binaryHandle, binarySize);
479                 break;
480             default:
481                 UNREACHABLE();
482                 ERR() << "Invalid binary type!";
483                 if (binaryStatus)
484                 {
485                     *binaryStatus++ = CL_INVALID_BINARY;
486                 }
487                 ANGLE_CL_RETURN_ERROR(CL_INVALID_BINARY);
488         }
489         mAssociatedDevicePrograms[device->getNative()] = std::move(deviceBinary);
490         if (binaryStatus)
491         {
492             *binaryStatus++ = CL_SUCCESS;
493         }
494     }
495 
496     return angle::Result::Continue;
497 }
498 
~CLProgramVk()499 CLProgramVk::~CLProgramVk() {}
500 
build(const cl::DevicePtrs & devices,const char * options,cl::Program * notify)501 angle::Result CLProgramVk::build(const cl::DevicePtrs &devices,
502                                  const char *options,
503                                  cl::Program *notify)
504 {
505     BuildType buildType = !mProgram.getSource().empty() ? BuildType::BUILD : BuildType::BINARY;
506     const cl::DevicePtrs &devicePtrs = !devices.empty() ? devices : mProgram.getDevices();
507 
508     setBuildStatus(devicePtrs, CL_BUILD_IN_PROGRESS);
509 
510     if (notify)
511     {
512         mAsyncBuildEvent =
513             getPlatform()->postMultiThreadWorkerTask(std::make_shared<CLAsyncBuildTask>(
514                 this, devicePtrs, std::string(options ? options : ""), "", buildType,
515                 LinkProgramsList{}, notify));
516         ASSERT(mAsyncBuildEvent != nullptr);
517     }
518     else
519     {
520         if (!buildInternal(devicePtrs, std::string(options ? options : ""), "", buildType,
521                            LinkProgramsList{}))
522         {
523             ANGLE_CL_RETURN_ERROR(CL_BUILD_PROGRAM_FAILURE);
524         }
525     }
526     return angle::Result::Continue;
527 }
528 
compile(const cl::DevicePtrs & devices,const char * options,const cl::ProgramPtrs & inputHeaders,const char ** headerIncludeNames,cl::Program * notify)529 angle::Result CLProgramVk::compile(const cl::DevicePtrs &devices,
530                                    const char *options,
531                                    const cl::ProgramPtrs &inputHeaders,
532                                    const char **headerIncludeNames,
533                                    cl::Program *notify)
534 {
535     const cl::DevicePtrs &devicePtrs = !devices.empty() ? devices : mProgram.getDevices();
536 
537     // Ensure OS temp dir is available
538     std::string internalCompileOpts;
539     Optional<std::string> tmpDir = angle::GetTempDirectory();
540     if (!tmpDir.valid())
541     {
542         ERR() << "Failed to open OS temp dir";
543         ANGLE_CL_RETURN_ERROR(CL_INVALID_OPERATION);
544     }
545     internalCompileOpts += inputHeaders.empty() ? "" : " -I" + tmpDir.value();
546 
547     // Dump input headers to OS temp directory
548     for (size_t i = 0; i < inputHeaders.size(); ++i)
549     {
550         const std::string &inputHeaderSrc =
551             inputHeaders.at(i)->getImpl<CLProgramVk>().mProgram.getSource();
552         std::string headerFilePath(angle::ConcatenatePath(tmpDir.value(), headerIncludeNames[i]));
553 
554         // Sanitize path so we can use "/" as universal path separator
555         angle::MakeForwardSlashThePathSeparator(headerFilePath);
556         size_t baseDirPos = headerFilePath.find_last_of("/");
557 
558         // Ensure parent dir(s) exists
559         if (!angle::CreateDirectories(headerFilePath.substr(0, baseDirPos)))
560         {
561             ERR() << "Failed to create output path(s) for header(s)!";
562             ANGLE_CL_RETURN_ERROR(CL_INVALID_OPERATION);
563         }
564         writeFile(headerFilePath.c_str(), inputHeaderSrc.data(), inputHeaderSrc.size());
565     }
566 
567     setBuildStatus(devicePtrs, CL_BUILD_IN_PROGRESS);
568 
569     // Perform compile
570     if (notify)
571     {
572         mAsyncBuildEvent = mProgram.getContext().getPlatform().getMultiThreadPool()->postWorkerTask(
573             std::make_shared<CLAsyncBuildTask>(
574                 this, devicePtrs, std::string(options ? options : ""), internalCompileOpts,
575                 BuildType::COMPILE, LinkProgramsList{}, notify));
576         ASSERT(mAsyncBuildEvent != nullptr);
577     }
578     else
579     {
580         mAsyncBuildEvent = std::make_shared<angle::WaitableEventDone>();
581         if (!buildInternal(devicePtrs, std::string(options ? options : ""), internalCompileOpts,
582                            BuildType::COMPILE, LinkProgramsList{}))
583         {
584             ANGLE_CL_RETURN_ERROR(CL_COMPILE_PROGRAM_FAILURE);
585         }
586     }
587 
588     return angle::Result::Continue;
589 }
590 
getInfo(cl::ProgramInfo name,size_t valueSize,void * value,size_t * valueSizeRet) const591 angle::Result CLProgramVk::getInfo(cl::ProgramInfo name,
592                                    size_t valueSize,
593                                    void *value,
594                                    size_t *valueSizeRet) const
595 {
596     cl_uint valUInt            = 0u;
597     cl_bool valBool            = CL_FALSE;
598     void *valPointer           = nullptr;
599     const void *copyValue      = nullptr;
600     size_t copySize            = 0u;
601     unsigned char **outputBins = reinterpret_cast<unsigned char **>(value);
602     std::string kernelNamesList;
603     std::vector<size_t> vBinarySizes;
604 
605     switch (name)
606     {
607         case cl::ProgramInfo::NumKernels:
608             for (const auto &deviceProgram : mAssociatedDevicePrograms)
609             {
610                 valUInt += static_cast<decltype(valUInt)>(deviceProgram.second.numKernels());
611             }
612             copyValue = &valUInt;
613             copySize  = sizeof(valUInt);
614             break;
615         case cl::ProgramInfo::BinarySizes:
616         {
617             for (const auto &deviceProgram : mAssociatedDevicePrograms)
618             {
619                 vBinarySizes.push_back(
620                     sizeof(ProgramBinaryOutputHeader) +
621                     (deviceProgram.second.binaryType == CL_PROGRAM_BINARY_TYPE_EXECUTABLE
622                          ? deviceProgram.second.binary.size() * sizeof(uint32_t)
623                          : deviceProgram.second.IR.size()));
624             }
625             valPointer = vBinarySizes.data();
626             copyValue  = valPointer;
627             copySize   = vBinarySizes.size() * sizeof(size_t);
628             break;
629         }
630         case cl::ProgramInfo::Binaries:
631             for (const auto &deviceProgram : mAssociatedDevicePrograms)
632             {
633                 const void *bin =
634                     deviceProgram.second.binaryType == CL_PROGRAM_BINARY_TYPE_EXECUTABLE
635                         ? reinterpret_cast<const void *>(deviceProgram.second.binary.data())
636                         : reinterpret_cast<const void *>(deviceProgram.second.IR.data());
637                 size_t binSize =
638                     deviceProgram.second.binaryType == CL_PROGRAM_BINARY_TYPE_EXECUTABLE
639                         ? deviceProgram.second.binary.size() * sizeof(uint32_t)
640                         : deviceProgram.second.IR.size();
641                 ProgramBinaryOutputHeader header{.headerVersion = kBinaryVersion,
642                                                  .binaryType    = deviceProgram.second.binaryType,
643                                                  .buildStatus   = deviceProgram.second.buildStatus};
644 
645                 if (outputBins != nullptr)
646                 {
647                     if (*outputBins != nullptr)
648                     {
649                         std::memcpy(*outputBins, &header, sizeof(ProgramBinaryOutputHeader));
650                         std::memcpy((*outputBins) + sizeof(ProgramBinaryOutputHeader), bin,
651                                     binSize);
652                     }
653                     outputBins++;
654                 }
655 
656                 // Spec just wants pointer size here
657                 copySize += sizeof(unsigned char *);
658             }
659             // We already copied the (headers + binaries) over - nothing else left to copy
660             copyValue = nullptr;
661             break;
662         case cl::ProgramInfo::KernelNames:
663             for (const auto &deviceProgram : mAssociatedDevicePrograms)
664             {
665                 kernelNamesList = deviceProgram.second.getKernelNames();
666             }
667             valPointer = kernelNamesList.data();
668             copyValue  = valPointer;
669             copySize   = kernelNamesList.size() + 1;
670             break;
671         case cl::ProgramInfo::ScopeGlobalCtorsPresent:
672         case cl::ProgramInfo::ScopeGlobalDtorsPresent:
673             // These are deprecated by version 3.0 and are currently not supported
674             copyValue = &valBool;
675             copySize  = sizeof(cl_bool);
676             break;
677         default:
678             UNREACHABLE();
679     }
680 
681     if ((value != nullptr) && (copyValue != nullptr))
682     {
683         std::memcpy(value, copyValue, copySize);
684     }
685 
686     if (valueSizeRet != nullptr)
687     {
688         *valueSizeRet = copySize;
689     }
690 
691     return angle::Result::Continue;
692 }
693 
getBuildInfo(const cl::Device & device,cl::ProgramBuildInfo name,size_t valueSize,void * value,size_t * valueSizeRet) const694 angle::Result CLProgramVk::getBuildInfo(const cl::Device &device,
695                                         cl::ProgramBuildInfo name,
696                                         size_t valueSize,
697                                         void *value,
698                                         size_t *valueSizeRet) const
699 {
700     cl_uint valUInt                            = 0;
701     cl_build_status valStatus                  = 0;
702     const void *copyValue                      = nullptr;
703     size_t copySize                            = 0;
704     const DeviceProgramData *deviceProgramData = getDeviceProgramData(device.getNative());
705 
706     switch (name)
707     {
708         case cl::ProgramBuildInfo::Status:
709             valStatus = deviceProgramData->buildStatus;
710             copyValue = &valStatus;
711             copySize  = sizeof(valStatus);
712             break;
713         case cl::ProgramBuildInfo::Log:
714             copyValue = deviceProgramData->buildLog.c_str();
715             copySize  = deviceProgramData->buildLog.size() + 1;
716             break;
717         case cl::ProgramBuildInfo::Options:
718             copyValue = mProgramOpts.c_str();
719             copySize  = mProgramOpts.size() + 1;
720             break;
721         case cl::ProgramBuildInfo::BinaryType:
722             valUInt   = deviceProgramData->binaryType;
723             copyValue = &valUInt;
724             copySize  = sizeof(valUInt);
725             break;
726         case cl::ProgramBuildInfo::GlobalVariableTotalSize:
727             // Returns 0 if device does not support program scope global variables.
728             valUInt   = 0;
729             copyValue = &valUInt;
730             copySize  = sizeof(valUInt);
731             break;
732         default:
733             UNREACHABLE();
734     }
735 
736     if ((value != nullptr) && (copyValue != nullptr))
737     {
738         memcpy(value, copyValue, std::min(valueSize, copySize));
739     }
740 
741     if (valueSizeRet != nullptr)
742     {
743         *valueSizeRet = copySize;
744     }
745 
746     return angle::Result::Continue;
747 }
748 
createKernel(const cl::Kernel & kernel,const char * name,CLKernelImpl::Ptr * kernelOut)749 angle::Result CLProgramVk::createKernel(const cl::Kernel &kernel,
750                                         const char *name,
751                                         CLKernelImpl::Ptr *kernelOut)
752 {
753     // Wait for the compile to finish
754     mAsyncBuildEvent->wait();
755 
756     std::scoped_lock<angle::SimpleMutex> sl(mProgramMutex);
757     const auto devProgram = getDeviceProgramData(name);
758     ASSERT(devProgram != nullptr);
759 
760     // Create kernel
761     CLKernelArguments kernelArgs = devProgram->getKernelArguments(name);
762     std::string kernelAttributes = devProgram->getKernelAttributes(name);
763     std::string kernelName       = std::string(name ? name : "");
764     CLKernelVk::Ptr kernelImpl   = CLKernelVk::Ptr(
765         new (std::nothrow) CLKernelVk(kernel, kernelName, kernelAttributes, kernelArgs));
766     if (kernelImpl == nullptr)
767     {
768         ERR() << "Could not create kernel obj!";
769         ANGLE_CL_RETURN_ERROR(CL_OUT_OF_HOST_MEMORY);
770     }
771 
772     ANGLE_TRY(kernelImpl->init());
773     *kernelOut = std::move(kernelImpl);
774 
775     return angle::Result::Continue;
776 }
777 
createKernels(cl_uint numKernels,CLKernelImpl::CreateFuncs & createFuncs,cl_uint * numKernelsRet)778 angle::Result CLProgramVk::createKernels(cl_uint numKernels,
779                                          CLKernelImpl::CreateFuncs &createFuncs,
780                                          cl_uint *numKernelsRet)
781 {
782     size_t numDevKernels = 0;
783     for (const auto &dev : mAssociatedDevicePrograms)
784     {
785         numDevKernels += dev.second.numKernels();
786     }
787     if (numKernelsRet != nullptr)
788     {
789         *numKernelsRet = static_cast<cl_uint>(numDevKernels);
790     }
791 
792     if (numKernels != 0)
793     {
794         for (const auto &dev : mAssociatedDevicePrograms)
795         {
796             for (const auto &kernArgMap : dev.second.getKernelArgsMap())
797             {
798                 createFuncs.emplace_back([this, &kernArgMap](const cl::Kernel &kern) {
799                     CLKernelImpl::Ptr implPtr = nullptr;
800                     ANGLE_CL_IMPL_TRY(this->createKernel(kern, kernArgMap.first.c_str(), &implPtr));
801                     return CLKernelImpl::Ptr(std::move(implPtr));
802                 });
803             }
804         }
805     }
806     return angle::Result::Continue;
807 }
808 
getDeviceProgramData(const _cl_device_id * device) const809 const CLProgramVk::DeviceProgramData *CLProgramVk::getDeviceProgramData(
810     const _cl_device_id *device) const
811 {
812     if (!mAssociatedDevicePrograms.contains(device))
813     {
814         WARN() << "Device (" << device << ") is not associated with program (" << this << ") !";
815         return nullptr;
816     }
817     return &mAssociatedDevicePrograms.at(device);
818 }
819 
getDeviceProgramData(const char * kernelName) const820 const CLProgramVk::DeviceProgramData *CLProgramVk::getDeviceProgramData(
821     const char *kernelName) const
822 {
823     for (const auto &deviceProgram : mAssociatedDevicePrograms)
824     {
825         if (deviceProgram.second.containsKernel(kernelName))
826         {
827             return &deviceProgram.second;
828         }
829     }
830     WARN() << "Kernel name (" << kernelName << ") is not associated with program (" << this
831            << ") !";
832     return nullptr;
833 }
834 
buildInternal(const cl::DevicePtrs & devices,std::string options,std::string internalOptions,BuildType buildType,const LinkProgramsList & LinkProgramsList)835 bool CLProgramVk::buildInternal(const cl::DevicePtrs &devices,
836                                 std::string options,
837                                 std::string internalOptions,
838                                 BuildType buildType,
839                                 const LinkProgramsList &LinkProgramsList)
840 {
841     std::scoped_lock<angle::SimpleMutex> sl(mProgramMutex);
842 
843     // Cache original options string
844     mProgramOpts = options;
845 
846     // Process options and append any other internal (required) options for clspv
847     std::vector<std::string> optionTokens;
848     angle::SplitStringAlongWhitespace(options + " " + internalOptions, &optionTokens);
849     const bool createLibrary     = std::find(optionTokens.begin(), optionTokens.end(),
850                                              "-create-library") != optionTokens.end();
851     std::string processedOptions = ProcessBuildOptions(optionTokens, buildType);
852 
853     // Build for each associated device
854     for (size_t i = 0; i < devices.size(); ++i)
855     {
856         const cl::RefPointer<cl::Device> &device = devices.at(i);
857         DeviceProgramData &deviceProgramData     = mAssociatedDevicePrograms[device->getNative()];
858         deviceProgramData.spirvVersion           = device->getImpl<CLDeviceVk>().getSpirvVersion();
859 
860         // add clspv compiler options based on device features
861         processedOptions += ClspvGetCompilerOptions(&device->getImpl<CLDeviceVk>());
862 
863         if (buildType != BuildType::BINARY)
864         {
865             // Invoke clspv
866             switch (buildType)
867             {
868                 case BuildType::BUILD:
869                 case BuildType::COMPILE:
870                 {
871                     ScopedClspvContext clspvCtx;
872                     const char *clSrc = mProgram.getSource().c_str();
873 
874                     ClspvError clspvRet = ClspvCompileSource(
875                         1, NULL, static_cast<const char **>(&clSrc), processedOptions.c_str(),
876                         &clspvCtx.mOutputBin, &clspvCtx.mOutputBinSize, &clspvCtx.mOutputBuildLog);
877                     deviceProgramData.buildLog =
878                         clspvCtx.mOutputBuildLog != nullptr ? clspvCtx.mOutputBuildLog : "";
879                     if (clspvRet != CLSPV_SUCCESS)
880                     {
881                         ERR() << "OpenCL build failed with: ClspvError(" << clspvRet << ")!";
882                         deviceProgramData.buildStatus = CL_BUILD_ERROR;
883                         return false;
884                     }
885 
886                     if (buildType == BuildType::COMPILE)
887                     {
888                         deviceProgramData.IR.assign(clspvCtx.mOutputBinSize, 0);
889                         std::memcpy(deviceProgramData.IR.data(), clspvCtx.mOutputBin,
890                                     clspvCtx.mOutputBinSize);
891                         deviceProgramData.binaryType = CL_PROGRAM_BINARY_TYPE_COMPILED_OBJECT;
892                     }
893                     else
894                     {
895                         deviceProgramData.binary.assign(clspvCtx.mOutputBinSize / sizeof(uint32_t),
896                                                         0);
897                         std::memcpy(deviceProgramData.binary.data(), clspvCtx.mOutputBin,
898                                     clspvCtx.mOutputBinSize);
899                         deviceProgramData.binaryType = CL_PROGRAM_BINARY_TYPE_EXECUTABLE;
900                     }
901                     break;
902                 }
903                 case BuildType::LINK:
904                 {
905                     ScopedClspvContext clspvCtx;
906                     std::vector<size_t> vSizes;
907                     std::vector<const char *> vBins;
908                     const LinkPrograms &linkPrograms = LinkProgramsList.at(i);
909                     for (const CLProgramVk::DeviceProgramData *linkProgramData : linkPrograms)
910                     {
911                         vSizes.push_back(linkProgramData->IR.size());
912                         vBins.push_back(linkProgramData->IR.data());
913                     }
914 
915                     ClspvError clspvRet = ClspvCompileSource(
916                         linkPrograms.size(), vSizes.data(), vBins.data(), processedOptions.c_str(),
917                         &clspvCtx.mOutputBin, &clspvCtx.mOutputBinSize, &clspvCtx.mOutputBuildLog);
918                     deviceProgramData.buildLog =
919                         clspvCtx.mOutputBuildLog != nullptr ? clspvCtx.mOutputBuildLog : "";
920                     if (clspvRet != CLSPV_SUCCESS)
921                     {
922                         ERR() << "OpenCL build failed with: ClspvError(" << clspvRet << ")!";
923                         deviceProgramData.buildStatus = CL_BUILD_ERROR;
924                         return false;
925                     }
926 
927                     if (createLibrary)
928                     {
929                         deviceProgramData.IR.assign(clspvCtx.mOutputBinSize, 0);
930                         std::memcpy(deviceProgramData.IR.data(), clspvCtx.mOutputBin,
931                                     clspvCtx.mOutputBinSize);
932                         deviceProgramData.binaryType = CL_PROGRAM_BINARY_TYPE_LIBRARY;
933                     }
934                     else
935                     {
936                         deviceProgramData.binary.assign(clspvCtx.mOutputBinSize / sizeof(uint32_t),
937                                                         0);
938                         std::memcpy(deviceProgramData.binary.data(),
939                                     reinterpret_cast<char *>(clspvCtx.mOutputBin),
940                                     clspvCtx.mOutputBinSize);
941                         deviceProgramData.binaryType = CL_PROGRAM_BINARY_TYPE_EXECUTABLE;
942                     }
943                     break;
944                 }
945                 default:
946                     UNREACHABLE();
947                     return false;
948             }
949         }
950 
951         // Extract reflection info from spv binary and populate reflection data, as well as create
952         // the shader module
953         if (deviceProgramData.binaryType == CL_PROGRAM_BINARY_TYPE_EXECUTABLE)
954         {
955             // Report SPIR-V validation failure as a build failure
956             if (!ClspvValidate(mContext->getRenderer(), deviceProgramData.binary))
957             {
958                 ERR() << "Failed to validate SPIR-V binary!";
959                 deviceProgramData.buildStatus = CL_BUILD_ERROR;
960                 return false;
961             }
962 
963             spvtools::SpirvTools spvTool(deviceProgramData.spirvVersion);
964             bool parseRet = spvTool.Parse(
965                 deviceProgramData.binary,
966                 [](const spv_endianness_t endianess, const spv_parsed_header_t &instruction) {
967                     return SPV_SUCCESS;
968                 },
969                 [&deviceProgramData](const spv_parsed_instruction_t &instruction) {
970                     return ParseReflection(deviceProgramData.reflectionData, instruction);
971                 });
972             if (!parseRet)
973             {
974                 ERR() << "Failed to parse reflection info from SPIR-V!";
975                 deviceProgramData.buildStatus = CL_BUILD_ERROR;
976                 return false;
977             }
978 
979             if (mShader)
980             {
981                 mShader.reset();
982             }
983             // Strip SPIR-V binary if Vk implementation does not support non-semantic info
984             angle::spirv::Blob spvBlob =
985                 !mContext->getFeatures().supportsShaderNonSemanticInfo.enabled
986                     ? stripReflection(&deviceProgramData)
987                     : deviceProgramData.binary;
988             ASSERT(!spvBlob.empty());
989             if (IsError(vk::InitShaderModule(mContext, &mShader, spvBlob.data(),
990                                              spvBlob.size() * sizeof(uint32_t))))
991             {
992                 ERR() << "Failed to init Vulkan Shader Module!";
993                 deviceProgramData.buildStatus = CL_BUILD_ERROR;
994                 return false;
995             }
996 
997             // Setup inital push constant range
998             uint32_t pushConstantMinOffet = UINT32_MAX, pushConstantMaxOffset = 0,
999                      pushConstantMaxSize = 0;
1000             for (const auto &pushConstant : deviceProgramData.reflectionData.pushConstants)
1001             {
1002                 pushConstantMinOffet = pushConstant.second.offset < pushConstantMinOffet
1003                                            ? pushConstant.second.offset
1004                                            : pushConstantMinOffet;
1005                 if (pushConstant.second.offset >= pushConstantMaxOffset)
1006                 {
1007                     pushConstantMaxOffset = pushConstant.second.offset;
1008                     pushConstantMaxSize   = pushConstant.second.size;
1009                 }
1010             }
1011             for (const auto &pushConstant : deviceProgramData.reflectionData.imagePushConstants)
1012             {
1013                 for (const auto imageConstant : pushConstant.second)
1014                 {
1015                     pushConstantMinOffet = imageConstant.pcRange.offset < pushConstantMinOffet
1016                                                ? imageConstant.pcRange.offset
1017                                                : pushConstantMinOffet;
1018                     if (imageConstant.pcRange.offset >= pushConstantMaxOffset)
1019                     {
1020                         pushConstantMaxOffset = imageConstant.pcRange.offset;
1021                         pushConstantMaxSize   = imageConstant.pcRange.size;
1022                     }
1023                 }
1024             }
1025             deviceProgramData.pushConstRange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
1026             deviceProgramData.pushConstRange.offset =
1027                 pushConstantMinOffet == UINT32_MAX ? 0 : pushConstantMinOffet;
1028             deviceProgramData.pushConstRange.size = pushConstantMaxOffset + pushConstantMaxSize;
1029 
1030             if (kAngleDebug)
1031             {
1032                 if (mContext->getFeatures().clDumpVkSpirv.enabled)
1033                 {
1034                     angle::spirv::Print(deviceProgramData.binary);
1035                 }
1036             }
1037         }
1038         deviceProgramData.buildStatus = CL_BUILD_SUCCESS;
1039     }
1040     return true;
1041 }
1042 
stripReflection(const DeviceProgramData * deviceProgramData)1043 angle::spirv::Blob CLProgramVk::stripReflection(const DeviceProgramData *deviceProgramData)
1044 {
1045     angle::spirv::Blob binaryStripped;
1046     spvtools::Optimizer opt(deviceProgramData->spirvVersion);
1047     opt.RegisterPass(spvtools::CreateStripReflectInfoPass());
1048     spvtools::OptimizerOptions optOptions;
1049     optOptions.set_run_validator(false);
1050     if (!opt.Run(deviceProgramData->binary.data(), deviceProgramData->binary.size(),
1051                  &binaryStripped, optOptions))
1052     {
1053         ERR() << "Could not strip reflection data from binary!";
1054     }
1055     return binaryStripped;
1056 }
1057 
setBuildStatus(const cl::DevicePtrs & devices,cl_build_status status)1058 void CLProgramVk::setBuildStatus(const cl::DevicePtrs &devices, cl_build_status status)
1059 {
1060     std::scoped_lock<angle::SimpleMutex> sl(mProgramMutex);
1061 
1062     for (const auto &device : devices)
1063     {
1064         ASSERT(mAssociatedDevicePrograms.contains(device->getNative()));
1065         DeviceProgramData &deviceProgram = mAssociatedDevicePrograms.at(device->getNative());
1066         deviceProgram.buildStatus        = status;
1067     }
1068 }
1069 
getPrintfDescriptors(const std::string & kernelName) const1070 const angle::HashMap<uint32_t, ClspvPrintfInfo> *CLProgramVk::getPrintfDescriptors(
1071     const std::string &kernelName) const
1072 {
1073     const DeviceProgramData *deviceProgram = getDeviceProgramData(kernelName.c_str());
1074     if (deviceProgram)
1075     {
1076         return &deviceProgram->reflectionData.printfInfoMap;
1077     }
1078     return nullptr;
1079 }
1080 
1081 }  // namespace rx
1082