1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Utilities to map clspv interface variables to OpenCL and Vulkan mappings.
7 //
8
9 #include "libANGLE/renderer/vulkan/clspv_utils.h"
10 #include "common/log_utils.h"
11 #include "libANGLE/renderer/vulkan/CLDeviceVk.h"
12
13 #include "libANGLE/CLDevice.h"
14 #include "libANGLE/renderer/driver_utils.h"
15
16 #include <mutex>
17 #include <string>
18
19 #include "CL/cl_half.h"
20
21 #include "clspv/Compiler.h"
22
23 #include "spirv-tools/libspirv.h"
24 #include "spirv-tools/libspirv.hpp"
25
26 namespace rx
27 {
28 constexpr std::string_view kPrintfConversionSpecifiers = "diouxXfFeEgGaAcsp";
29 constexpr std::string_view kPrintfFlagsSpecifiers = "-+ #0";
30 constexpr std::string_view kPrintfPrecisionSpecifiers = "123456789.";
31 constexpr std::string_view kPrintfVectorSizeSpecifiers = "23468";
32
33 namespace
34 {
35
36 template <typename T>
ReadPtrAs(const unsigned char * data)37 T ReadPtrAs(const unsigned char *data)
38 {
39 return *(reinterpret_cast<const T *>(data));
40 }
41
42 template <typename T>
ReadPtrAsAndIncrement(unsigned char * & data)43 T ReadPtrAsAndIncrement(unsigned char *&data)
44 {
45 T out = *(reinterpret_cast<T *>(data));
46 data += sizeof(T);
47 return out;
48 }
49
getPrintfConversionSpecifier(std::string_view formatString)50 char getPrintfConversionSpecifier(std::string_view formatString)
51 {
52 return formatString.at(formatString.find_first_of(kPrintfConversionSpecifiers));
53 }
54
IsVectorFormat(std::string_view formatString)55 bool IsVectorFormat(std::string_view formatString)
56 {
57 ASSERT(formatString.at(0) == '%');
58
59 // go past the flags, field width and precision
60 size_t pos = formatString.find_first_not_of(kPrintfFlagsSpecifiers, 1ul);
61 pos = formatString.find_first_not_of(kPrintfPrecisionSpecifiers, pos);
62
63 return (formatString.at(pos) == 'v');
64 }
65
66 // Printing an individual formatted string into a std::string
67 // snprintf is used for parsing as OpenCL C printf is similar to printf
PrintFormattedString(const std::string & formatString,const unsigned char * data,size_t size)68 std::string PrintFormattedString(const std::string &formatString,
69 const unsigned char *data,
70 size_t size)
71 {
72 ASSERT(std::count(formatString.begin(), formatString.end(), '%') == 1);
73
74 size_t outSize = 1024;
75 std::vector<char> out(outSize);
76 out[0] = '\0';
77
78 char conversion = std::tolower(getPrintfConversionSpecifier(formatString));
79 bool finished = false;
80 while (!finished)
81 {
82 int bytesWritten = 0;
83 switch (conversion)
84 {
85 case 's':
86 {
87 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(), data);
88 break;
89 }
90 case 'f':
91 case 'e':
92 case 'g':
93 case 'a':
94 {
95 // all floats with same convention as snprintf
96 if (size == 2)
97 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
98 cl_half_to_float(ReadPtrAs<cl_half>(data)));
99 else if (size == 4)
100 bytesWritten =
101 snprintf(out.data(), outSize, formatString.c_str(), ReadPtrAs<float>(data));
102 else
103 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
104 ReadPtrAs<double>(data));
105 break;
106 }
107 default:
108 {
109 if (size == 1)
110 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
111 ReadPtrAs<uint8_t>(data));
112 else if (size == 2)
113 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
114 ReadPtrAs<uint16_t>(data));
115 else if (size == 4)
116 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
117 ReadPtrAs<uint32_t>(data));
118 else
119 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
120 ReadPtrAs<uint64_t>(data));
121 break;
122 }
123 }
124 if (bytesWritten < 0)
125 {
126 out[0] = '\0';
127 finished = true;
128 }
129 else if (bytesWritten < static_cast<long>(outSize))
130 {
131 finished = true;
132 }
133 else
134 {
135 // insufficient size redo above post increment of size
136 outSize *= 2;
137 out.resize(outSize);
138 }
139 }
140
141 return std::string(out.data());
142 }
143
144 // Spec mention vn modifier to be printed in the form v1,v2...vn
PrintVectorFormatIntoString(std::string formatString,const unsigned char * data,const uint32_t size)145 std::string PrintVectorFormatIntoString(std::string formatString,
146 const unsigned char *data,
147 const uint32_t size)
148 {
149 ASSERT(IsVectorFormat(formatString));
150
151 size_t conversionPos = formatString.find_first_of(kPrintfConversionSpecifiers);
152 // keep everything after conversion specifier in remainingFormat
153 std::string remainingFormat = formatString.substr(conversionPos + 1);
154 formatString = formatString.substr(0, conversionPos + 1);
155
156 size_t vectorPos = formatString.find_first_of('v');
157 size_t vectorLengthPos = ++vectorPos;
158 size_t vectorLengthPosEnd =
159 formatString.find_first_not_of(kPrintfVectorSizeSpecifiers, vectorLengthPos);
160
161 std::string preVectorString = formatString.substr(0, vectorPos - 1);
162 std::string postVectorString = formatString.substr(vectorLengthPosEnd, formatString.size());
163 std::string vectorLengthStr = formatString.substr(vectorLengthPos, vectorLengthPosEnd);
164 int vectorLength = std::atoi(vectorLengthStr.c_str());
165
166 // skip the vector specifier
167 formatString = preVectorString + postVectorString;
168
169 // Get the length modifier
170 int elementSize = 0;
171 if (postVectorString.find("hh") != std::string::npos)
172 {
173 elementSize = 1;
174 }
175 else if (postVectorString.find("hl") != std::string::npos)
176 {
177 elementSize = 4;
178 // snprintf doesn't recognize the hl modifier so strip it
179 size_t hl = formatString.find("hl");
180 formatString.erase(hl, 2);
181 }
182 else if (postVectorString.find("h") != std::string::npos)
183 {
184 elementSize = 2;
185 }
186 else if (postVectorString.find("l") != std::string::npos)
187 {
188 elementSize = 8;
189 }
190 else
191 {
192 WARN() << "Vector specifier is used without a length modifier. Guessing it from "
193 "vector length and argument sizes in PrintInfo. Kernel modification is "
194 "recommended.";
195 elementSize = size / vectorLength;
196 }
197
198 std::string out{""};
199 for (int i = 0; i < vectorLength - 1; i++)
200 {
201 out += PrintFormattedString(formatString, data, size / vectorLength) + ",";
202 data += elementSize;
203 }
204 out += PrintFormattedString(formatString, data, size / vectorLength) + remainingFormat;
205
206 return out;
207 }
208
209 // Process the printf stream by breaking them down into individual format specifier and processing
210 // them.
ProcessPrintfStatement(unsigned char * & data,const angle::HashMap<uint32_t,ClspvPrintfInfo> * descs,const unsigned char * dataEnd)211 void ProcessPrintfStatement(unsigned char *&data,
212 const angle::HashMap<uint32_t, ClspvPrintfInfo> *descs,
213 const unsigned char *dataEnd)
214 {
215 // printf storage buffer contents - | id | formatString | argSizes... |
216 uint32_t printfID = ReadPtrAsAndIncrement<uint32_t>(data);
217 const std::string &formatString = descs->at(printfID).formatSpecifier;
218
219 std::string printfOutput = "";
220
221 // formatString could be "<string literal> <% format specifiers ...> <string literal>"
222 // print the literal part if any first
223 size_t nextFormatSpecPos = formatString.find_first_of('%');
224 printfOutput += formatString.substr(0, nextFormatSpecPos);
225
226 // print each <% format specifier> + any string literal separately using snprintf
227 size_t idx = 0;
228 while (nextFormatSpecPos < formatString.size() - 1)
229 {
230 // Get the part of the format string before the next format specifier
231 size_t partStart = nextFormatSpecPos;
232 size_t partEnd = formatString.find_first_of('%', partStart + 1);
233 std::string partFormatString = formatString.substr(partStart, partEnd - partStart);
234
235 // Handle special cases
236 if (partEnd == partStart + 1)
237 {
238 printfOutput += "%";
239 nextFormatSpecPos = partEnd + 1;
240 continue;
241 }
242 else if (partEnd == std::string::npos && idx >= descs->at(printfID).argSizes.size())
243 {
244 // If there are no remaining arguments, the rest of the format
245 // should be printed verbatim
246 printfOutput += partFormatString;
247 break;
248 }
249
250 // The size of the argument that this format part will consume
251 const uint32_t &size = descs->at(printfID).argSizes[idx];
252
253 if (data + size > dataEnd)
254 {
255 data += size;
256 return;
257 }
258
259 // vector format need special care for snprintf
260 if (!IsVectorFormat(partFormatString))
261 {
262 // not a vector format can be printed through snprintf
263 // except for %s
264 if (getPrintfConversionSpecifier(partFormatString) == 's')
265 {
266 uint32_t stringID = ReadPtrAs<uint32_t>(data);
267 printfOutput +=
268 PrintFormattedString(partFormatString,
269 reinterpret_cast<const unsigned char *>(
270 descs->at(stringID).formatSpecifier.c_str()),
271 size);
272 }
273 else
274 {
275 printfOutput += PrintFormattedString(partFormatString, data, size);
276 }
277 data += size;
278 }
279 else
280 {
281 printfOutput += PrintVectorFormatIntoString(partFormatString, data, size);
282 data += size;
283 }
284
285 // Move to the next format part and prepare to handle the next arg
286 nextFormatSpecPos = partEnd;
287 idx++;
288 }
289
290 std::printf("%s", printfOutput.c_str());
291 }
292
GetSpvVersionAsClspvString(spv_target_env spvVersion)293 std::string GetSpvVersionAsClspvString(spv_target_env spvVersion)
294 {
295 switch (spvVersion)
296 {
297 default:
298 case SPV_ENV_VULKAN_1_0:
299 return "1.0";
300 case SPV_ENV_VULKAN_1_1:
301 return "1.3";
302 case SPV_ENV_VULKAN_1_1_SPIRV_1_4:
303 return "1.4";
304 case SPV_ENV_VULKAN_1_2:
305 return "1.5";
306 case SPV_ENV_VULKAN_1_3:
307 return "1.6";
308 }
309 }
310
GetNativeBuiltins(const vk::Renderer * renderer)311 std::vector<std::string> GetNativeBuiltins(const vk::Renderer *renderer)
312 {
313 if (renderer->getFeatures().usesNativeBuiltinClKernel.enabled)
314 {
315 return std::vector<std::string>({"fma", "half_exp2", "exp2"});
316 }
317
318 return {};
319 }
320 } // anonymous namespace
321
322 namespace clspv_cl
323 {
324
GetAddressingMode(uint32_t mask)325 cl::AddressingMode GetAddressingMode(uint32_t mask)
326 {
327 cl::AddressingMode addressingMode = cl::AddressingMode::Clamp;
328
329 switch (mask & clspv::kSamplerAddressMask)
330 {
331 case clspv::CLK_ADDRESS_NONE:
332 default:
333 addressingMode =
334 cl::FromCLenum<cl::AddressingMode>(static_cast<CLenum>(CL_ADDRESS_NONE));
335 break;
336 case clspv::CLK_ADDRESS_CLAMP_TO_EDGE:
337 addressingMode =
338 cl::FromCLenum<cl::AddressingMode>(static_cast<CLenum>(CL_ADDRESS_CLAMP_TO_EDGE));
339 break;
340 case clspv::CLK_ADDRESS_CLAMP:
341 addressingMode =
342 cl::FromCLenum<cl::AddressingMode>(static_cast<CLenum>(CL_ADDRESS_CLAMP));
343 break;
344 case clspv::CLK_ADDRESS_MIRRORED_REPEAT:
345 addressingMode =
346 cl::FromCLenum<cl::AddressingMode>(static_cast<CLenum>(CL_ADDRESS_MIRRORED_REPEAT));
347 break;
348 case clspv::CLK_ADDRESS_REPEAT:
349 addressingMode =
350 cl::FromCLenum<cl::AddressingMode>(static_cast<CLenum>(CL_ADDRESS_REPEAT));
351 break;
352 }
353
354 return addressingMode;
355 }
356
GetFilterMode(uint32_t mask)357 cl::FilterMode GetFilterMode(uint32_t mask)
358 {
359 cl::FilterMode filterMode = cl::FilterMode::Nearest;
360
361 switch (mask & clspv::kSamplerFilterMask)
362 {
363 case clspv::CLK_FILTER_NEAREST:
364 default:
365 filterMode = cl::FromCLenum<cl::FilterMode>(static_cast<CLenum>(CL_FILTER_NEAREST));
366 break;
367 case clspv::CLK_FILTER_LINEAR:
368 filterMode = cl::FromCLenum<cl::FilterMode>(static_cast<CLenum>(CL_FILTER_LINEAR));
369 break;
370 }
371
372 return filterMode;
373 }
374
375 } // namespace clspv_cl
376
377 // Process the data recorded into printf storage buffer along with the info in printfino descriptor
378 // and write it to stdout.
ClspvProcessPrintfBuffer(unsigned char * buffer,const size_t bufferSize,const angle::HashMap<uint32_t,ClspvPrintfInfo> * infoMap)379 angle::Result ClspvProcessPrintfBuffer(unsigned char *buffer,
380 const size_t bufferSize,
381 const angle::HashMap<uint32_t, ClspvPrintfInfo> *infoMap)
382 {
383 // printf storage buffer contains a series of uint32_t values
384 // the first integer is offset from second to next available free memory -- this is the amount
385 // of data written by kernel.
386 const size_t bytesWritten = ReadPtrAsAndIncrement<uint32_t>(buffer) * sizeof(uint32_t);
387 const size_t dataSize = bufferSize - sizeof(uint32_t);
388 const size_t limit = std::min(bytesWritten, dataSize);
389
390 const unsigned char *dataEnd = buffer + limit;
391 while (buffer < dataEnd)
392 {
393 ProcessPrintfStatement(buffer, infoMap, dataEnd);
394 }
395
396 if (bufferSize < bytesWritten)
397 {
398 WARN() << "Printf storage buffer was not sufficient for all printfs. Around "
399 << 100.0 * (float)(bytesWritten - bufferSize) / bytesWritten
400 << "% of them have been skipped.";
401 }
402
403 return angle::Result::Continue;
404 }
405
ClspvGetCompilerOptions(const CLDeviceVk * device)406 std::string ClspvGetCompilerOptions(const CLDeviceVk *device)
407 {
408 ASSERT(device && device->getRenderer());
409 const vk::Renderer *rendererVk = device->getRenderer();
410 std::string options{""};
411 std::vector<std::string> featureMacros;
412
413 cl_uint addressBits;
414 if (IsError(device->getInfoUInt(cl::DeviceInfo::AddressBits, &addressBits)))
415 {
416 // This shouldn't fail here
417 ASSERT(false);
418 }
419 options += addressBits == 64 ? " -arch=spir64" : " -arch=spir";
420
421 // select SPIR-V version target
422 options += " --spv-version=" + GetSpvVersionAsClspvString(device->getSpirvVersion());
423
424 cl_uint nonUniformNDRangeSupport;
425 if (IsError(device->getInfoUInt(cl::DeviceInfo::NonUniformWorkGroupSupport,
426 &nonUniformNDRangeSupport)))
427 {
428 // This shouldn't fail here
429 ASSERT(false);
430 }
431 // This "cl-arm-non-uniform-work-group-size" flag is needed to generate region reflection
432 // instructions since clspv builtin pass is conditionally dependant on it:
433 /*
434 bool NonUniformNDRangeSupported() {
435 return ((Language() == SourceLanguage::OpenCL_CPP) ||
436 (Language() == SourceLanguage::OpenCL_C_20) ||
437 (Language() == SourceLanguage::OpenCL_C_30) ||
438 ArmNonUniformWorkGroupSize()) &&
439 !UniformWorkgroupSize();
440 }
441 ...
442 Value *Ret = GidBase;
443 if (clspv::Option::NonUniformNDRangeSupported()) {
444 auto Ptr = GetPushConstantPointer(BB, clspv::PushConstant::RegionOffset);
445 auto DimPtr = Builder.CreateInBoundsGEP(VT, Ptr, Indices);
446 auto Size = Builder.CreateLoad(IT, DimPtr);
447 ...
448 */
449 options += nonUniformNDRangeSupport == CL_TRUE ? " -cl-arm-non-uniform-work-group-size" : "";
450
451 // Other internal Clspv compiler flags that are needed/required
452 options += " --long-vector";
453 options += " --global-offset";
454 options += " --enable-printf";
455 options += " --cl-kernel-arg-info";
456
457 // check for int8 support
458 if (rendererVk->getFeatures().supportsShaderInt8.enabled)
459 {
460 options += " --int8 --rewrite-packed-structs";
461 }
462
463 // 8 bit storage buffer support
464 if (!rendererVk->getFeatures().supports8BitStorageBuffer.enabled)
465 {
466 options += " --no-8bit-storage=ssbo";
467 }
468 if (!rendererVk->getFeatures().supports8BitUniformAndStorageBuffer.enabled)
469 {
470 options += " --no-8bit-storage=ubo";
471 }
472 if (!rendererVk->getFeatures().supports8BitPushConstant.enabled)
473 {
474 options += " --no-8bit-storage=pushconstant";
475 }
476
477 // 16 bit storage options
478 if (!rendererVk->getFeatures().supports16BitStorageBuffer.enabled)
479 {
480 options += " --no-16bit-storage=ssbo";
481 }
482 if (!rendererVk->getFeatures().supports16BitUniformAndStorageBuffer.enabled)
483 {
484 options += " --no-16bit-storage=ubo";
485 }
486 if (!rendererVk->getFeatures().supports16BitPushConstant.enabled)
487 {
488 options += " --no-16bit-storage=pushconstant";
489 }
490
491 if (rendererVk->getFeatures().supportsUniformBufferStandardLayout.enabled)
492 {
493 options += " --std430-ubo-layout";
494 }
495
496 std::string nativeBuiltins{""};
497 for (const std::string &builtin : GetNativeBuiltins(rendererVk))
498 {
499 nativeBuiltins += builtin + ",";
500 }
501 options += " --use-native-builtins=" + nativeBuiltins;
502 std::vector<std::string> rteModes;
503 if (rendererVk->getFeatures().supportsRoundingModeRteFp32.enabled)
504 {
505 rteModes.push_back("32");
506 }
507 if (rendererVk->getFeatures().supportsShaderFloat16.enabled)
508 {
509 options += " --fp16";
510 if (rendererVk->getFeatures().supportsRoundingModeRteFp16.enabled)
511 {
512 rteModes.push_back("16");
513 }
514 }
515 if (rendererVk->getFeatures().supportsShaderFloat64.enabled)
516 {
517 options += " --fp64";
518 featureMacros.push_back("__opencl_c_fp64");
519 if (rendererVk->getFeatures().supportsRoundingModeRteFp64.enabled)
520 {
521 rteModes.push_back("64");
522 }
523 }
524 else
525 {
526 options += " --fp64=0";
527 }
528
529 if (device->getFrontendObject().getInfo().imageSupport)
530 {
531 featureMacros.push_back("__opencl_c_images");
532 featureMacros.push_back("__opencl_c_3d_image_writes");
533 featureMacros.push_back("__opencl_c_read_write_images");
534 }
535
536 if (rendererVk->getEnabledFeatures().features.shaderInt64)
537 {
538 featureMacros.push_back("__opencl_c_int64");
539 }
540
541 if (!rteModes.empty())
542 {
543 options += " --rounding-mode-rte=";
544 options += std::reduce(std::next(rteModes.begin()), rteModes.end(), rteModes[0],
545 [](const auto a, const auto b) { return a + "," + b; });
546 }
547 if (!featureMacros.empty())
548 {
549 options += " --enable-feature-macros=";
550 options +=
551 std::reduce(std::next(featureMacros.begin()), featureMacros.end(), featureMacros[0],
552 [](const std::string a, const std::string b) { return a + "," + b; });
553 }
554
555 return options;
556 }
557
558 // A locked wrapper for clspvCompileFromSourcesString - the underneath LLVM parser is non-rentrant.
559 // So protecting it with mutex.
ClspvCompileSource(const size_t programCount,const size_t * programSizes,const char ** programs,const char * options,char ** outputBinary,size_t * outputBinarySize,char ** outputLog)560 ClspvError ClspvCompileSource(const size_t programCount,
561 const size_t *programSizes,
562 const char **programs,
563 const char *options,
564 char **outputBinary,
565 size_t *outputBinarySize,
566 char **outputLog)
567 {
568 [[clang::no_destroy]] static angle::SimpleMutex mtx;
569
570 std::lock_guard<angle::SimpleMutex> lock(mtx);
571
572 return clspvCompileFromSourcesString(programCount, programSizes, programs, options,
573 outputBinary, outputBinarySize, outputLog);
574 }
575
ClspvGetSpirvVersion(const vk::Renderer * renderer)576 spv_target_env ClspvGetSpirvVersion(const vk::Renderer *renderer)
577 {
578 uint32_t vulkanApiVersion = renderer->getDeviceVersion();
579 if (vulkanApiVersion < VK_API_VERSION_1_1)
580 {
581 // Minimum supported Vulkan version is 1.1 by Angle
582 UNREACHABLE();
583 return SPV_ENV_MAX;
584 }
585 else if (vulkanApiVersion < VK_API_VERSION_1_2)
586 {
587 // TODO: Might be worthwhile to make Vulkan 1.3 as minimum requirement
588 // http://anglebug.com/383824579
589 if (renderer->getFeatures().supportsSPIRV14.enabled)
590 {
591 return SPV_ENV_VULKAN_1_1_SPIRV_1_4;
592 }
593 return SPV_ENV_VULKAN_1_1;
594 }
595 else if (vulkanApiVersion < VK_API_VERSION_1_3)
596 {
597 return SPV_ENV_VULKAN_1_2;
598 }
599 else
600 {
601 // return the latest supported version
602 return SPV_ENV_VULKAN_1_3;
603 }
604 }
605
ClspvValidate(vk::Renderer * rendererVk,const angle::spirv::Blob & blob)606 bool ClspvValidate(vk::Renderer *rendererVk, const angle::spirv::Blob &blob)
607 {
608 spvtools::SpirvTools spvTool(ClspvGetSpirvVersion(rendererVk));
609 spvTool.SetMessageConsumer([](spv_message_level_t level, const char *,
610 const spv_position_t &position, const char *message) {
611 switch (level)
612 {
613 case SPV_MSG_FATAL:
614 case SPV_MSG_ERROR:
615 case SPV_MSG_INTERNAL_ERROR:
616 ERR() << "SPV validation error (" << position.line << "." << position.column
617 << "): " << message;
618 break;
619 case SPV_MSG_WARNING:
620 WARN() << "SPV validation warning (" << position.line << "." << position.column
621 << "): " << message;
622 break;
623 case SPV_MSG_INFO:
624 INFO() << "SPV validation info (" << position.line << "." << position.column
625 << "): " << message;
626 break;
627 case SPV_MSG_DEBUG:
628 INFO() << "SPV validation debug (" << position.line << "." << position.column
629 << "): " << message;
630 break;
631 default:
632 UNREACHABLE();
633 break;
634 }
635 });
636
637 spvtools::ValidatorOptions options;
638 if (rendererVk->getFeatures().supportsUniformBufferStandardLayout.enabled)
639 {
640 // Allow UBO layouts that conform to std430 (SSBO) layout requirements
641 options.SetUniformBufferStandardLayout(true);
642 }
643
644 return spvTool.Validate(blob.data(), blob.size(), options);
645 }
646
647 } // namespace rx
648