• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 // third party headers
17 #include <SPIRV/GlslangToSpv.h>
18 #include <SPIRV/SpvTools.h>
19 #include <glslang/Public/ShaderLang.h>
20 #include <spirv-tools/optimizer.hpp>
21 
22 // standard library
23 #include <algorithm>
24 #include <chrono>
25 #include <filesystem>
26 #include <fstream>
27 #include <iostream>
28 #include <memory>
29 #include <numeric>
30 #include <optional>
31 #include <sstream>
32 #include <string>
33 #include <string_view>
34 #include <thread>
35 
36 // internal
37 #include "array_view.h"
38 #include "default_limits.h"
39 #include "io/dev/FileMonitor.h"
40 #include "lume/Log.h"
41 #include "shader_type.h"
42 #include "spirv_cross.hpp"
43 #include "spirv_cross_helpers_gles.h"
44 #include "spirv_opt_extensions.h"
45 
46 namespace {
47 constexpr const int GLSL_VERSION = 110;
48 
49 // Enumerations from Engine which should match: Format, DescriptorType, ShaderStageFlagBits
50 /** Format */
51 enum class Format {
52     /** Undefined */
53     UNDEFINED = 0,
54     /** R4G4 UNORM PACK8 */
55     R4G4_UNORM_PACK8 = 1,
56     /** R4G4B4A4 UNORM PACK16 */
57     R4G4B4A4_UNORM_PACK16 = 2,
58     /** B4G4R4A4 UNORM PACK16 */
59     B4G4R4A4_UNORM_PACK16 = 3,
60     /** R5G6B5 UNORM PACK16 */
61     R5G6B5_UNORM_PACK16 = 4,
62     /** B5G6R5 UNORM PACK16 */
63     B5G6R5_UNORM_PACK16 = 5,
64     /** R5G5B5A1 UNORM PACK16 */
65     R5G5B5A1_UNORM_PACK16 = 6,
66     /** B5G5R5A1 UNORM PACK16 */
67     B5G5R5A1_UNORM_PACK16 = 7,
68     /** A1R5G5B5 UNORM PACK16 */
69     A1R5G5B5_UNORM_PACK16 = 8,
70     /** R8 UNORM */
71     R8_UNORM = 9,
72     /** R8 SNORM */
73     R8_SNORM = 10,
74     /** R8 USCALED */
75     R8_USCALED = 11,
76     /** R8 SSCALED */
77     R8_SSCALED = 12,
78     /** R8 UINT */
79     R8_UINT = 13,
80     /** R8 SINT */
81     R8_SINT = 14,
82     /** R8 SRGB */
83     R8_SRGB = 15,
84     /** R8G8 UNORM */
85     R8G8_UNORM = 16,
86     /** R8G8 SNORM */
87     R8G8_SNORM = 17,
88     /** R8G8 USCALED */
89     R8G8_USCALED = 18,
90     /** R8G8 SSCALED */
91     R8G8_SSCALED = 19,
92     /** R8G8 UINT */
93     R8G8_UINT = 20,
94     /** R8G8 SINT */
95     R8G8_SINT = 21,
96     /** R8G8 SRGB */
97     R8G8_SRGB = 22,
98     /** R8G8B8 UNORM */
99     R8G8B8_UNORM = 23,
100     /** R8G8B8 SNORM */
101     R8G8B8_SNORM = 24,
102     /** R8G8B8 USCALED */
103     R8G8B8_USCALED = 25,
104     /** R8G8B8 SSCALED */
105     R8G8B8_SSCALED = 26,
106     /** R8G8B8 UINT */
107     R8G8B8_UINT = 27,
108     /** R8G8B8 SINT */
109     R8G8B8_SINT = 28,
110     /** R8G8B8 SRGB */
111     R8G8B8_SRGB = 29,
112     /** B8G8R8 UNORM */
113     B8G8R8_UNORM = 30,
114     /** B8G8R8 SNORM */
115     B8G8R8_SNORM = 31,
116     /** B8G8R8 UINT */
117     B8G8R8_UINT = 34,
118     /** B8G8R8 SINT */
119     B8G8R8_SINT = 35,
120     /** B8G8R8 SRGB */
121     B8G8R8_SRGB = 36,
122     /** R8G8B8A8 UNORM */
123     R8G8B8A8_UNORM = 37,
124     /** R8G8B8A8 SNORM */
125     R8G8B8A8_SNORM = 38,
126     /** R8G8B8A8 USCALED */
127     R8G8B8A8_USCALED = 39,
128     /** R8G8B8A8 SSCALED */
129     R8G8B8A8_SSCALED = 40,
130     /** R8G8B8A8 UINT */
131     R8G8B8A8_UINT = 41,
132     /** R8G8B8A8 SINT */
133     R8G8B8A8_SINT = 42,
134     /** R8G8B8A8 SRGB */
135     R8G8B8A8_SRGB = 43,
136     /** B8G8R8A8 UNORM */
137     B8G8R8A8_UNORM = 44,
138     /** B8G8R8A8 SNORM */
139     B8G8R8A8_SNORM = 45,
140     /** B8G8R8A8 UINT */
141     B8G8R8A8_UINT = 48,
142     /** B8G8R8A8 SINT */
143     B8G8R8A8_SINT = 49,
144     /** FORMAT B8G8R8A8 SRGB */
145     B8G8R8A8_SRGB = 50,
146     /** A8B8G8R8 UNORM PACK32 */
147     A8B8G8R8_UNORM_PACK32 = 51,
148     /** A8B8G8R8 SNORM PACK32 */
149     A8B8G8R8_SNORM_PACK32 = 52,
150     /** A8B8G8R8 USCALED PACK32 */
151     A8B8G8R8_USCALED_PACK32 = 53,
152     /** A8B8G8R8 SSCALED PACK32 */
153     A8B8G8R8_SSCALED_PACK32 = 54,
154     /** A8B8G8R8 UINT PACK32 */
155     A8B8G8R8_UINT_PACK32 = 55,
156     /** A8B8G8R8 SINT PACK32 */
157     A8B8G8R8_SINT_PACK32 = 56,
158     /** A8B8G8R8 SRGB PACK32 */
159     A8B8G8R8_SRGB_PACK32 = 57,
160     /** A2R10G10B10 UNORM PACK32 */
161     A2R10G10B10_UNORM_PACK32 = 58,
162     /** A2R10G10B10 UINT PACK32 */
163     A2R10G10B10_UINT_PACK32 = 62,
164     /** A2R10G10B10 SINT PACK32 */
165     A2R10G10B10_SINT_PACK32 = 63,
166     /** A2B10G10R10 UNORM PACK32 */
167     A2B10G10R10_UNORM_PACK32 = 64,
168     /** A2B10G10R10 SNORM PACK32 */
169     A2B10G10R10_SNORM_PACK32 = 65,
170     /** A2B10G10R10 USCALED PACK32 */
171     A2B10G10R10_USCALED_PACK32 = 66,
172     /** A2B10G10R10 SSCALED PACK32 */
173     A2B10G10R10_SSCALED_PACK32 = 67,
174     /** A2B10G10R10 UINT PACK32 */
175     A2B10G10R10_UINT_PACK32 = 68,
176     /** A2B10G10R10 SINT PACK32 */
177     A2B10G10R10_SINT_PACK32 = 69,
178     /** R16 UNORM */
179     R16_UNORM = 70,
180     /** R16 SNORM */
181     R16_SNORM = 71,
182     /** R16 USCALED */
183     R16_USCALED = 72,
184     /** R16 SSCALED */
185     R16_SSCALED = 73,
186     /** R16 UINT */
187     R16_UINT = 74,
188     /** R16 SINT */
189     R16_SINT = 75,
190     /** R16 SFLOAT */
191     R16_SFLOAT = 76,
192     /** R16G16 UNORM */
193     R16G16_UNORM = 77,
194     /** R16G16 SNORM */
195     R16G16_SNORM = 78,
196     /** R16G16 USCALED */
197     R16G16_USCALED = 79,
198     /** R16G16 SSCALED */
199     R16G16_SSCALED = 80,
200     /** R16G16 UINT */
201     R16G16_UINT = 81,
202     /** R16G16 SINT */
203     R16G16_SINT = 82,
204     /** R16G16 SFLOAT */
205     R16G16_SFLOAT = 83,
206     /** R16G16B16 UNORM */
207     R16G16B16_UNORM = 84,
208     /** R16G16B16 SNORM */
209     R16G16B16_SNORM = 85,
210     /** R16G16B16 USCALED */
211     R16G16B16_USCALED = 86,
212     /** R16G16B16 SSCALED */
213     R16G16B16_SSCALED = 87,
214     /** R16G16B16 UINT */
215     R16G16B16_UINT = 88,
216     /** R16G16B16 SINT */
217     R16G16B16_SINT = 89,
218     /** R16G16B16 SFLOAT */
219     R16G16B16_SFLOAT = 90,
220     /** R16G16B16A16 UNORM */
221     R16G16B16A16_UNORM = 91,
222     /** R16G16B16A16 SNORM */
223     R16G16B16A16_SNORM = 92,
224     /** R16G16B16A16 USCALED */
225     R16G16B16A16_USCALED = 93,
226     /** R16G16B16A16 SSCALED */
227     R16G16B16A16_SSCALED = 94,
228     /** R16G16B16A16 UINT */
229     R16G16B16A16_UINT = 95,
230     /** R16G16B16A16 SINT */
231     R16G16B16A16_SINT = 96,
232     /** R16G16B16A16 SFLOAT */
233     R16G16B16A16_SFLOAT = 97,
234     /** R32 UINT */
235     R32_UINT = 98,
236     /** R32 SINT */
237     R32_SINT = 99,
238     /** R32 SFLOAT */
239     R32_SFLOAT = 100,
240     /** R32G32 UINT */
241     R32G32_UINT = 101,
242     /** R32G32 SINT */
243     R32G32_SINT = 102,
244     /** R32G32 SFLOAT */
245     R32G32_SFLOAT = 103,
246     /** R32G32B32 UINT */
247     R32G32B32_UINT = 104,
248     /** R32G32B32 SINT */
249     R32G32B32_SINT = 105,
250     /** R32G32B32 SFLOAT */
251     R32G32B32_SFLOAT = 106,
252     /** R32G32B32A32 UINT */
253     R32G32B32A32_UINT = 107,
254     /** R32G32B32A32 SINT */
255     R32G32B32A32_SINT = 108,
256     /** R32G32B32A32 SFLOAT */
257     R32G32B32A32_SFLOAT = 109,
258     /** B10G11R11 UFLOAT PACK32 */
259     B10G11R11_UFLOAT_PACK32 = 122,
260     /** E5B9G9R9 UFLOAT PACK32 */
261     E5B9G9R9_UFLOAT_PACK32 = 123,
262     /** D16 UNORM */
263     D16_UNORM = 124,
264     /** X8 D24 UNORM PACK32 */
265     X8_D24_UNORM_PACK32 = 125,
266     /** D32 SFLOAT */
267     D32_SFLOAT = 126,
268     /** S8 UINT */
269     S8_UINT = 127,
270     /** D24 UNORM S8 UINT */
271     D24_UNORM_S8_UINT = 129,
272     /** BC1 RGB UNORM BLOCK */
273     BC1_RGB_UNORM_BLOCK = 131,
274     /** BC1 RGB SRGB BLOCK */
275     BC1_RGB_SRGB_BLOCK = 132,
276     /** BC1 RGBA UNORM BLOCK */
277     BC1_RGBA_UNORM_BLOCK = 133,
278     /** BC1 RGBA SRGB BLOCK */
279     BC1_RGBA_SRGB_BLOCK = 134,
280     /** BC2 UNORM BLOCK */
281     BC2_UNORM_BLOCK = 135,
282     /** BC2 SRGB BLOCK */
283     BC2_SRGB_BLOCK = 136,
284     /** BC3 UNORM BLOCK */
285     BC3_UNORM_BLOCK = 137,
286     /** BC3 SRGB BLOCK */
287     BC3_SRGB_BLOCK = 138,
288     /** BC4 UNORM BLOCK */
289     BC4_UNORM_BLOCK = 139,
290     /** BC4 SNORM BLOCK */
291     BC4_SNORM_BLOCK = 140,
292     /** BC5 UNORM BLOCK */
293     BC5_UNORM_BLOCK = 141,
294     /** BC5 SNORM BLOCK */
295     BC5_SNORM_BLOCK = 142,
296     /** BC6H UFLOAT BLOCK */
297     BC6H_UFLOAT_BLOCK = 143,
298     /** BC6H SFLOAT BLOCK */
299     BC6H_SFLOAT_BLOCK = 144,
300     /** BC7 UNORM BLOCK */
301     BC7_UNORM_BLOCK = 145,
302     /** BC7 SRGB BLOCK */
303     BC7_SRGB_BLOCK = 146,
304     /** ETC2 R8G8B8 UNORM BLOCK */
305     ETC2_R8G8B8_UNORM_BLOCK = 147,
306     /** ETC2 R8G8B8 SRGB BLOCK */
307     ETC2_R8G8B8_SRGB_BLOCK = 148,
308     /** ETC2 R8G8B8A1 UNORM BLOCK */
309     ETC2_R8G8B8A1_UNORM_BLOCK = 149,
310     /** ETC2 R8G8B8A1 SRGB BLOCK */
311     ETC2_R8G8B8A1_SRGB_BLOCK = 150,
312     /** ETC2 R8G8B8A8 UNORM BLOCK */
313     ETC2_R8G8B8A8_UNORM_BLOCK = 151,
314     /** ETC2 R8G8B8A8 SRGB BLOCK */
315     ETC2_R8G8B8A8_SRGB_BLOCK = 152,
316     /** EAC R11 UNORM BLOCK */
317     EAC_R11_UNORM_BLOCK = 153,
318     /** EAC R11 SNORM BLOCK */
319     EAC_R11_SNORM_BLOCK = 154,
320     /** EAC R11G11 UNORM BLOCK */
321     EAC_R11G11_UNORM_BLOCK = 155,
322     /** EAC R11G11 SNORM BLOCK */
323     EAC_R11G11_SNORM_BLOCK = 156,
324     /** ASTC 4x4 UNORM BLOCK */
325     ASTC_4X4_UNORM_BLOCK = 157,
326     /** ASTC 4x4 SRGB BLOCK */
327     ASTC_4X4_SRGB_BLOCK = 158,
328     /** ASTC 5x4 UNORM BLOCK */
329     ASTC_5X4_UNORM_BLOCK = 159,
330     /** ASTC 5x4 SRGB BLOCK */
331     ASTC_5X4_SRGB_BLOCK = 160,
332     /** ASTC 5x5 UNORM BLOCK */
333     ASTC_5X5_UNORM_BLOCK = 161,
334     /** ASTC 5x5 SRGB BLOCK */
335     ASTC_5X5_SRGB_BLOCK = 162,
336     /** ASTC 6x5 UNORM BLOCK */
337     ASTC_6X5_UNORM_BLOCK = 163,
338     /** ASTC 6x5 SRGB BLOCK */
339     ASTC_6X5_SRGB_BLOCK = 164,
340     /** ASTC 6x6 UNORM BLOCK */
341     ASTC_6X6_UNORM_BLOCK = 165,
342     /** ASTC 6x6 SRGB BLOCK */
343     ASTC_6X6_SRGB_BLOCK = 166,
344     /** ASTC 8x5 UNORM BLOCK */
345     ASTC_8X5_UNORM_BLOCK = 167,
346     /** ASTC 8x5 SRGB BLOCK */
347     ASTC_8X5_SRGB_BLOCK = 168,
348     /** ASTC 8x6 UNORM BLOCK */
349     ASTC_8X6_UNORM_BLOCK = 169,
350     /** ASTC 8x6 SRGB BLOCK */
351     ASTC_8X6_SRGB_BLOCK = 170,
352     /** ASTC 8x8 UNORM BLOCK */
353     ASTC_8X8_UNORM_BLOCK = 171,
354     /** ASTC 8x8 SRGB BLOCK */
355     ASTC_8X8_SRGB_BLOCK = 172,
356     /** ASTC 10x5 UNORM BLOCK */
357     ASTC_10X5_UNORM_BLOCK = 173,
358     /** ASTC 10x5 SRGB BLOCK */
359     ASTC_10X5_SRGB_BLOCK = 174,
360     /** ASTC 10x6 UNORM BLOCK */
361     ASTC_10X6_UNORM_BLOCK = 175,
362     /** ASTC 10x6 SRGB BLOCK */
363     ASTC_10X6_SRGB_BLOCK = 176,
364     /** ASTC 10x8 UNORM BLOCK */
365     ASTC_10X8_UNORM_BLOCK = 177,
366     /** ASTC 10x8 SRGB BLOCK */
367     ASTC_10X8_SRGB_BLOCK = 178,
368     /** ASTC 10x10 UNORM BLOCK */
369     ASTC_10X10_UNORM_BLOCK = 179,
370     /** ASTC 10x10 SRGB BLOCK */
371     ASTC_10X10_SRGB_BLOCK = 180,
372     /** ASTC 12x10 UNORM BLOCK */
373     ASTC_12X10_UNORM_BLOCK = 181,
374     /** ASTC 12x10 SRGB BLOCK */
375     ASTC_12X10_SRGB_BLOCK = 182,
376     /** ASTC 12x12 UNORM BLOCK */
377     ASTC_12X12_UNORM_BLOCK = 183,
378     /** ASTC 12x12 SRGB BLOCK */
379     ASTC_12X12_SRGB_BLOCK = 184,
380     /** G8B8G8R8 422 UNORM */
381     G8B8G8R8_422_UNORM = 1000156000,
382     /** B8G8R8G8 422 UNORM */
383     B8G8R8G8_422_UNORM = 1000156001,
384     /** G8 B8 R8 3PLANE 420 UNORM */
385     G8_B8_R8_3PLANE_420_UNORM = 1000156002,
386     /** G8 B8R8 2PLANE 420 UNORM */
387     G8_B8R8_2PLANE_420_UNORM = 1000156003,
388     /** G8 B8 R8 3PLANE 422 UNORM */
389     G8_B8_R8_3PLANE_422_UNORM = 1000156004,
390     /** G8 B8R8 2PLANE 422 UNORM */
391     G8_B8R8_2PLANE_422_UNORM = 1000156005,
392     /** Max enumeration */
393     MAX_ENUM = 0x7FFFFFFF
394 };
395 
396 enum class DescriptorType {
397     /** Sampler */
398     SAMPLER = 0,
399     /** Combined image sampler */
400     COMBINED_IMAGE_SAMPLER = 1,
401     /** Sampled image */
402     SAMPLED_IMAGE = 2,
403     /** Storage image */
404     STORAGE_IMAGE = 3,
405     /** Uniform texel buffer */
406     UNIFORM_TEXEL_BUFFER = 4,
407     /** Storage texel buffer */
408     STORAGE_TEXEL_BUFFER = 5,
409     /** Uniform buffer */
410     UNIFORM_BUFFER = 6,
411     /** Storage buffer */
412     STORAGE_BUFFER = 7,
413     /** Dynamic uniform buffer */
414     UNIFORM_BUFFER_DYNAMIC = 8,
415     /** Dynamic storage buffer */
416     STORAGE_BUFFER_DYNAMIC = 9,
417     /** Input attachment */
418     INPUT_ATTACHMENT = 10,
419     /** Acceleration structure */
420     ACCELERATION_STRUCTURE = 1000150000,
421     /** Max enumeration */
422     MAX_ENUM = 0x7FFFFFFF
423 };
424 
425 /** Vertex input rate */
426 enum class VertexInputRate {
427     /** Vertex */
428     VERTEX = 0,
429     /** Instance */
430     INSTANCE = 1,
431     /** Max enumeration */
432     MAX_ENUM = 0x7FFFFFFF
433 };
434 
435 /** Pipeline layout constants */
436 namespace PipelineLayoutConstants {
437 /** Max descriptor set count */
438 static constexpr uint32_t MAX_DESCRIPTOR_SET_COUNT { 4u };
439 /** Invalid index */
440 static constexpr uint32_t INVALID_INDEX { ~0u };
441 } // namespace PipelineLayoutConstants
442 
443 enum class ImageDimension : uint8_t {
444     DIMENSION_1D = 0,
445     DIMENSION_2D = 1,
446     DIMENSION_3D = 2,
447     DIMENSION_CUBE = 3,
448     DIMENSION_RECT = 4,
449     DIMENSION_BUFFER = 5,
450     DIMENSION_SUBPASS = 6,
451 };
452 
453 enum ImageFlags {
454     IMAGE_DEPTH = 0b00000001,
455     IMAGE_ARRAY = 0b00000010,
456     IMAGE_MULTISAMPLE = 0b00000100,
457     IMAGE_SAMPLED = 0b00001000,
458     IMAGE_LOAD_STORE = 0b00010000,
459 };
460 
461 static_assert(int(ImageDimension::DIMENSION_1D) == spv::Dim::Dim1D);
462 static_assert(int(ImageDimension::DIMENSION_CUBE) == spv::Dim::DimCube);
463 static_assert(int(ImageDimension::DIMENSION_SUBPASS) == spv::Dim::DimSubpassData);
464 
465 /** Descriptor set layout binding */
466 struct DescriptorSetLayoutBinding {
467     /** Binding */
468     uint32_t binding { PipelineLayoutConstants::INVALID_INDEX };
469     /** Descriptor type */
470     DescriptorType descriptorType { DescriptorType::MAX_ENUM };
471     /** Descriptor count */
472     uint32_t descriptorCount { 0u };
473     /** Stage flags */
474     ShaderStageFlags shaderStageFlags;
475     ImageDimension imageDimension { ImageDimension::DIMENSION_1D };
476     uint8_t imageFlags;
477 };
478 
479 /** Descriptor set layout */
480 struct DescriptorSetLayout {
481     /** Set */
482     uint32_t set { PipelineLayoutConstants::INVALID_INDEX };
483     /** Bindings */
484     std::vector<DescriptorSetLayoutBinding> bindings;
485 };
486 
487 /** Push constant */
488 struct PushConstant {
489     /** Shader stage flags */
490     ShaderStageFlags shaderStageFlags;
491     /** Byte size */
492     uint32_t byteSize { 0u };
493 };
494 
495 /** Pipeline layout */
496 struct PipelineLayout {
497     /** Push constant */
498     PushConstant pushConstant;
499     /** Descriptor set count */
500     uint32_t descriptorSetCount { 0u };
501     /** Descriptor sets */
502     DescriptorSetLayout descriptorSetLayouts[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] {};
503 };
504 
505 constexpr const uint32_t RESERVED_CONSTANT_ID_INDEX { 256u };
506 
507 /** Vertex input attribute description */
508 struct VertexInputAttributeDescription {
509     /** Location */
510     uint32_t location { ~0u };
511     /** Binding */
512     uint32_t binding { ~0u };
513     /** Format */
514     Format format { Format::UNDEFINED };
515     /** Offset */
516     uint32_t offset { 0u };
517 };
518 
519 struct VertexAttributeInfo {
520     uint32_t byteSize { 0u };
521     VertexInputAttributeDescription description;
522 };
523 
524 struct UVec3 {
525     uint32_t x;
526     uint32_t y;
527     uint32_t z;
528 };
529 
530 struct ShaderReflectionData {
531     array_view<const uint8_t> reflectionData;
532 };
533 
534 struct ShaderModuleCreateInfo {
535     ShaderStageFlags shaderStageFlags;
536     array_view<const uint8_t> spvData;
537     ShaderReflectionData reflectionData;
538 };
539 
540 class FileIncluder : public glslang::TShader::Includer {
541 public:
FileIncluder(const std::filesystem::path & shaderSourcePath,array_view<const std::filesystem::path> shaderIncludePaths)542     FileIncluder(
543         const std::filesystem::path& shaderSourcePath, array_view<const std::filesystem::path> shaderIncludePaths)
544         : shaderSourcePath_(shaderSourcePath), shaderIncludePaths_(shaderIncludePaths)
545     {}
546 
includeSystem(const char * headerName,const char * includerName,size_t inclusionDepth)547     IncludeResult* includeSystem(const char* headerName, const char* includerName, size_t inclusionDepth) override
548     {
549         return include(headerName, includerName, inclusionDepth, false);
550     }
551 
includeLocal(const char * headerName,const char * includerName,size_t inclusionDepth)552     IncludeResult* includeLocal(const char* headerName, const char* includerName, size_t inclusionDepth) override
553     {
554         return include(headerName, includerName, inclusionDepth, true);
555     }
556 
releaseInclude(IncludeResult * include)557     void releaseInclude(IncludeResult* include) override
558     {
559         delete include;
560     }
561 
Reset()562     void Reset()
563     {
564         data_.clear();
565     }
566 
567 private:
SearchExistingIncludes(const char * headerName,const char * includerName,size_t inclusionDepth,bool relative)568     IncludeResult* SearchExistingIncludes(
569         const char* headerName, const char* includerName, size_t inclusionDepth, bool relative)
570     {
571         std::filesystem::path path;
572         if (relative) {
573             path.assign(shaderSourcePath_);
574             path /= includerName;
575             path = path.parent_path();
576             path /= headerName;
577             const auto pathAsString = path.make_preferred().u8string();
578             if (auto pos = data_.find(pathAsString); pos != data_.end()) {
579                 return new (std::nothrow) IncludeResult(pathAsString, pos->second.data(), pos->second.size(), 0);
580             }
581         }
582         for (const auto& includePath : shaderIncludePaths_) {
583             path.assign(includePath);
584             path /= headerName;
585             const auto pathAsString = path.make_preferred().u8string();
586             if (auto pos = data_.find(pathAsString); pos != data_.end()) {
587                 return new (std::nothrow) IncludeResult(pathAsString, pos->second.data(), pos->second.size(), 0);
588             }
589         }
590         return nullptr;
591     }
592 
include(const char * headerName,const char * includerName,size_t inclusionDepth,bool relative)593     IncludeResult* include(const char* headerName, const char* includerName, size_t inclusionDepth, bool relative)
594     {
595         IncludeResult* result = SearchExistingIncludes(headerName, includerName, inclusionDepth, relative);
596         if (result) {
597             return result;
598         }
599 
600         std::filesystem::path path;
601         bool found = false;
602         if (relative) {
603             path.assign(shaderSourcePath_);
604             path /= includerName;
605             path = path.parent_path();
606             path /= headerName;
607             found = std::filesystem::exists(path);
608         }
609         if (!found) {
610             for (const auto& includePath : shaderIncludePaths_) {
611                 path.assign(includePath);
612                 path /= headerName;
613                 found = std::filesystem::exists(path);
614                 if (found) {
615                     break;
616                 }
617             }
618         }
619         if (found) {
620             const auto pathAsString = path.make_preferred().u8string();
621             auto& headerData = data_[pathAsString];
622             const auto length = std::filesystem::file_size(path);
623             headerData.resize(length);
624 
625             std::ifstream(path, std::ios_base::binary).read(headerData.data(), length);
626 
627             return new (std::nothrow) IncludeResult(pathAsString, headerData.data(), headerData.size(), 0);
628         }
629         return nullptr;
630     }
631 
632     const std::filesystem::path& shaderSourcePath_;
633     const array_view<const std::filesystem::path> shaderIncludePaths_;
634     std::unordered_map<std::string, std::string> data_;
635 };
636 
637 struct CompilationSettings {
638     ShaderEnv shaderEnv;
639     std::vector<std::filesystem::path> shaderIncludePaths;
640     std::optional<spvtools::Optimizer> optimizer;
641     const std::filesystem::path& shaderSourcePath;
642     const std::filesystem::path& compiledShaderDestinationPath;
643     FileIncluder& includer;
644 };
645 
646 constexpr uint8_t REFLECTION_TAG[] = { 'r', 'f', 'l', 1 }; // last one is version
647 struct ReflectionHeader {
648     uint8_t tag[sizeof(REFLECTION_TAG)];
649     uint16_t type;
650     uint16_t offsetPushConstants;
651     uint16_t offsetSpecializationConstants;
652     uint16_t offsetDescriptorSets;
653     uint16_t offsetInputs;
654     uint16_t offsetLocalSize;
655 };
656 
657 struct Inputs {
658     std::filesystem::path shaderSourcesPath;
659     std::filesystem::path compiledShaderDestinationPath;
660     std::filesystem::path sourceFile;
661     std::vector<std::filesystem::path> shaderIncludePaths;
662     bool monitorChanges = false;
663     bool optimizeSpirv = false;
664     bool stripDebugInformation = false;
665     ShaderEnv envVersion = ShaderEnv::version_vulkan_1_0;
666 };
667 
668 template<typename InitFun, typename DeinitFun>
669 class Scope {
670 private:
671     InitFun* init_;
672     DeinitFun* deinit_;
673 
674 public:
Scope(InitFun && initializer,DeinitFun && deinitalizer)675     Scope(InitFun&& initializer, DeinitFun&& deinitalizer) : init_(initializer), deinit_(deinitalizer)
676     {
677         init_();
678     }
679 
~Scope()680     ~Scope()
681     {
682         deinit_();
683     }
684 };
685 
ReadFileToString(std::string_view aFilename)686 std::string ReadFileToString(std::string_view aFilename)
687 {
688     std::stringstream ss;
689     std::ifstream file;
690 
691     file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
692     try {
693         file.open(std::filesystem::u8path(aFilename), std::ios::in);
694 
695         if (!file.fail()) {
696             ss << file.rdbuf();
697             return ss.str();
698         }
699     } catch (std::exception const& ex) {
700         LUME_LOG_E("Error reading file: '%s': %s", aFilename.data(), ex.what());
701         return {};
702     }
703     return {};
704 }
705 
ToSpirVVersion(glslang::EShTargetClientVersion env_version)706 glslang::EShTargetLanguageVersion ToSpirVVersion(glslang::EShTargetClientVersion env_version)
707 {
708     if (env_version == glslang::EShTargetVulkan_1_0) {
709         return glslang::EShTargetSpv_1_0;
710     } else if (env_version == glslang::EShTargetVulkan_1_1) {
711         return glslang::EShTargetSpv_1_3;
712     } else if (env_version == glslang::EShTargetVulkan_1_2) {
713         return glslang::EShTargetSpv_1_5;
714 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
715     } else if (env_version == glslang::EShTargetVulkan_1_3) {
716         return glslang::EShTargetSpv_1_6;
717 #endif
718     } else {
719         return glslang::EShTargetSpv_1_0;
720     }
721 }
722 
ConvertShaderKind(ShaderKind kind)723 std::optional<EShLanguage> ConvertShaderKind(ShaderKind kind)
724 {
725     switch (kind) {
726         case ShaderKind::VERTEX:
727             return EShLanguage::EShLangVertex;
728         case ShaderKind::FRAGMENT:
729             return EShLanguage::EShLangFragment;
730         case ShaderKind::COMPUTE:
731             return EShLanguage::EShLangCompute;
732         default:
733             return std::nullopt;
734     }
735 }
736 
ConvertShaderEnv(ShaderEnv shaderEnv)737 std::optional<glslang::EShTargetClientVersion> ConvertShaderEnv(ShaderEnv shaderEnv)
738 {
739     switch (shaderEnv) {
740         case ShaderEnv::version_vulkan_1_0:
741             return glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
742         case ShaderEnv::version_vulkan_1_1:
743             return glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
744         case ShaderEnv::version_vulkan_1_2:
745             return glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
746 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
747         case ShaderEnv::version_vulkan_1_3:
748             return glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
749 #endif
750         default:
751             return std::nullopt;
752     }
753 }
754 
755 // setup values which are common for preproccessing and compilation
CommonShaderInit(glslang::TShader & shader,glslang::EShTargetClientVersion version,glslang::EShTargetLanguageVersion languageVersion)756 void CommonShaderInit(glslang::TShader& shader, glslang::EShTargetClientVersion version,
757     glslang::EShTargetLanguageVersion languageVersion)
758 {
759     shader.setEntryPoint("main");
760     shader.setAutoMapBindings(false);
761     shader.setAutoMapLocations(false);
762     shader.setShiftImageBinding(0);
763     shader.setShiftSamplerBinding(0);
764     shader.setShiftTextureBinding(0);
765     shader.setShiftUboBinding(0);
766     shader.setShiftSsboBinding(0);
767     shader.setShiftUavBinding(0);
768     shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
769     shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
770     shader.setInvertY(false);
771     shader.setNanMinMaxClamp(false);
772 }
773 
PreProcessShader(std::string_view source,ShaderKind kind,std::string_view sourceName,const CompilationSettings & settings)774 std::string PreProcessShader(
775     std::string_view source, ShaderKind kind, std::string_view sourceName, const CompilationSettings& settings)
776 {
777     const std::optional<EShLanguage> stage = ConvertShaderKind(kind);
778     if (!stage) {
779         LUME_LOG_E("Spirv preprocessing failed '%s'", "ShaderKind not recognized");
780         return {};
781     }
782 
783     const std::optional<glslang::EShTargetClientVersion> version = ConvertShaderEnv(settings.shaderEnv);
784     if (!version) {
785         LUME_LOG_E("Spirv preprocessing failed '%s'", "ShaderEnv not recognized");
786         return {};
787     }
788 
789     const glslang::EShTargetLanguageVersion languageVersion = ToSpirVVersion(version.value());
790 
791     glslang::TShader shader(stage.value());
792     CommonShaderInit(shader, version.value(), languageVersion);
793 
794     const char* shaderStrings = source.data();
795     const int shaderLengths = static_cast<int>(source.size());
796     const char* stringNames = sourceName.data();
797     static constexpr std::string_view preamble = "#extension GL_GOOGLE_include_directive : enable\n";
798     shader.setStringsWithLengthsAndNames(&shaderStrings, &shaderLengths, &stringNames, 1);
799     shader.setPreamble(preamble.data());
800 
801     std::string output;
802     const EShMessages rules =
803         static_cast<EShMessages>(EShMsgOnlyPreprocessor | EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
804     if (shader.preprocess(&kGLSLangDefaultTResource, GLSL_VERSION, EProfile::ENoProfile, false, false, rules, &output,
805         settings.includer)) {
806         output.erase(0u, preamble.size());
807     } else {
808         LUME_LOG_E("Spirv preprocessing failed '%s':\n%s", sourceName.data(), shader.getInfoLog());
809         LUME_LOG_E("Spirv preprocessing failed '%s':\n%s", sourceName.data(), shader.getInfoDebugLog());
810     }
811     return output;
812 }
813 
CompileShaderToSpirvBinary(std::string_view source,ShaderKind kind,std::string_view sourceName,const CompilationSettings & settings)814 std::vector<uint32_t> CompileShaderToSpirvBinary(
815     std::string_view source, ShaderKind kind, std::string_view sourceName, const CompilationSettings& settings)
816 {
817     const std::optional<EShLanguage> stage = ConvertShaderKind(kind);
818     if (!stage) {
819         LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderKind not recognized");
820         return {};
821     }
822 
823     const std::optional<glslang::EShTargetClientVersion> version = ConvertShaderEnv(settings.shaderEnv);
824     if (!version) {
825         LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderEnv not recognized");
826         return {};
827     }
828 
829     const glslang::EShTargetLanguageVersion languageVersion = ToSpirVVersion(version.value());
830 
831     glslang::TShader shader(stage.value());
832     CommonShaderInit(shader, version.value(), languageVersion);
833 
834     const char* shaderStrings = source.data();
835     const int shaderLengths = static_cast<int>(source.size());
836     const char* stringNames = sourceName.data();
837     shader.setStringsWithLengthsAndNames(&shaderStrings, &shaderLengths, &stringNames, 1);
838     static constexpr std::string_view preamble = "#extension GL_GOOGLE_include_directive : enable\n";
839     shader.setPreamble(preamble.data());
840 
841     const EShMessages rules = static_cast<EShMessages>(EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
842     if (!shader.parse(&kGLSLangDefaultTResource, GLSL_VERSION, EProfile::ENoProfile, false, false, rules)) {
843         const char* infoLog = shader.getInfoLog();
844         const char* debugInfoLog = shader.getInfoDebugLog();
845         lume::GetLogger().Write(lume::ILogger::LogLevel::ERROR, __FILE__, __LINE__, infoLog);
846         lume::GetLogger().Write(lume::ILogger::LogLevel::ERROR, __FILE__, __LINE__, debugInfoLog);
847         LUME_LOG_E("Spirv binary compilation failed '%s'", sourceName.data());
848         return {};
849     }
850 
851     glslang::TProgram program;
852     program.addShader(&shader);
853     if (!program.link(EShMsgDefault) || !program.mapIO()) {
854         const char* infoLog = shader.getInfoLog();
855         const char* debugInfoLog = shader.getInfoDebugLog();
856         lume::GetLogger().Write(lume::ILogger::LogLevel::ERROR, __FILE__, __LINE__, infoLog);
857         lume::GetLogger().Write(lume::ILogger::LogLevel::ERROR, __FILE__, __LINE__, debugInfoLog);
858         LUME_LOG_E("Spirv binary compilation failed '%s'", sourceName.data());
859         return {};
860     }
861 
862     std::vector<unsigned int> spirv;
863     glslang::SpvOptions spv_options;
864     spv_options.generateDebugInfo = false;
865     spv_options.disableOptimizer = true;
866     spv_options.optimizeSize = false;
867     spv::SpvBuildLogger logger;
868     glslang::TIntermediate* intermediate = program.getIntermediate(stage.value());
869     glslang::GlslangToSpv(*intermediate, spirv, &logger, &spv_options);
870 
871     const uint32_t shadercGeneratorWord = 13; // From SPIR-V XML Registry
872     const uint32_t generatorWordIndex = 2;    // SPIR-V 2.3: Physical layout
873     assert(spirv.size() > generatorWordIndex);
874     spirv[generatorWordIndex] = (spirv[generatorWordIndex] & 0xffff) | (shadercGeneratorWord << 16u);
875     return spirv;
876 }
877 
ProcessResource(const spirv_cross::Compiler & compiler,const spirv_cross::Resource & resource,ShaderStageFlags shaderStateFlags,DescriptorType type,DescriptorSetLayout * layouts)878 void ProcessResource(const spirv_cross::Compiler& compiler, const spirv_cross::Resource& resource,
879     ShaderStageFlags shaderStateFlags, DescriptorType type, DescriptorSetLayout* layouts)
880 {
881     const uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
882 
883     assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
884     if (set >= PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT) {
885         return;
886     }
887     DescriptorSetLayout& layout = layouts[set];
888     layout.set = set;
889 
890     // Collect bindings.
891     const uint32_t bindingIndex = compiler.get_decoration(resource.id, spv::DecorationBinding);
892     auto& bindings = layout.bindings;
893     const auto pos = std::find_if(bindings.begin(), bindings.end(),
894         [bindingIndex](const DescriptorSetLayoutBinding& binding) { return binding.binding == bindingIndex; });
895     if (pos == bindings.end()) {
896         const spirv_cross::SPIRType& spirType = compiler.get_type(resource.type_id);
897 
898         DescriptorSetLayoutBinding binding;
899         binding.binding = bindingIndex;
900         binding.descriptorType = type;
901         binding.descriptorCount = spirType.array.empty() ? 1 : spirType.array[0];
902         binding.shaderStageFlags = shaderStateFlags;
903         binding.imageDimension = ImageDimension(0);
904         binding.imageFlags = 0;
905         if (spirType.basetype == spirv_cross::SPIRType::BaseType::Image ||
906             spirv_cross::SPIRType::BaseType::SampledImage) {
907             binding.imageDimension = ImageDimension(spirType.image.dim);
908             binding.imageFlags = 0;
909             if (spirType.image.depth) {
910                 binding.imageFlags |= ImageFlags::IMAGE_DEPTH;
911             }
912             if (spirType.image.arrayed) {
913                 binding.imageFlags |= ImageFlags::IMAGE_ARRAY;
914             }
915             if (spirType.image.ms) {
916                 binding.imageFlags |= ImageFlags::IMAGE_MULTISAMPLE;
917             }
918             if (spirType.image.sampled == 1) {
919                 binding.imageFlags |= ImageFlags::IMAGE_SAMPLED;
920             } else if (spirType.image.sampled == 2) { // 2: parm
921                 binding.imageFlags |= ImageFlags::IMAGE_LOAD_STORE;
922             }
923         }
924         bindings.emplace_back(binding);
925     } else {
926         pos->shaderStageFlags |= shaderStateFlags;
927     }
928 }
929 
ReflectDescriptorSets(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,DescriptorSetLayout * layouts)930 void ReflectDescriptorSets(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
931     ShaderStageFlags shaderStateFlags, DescriptorSetLayout* layouts)
932 {
933     for (const auto& ref : resources.sampled_images) {
934         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::COMBINED_IMAGE_SAMPLER, layouts);
935     }
936 
937     for (const auto& ref : resources.separate_samplers) {
938         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLER, layouts);
939     }
940 
941     for (const auto& ref : resources.separate_images) {
942         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLED_IMAGE, layouts);
943     }
944 
945     for (const auto& ref : resources.storage_images) {
946         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_IMAGE, layouts);
947     }
948 
949     for (const auto& ref : resources.uniform_buffers) {
950         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::UNIFORM_BUFFER, layouts);
951     }
952 
953     for (const auto& ref : resources.storage_buffers) {
954         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_BUFFER, layouts);
955     }
956 
957     for (const auto& ref : resources.subpass_inputs) {
958         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::INPUT_ATTACHMENT, layouts);
959     }
960 
961     for (const auto& ref : resources.acceleration_structures) {
962         ProcessResource(compiler, ref, shaderStateFlags, DescriptorType::ACCELERATION_STRUCTURE, layouts);
963     }
964 
965     std::sort(layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT,
966         [](const DescriptorSetLayout& lhs, const DescriptorSetLayout& rhs) { return (lhs.set < rhs.set); });
967 
968     std::for_each(
969         layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT, [](DescriptorSetLayout& layout) {
970             std::sort(layout.bindings.begin(), layout.bindings.end(),
971                 [](const DescriptorSetLayoutBinding& lhs, const DescriptorSetLayoutBinding& rhs) {
972                     return (lhs.binding < rhs.binding);
973                 });
974         });
975 }
976 
ReflectPushContants(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,PushConstant & pushConstant)977 void ReflectPushContants(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
978     ShaderStageFlags shaderStateFlags, PushConstant& pushConstant)
979 {
980     // NOTE: support for only one push constant
981     if (resources.push_constant_buffers.size() > 0) {
982         pushConstant.shaderStageFlags |= shaderStateFlags;
983 
984         const auto ranges = compiler.get_active_buffer_ranges(resources.push_constant_buffers[0].id);
985         const uint32_t byteSize = std::accumulate(
986             ranges.begin(), ranges.end(), 0u, [](uint32_t byteSize, const spirv_cross::BufferRange& range) {
987                 return byteSize + static_cast<uint32_t>(range.range);
988             });
989         pushConstant.byteSize = std::max(pushConstant.byteSize, byteSize);
990     }
991 }
992 
ReflectSpecializationConstants(const spirv_cross::Compiler & compiler,ShaderStageFlags shaderStateFlags)993 std::vector<ShaderSpecializationConstant> ReflectSpecializationConstants(
994     const spirv_cross::Compiler& compiler, ShaderStageFlags shaderStateFlags)
995 {
996     std::vector<ShaderSpecializationConstant> specializationConstants;
997     uint32_t offset = 0;
998     for (auto const& constant : compiler.get_specialization_constants()) {
999         if (constant.constant_id < RESERVED_CONSTANT_ID_INDEX) {
1000             const spirv_cross::SPIRConstant& spirvConstant = compiler.get_constant(constant.id);
1001             const auto type = compiler.get_type(spirvConstant.constant_type);
1002             ShaderSpecializationConstant::Type constantType = ShaderSpecializationConstant::Type::INVALID;
1003             if (type.basetype == spirv_cross::SPIRType::Boolean) {
1004                 constantType = ShaderSpecializationConstant::Type::BOOL;
1005             } else if (type.basetype == spirv_cross::SPIRType::UInt) {
1006                 constantType = ShaderSpecializationConstant::Type::UINT32;
1007             } else if (type.basetype == spirv_cross::SPIRType::Int) {
1008                 constantType = ShaderSpecializationConstant::Type::INT32;
1009             } else if (type.basetype == spirv_cross::SPIRType::Float) {
1010                 constantType = ShaderSpecializationConstant::Type::FLOAT;
1011             } else {
1012                 assert(false && "Unhandled specialization constant type");
1013             }
1014             const uint32_t size = spirvConstant.vector_size() * spirvConstant.columns() * sizeof(uint32_t);
1015             specializationConstants.push_back(
1016                 ShaderSpecializationConstant { shaderStateFlags, constant.constant_id, constantType, offset });
1017             offset += size;
1018         }
1019     }
1020     // sorted based on offset due to offset mapping with shader combinations
1021     // NOTE: id and name indexing
1022     std::sort(specializationConstants.begin(), specializationConstants.end(),
1023         [](const auto& lhs, const auto& rhs) { return (lhs.offset < rhs.offset); });
1024 
1025     return specializationConstants;
1026 }
1027 
ConvertToVertexInputFormat(const spirv_cross::SPIRType & type)1028 Format ConvertToVertexInputFormat(const spirv_cross::SPIRType& type)
1029 {
1030     using BaseType = spirv_cross::SPIRType::BaseType;
1031 
1032     // ivecn: a vector of signed integers
1033     if (type.basetype == BaseType::Int) {
1034         switch (type.vecsize) {
1035             case 1u:
1036                 return Format::R32_SINT;
1037             case 2u:
1038                 return Format::R32G32_SINT;
1039             case 3u:
1040                 return Format::R32G32B32_SINT;
1041             case 4u:
1042                 return Format::R32G32B32A32_SINT;
1043             default:
1044                 return Format::UNDEFINED;
1045         }
1046     }
1047 
1048     // uvecn: a vector of unsigned integers
1049     if (type.basetype == BaseType::UInt) {
1050         switch (type.vecsize) {
1051             case 1u:
1052                 return Format::R32_UINT;
1053             case 2u:
1054                 return Format::R32G32_UINT;
1055             case 3u:
1056                 return Format::R32G32B32_UINT;
1057             case 4u:
1058                 return Format::R32G32B32A32_UINT;
1059             default:
1060                 return Format::UNDEFINED;
1061         }
1062     }
1063 
1064     // halfn: a vector of half-precision floating-point numbers
1065     if (type.basetype == BaseType::Half) {
1066         switch (type.vecsize) {
1067             case 1u:
1068                 return Format::R16_SFLOAT;
1069             case 2u:
1070                 return Format::R16G16_SFLOAT;
1071             case 3u:
1072                 return Format::R16G16B16_SFLOAT;
1073             case 4u:
1074                 return Format::R16G16B16A16_SFLOAT;
1075             default:
1076                 return Format::UNDEFINED;
1077         }
1078     }
1079 
1080     // vecn: a vector of single-precision floating-point numbers
1081     if (type.basetype == BaseType::Float) {
1082         switch (type.vecsize) {
1083             case 1u:
1084                 return Format::R32_SFLOAT;
1085             case 2u:
1086                 return Format::R32G32_SFLOAT;
1087             case 3u:
1088                 return Format::R32G32B32_SFLOAT;
1089             case 4u:
1090                 return Format::R32G32B32A32_SFLOAT;
1091             default:
1092                 return Format::UNDEFINED;
1093         }
1094     }
1095 
1096     return Format::UNDEFINED;
1097 }
1098 
ReflectVertexInputs(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags)1099 std::vector<VertexInputAttributeDescription> ReflectVertexInputs(const spirv_cross::Compiler& compiler,
1100     const spirv_cross::ShaderResources& resources, ShaderStageFlags /* shaderStateFlags */)
1101 {
1102     std::vector<VertexInputAttributeDescription> vertexInputAttributes;
1103 
1104     std::vector<VertexAttributeInfo> vertexAttributeInfos;
1105     std::transform(std::begin(resources.stage_inputs), std::end(resources.stage_inputs),
1106         std::back_inserter(vertexAttributeInfos), [&compiler](const spirv_cross::Resource& attr) {
1107             const spirv_cross::SPIRType& attributeType = compiler.get_type(attr.type_id);
1108             const uint32_t location = compiler.get_decoration(attr.id, spv::DecorationLocation);
1109             return VertexAttributeInfo {
1110                 // width is in bits so convert to bytes
1111                 attributeType.vecsize * (attributeType.width / 8u), // byteSize
1112                 {
1113                     // For now, assume that every vertex attribute comes from it's own binding which equals the
1114                     location,                                  // location
1115                     location,                                  // binding
1116                     ConvertToVertexInputFormat(attributeType), // format
1117                     0u                                         // offset
1118                 },                                             // description
1119             };
1120         });
1121 
1122     // Sort input attributes by binding and location.
1123     std::sort(std::begin(vertexAttributeInfos), std::end(vertexAttributeInfos),
1124         [](const VertexAttributeInfo& aA, const VertexAttributeInfo& aB) {
1125             if (aA.description.binding < aB.description.binding) {
1126                 return true;
1127             }
1128 
1129             return aA.description.location < aB.description.location;
1130         });
1131 
1132     // Create final attributes.
1133     if (!vertexAttributeInfos.empty()) {
1134         vertexInputAttributes.reserve(vertexAttributeInfos.size());
1135         std::transform(vertexAttributeInfos.cbegin(), vertexAttributeInfos.cend(),
1136             std::back_inserter(vertexInputAttributes),
1137             [](const VertexAttributeInfo& info) { return info.description; });
1138     }
1139     return vertexInputAttributes;
1140 }
1141 
1142 template<typename T>
Push(std::vector<uint8_t> & buffer,T data)1143 void Push(std::vector<uint8_t>& buffer, T data)
1144 {
1145     buffer.push_back(data & 0xffu);
1146     if constexpr (sizeof(T) > sizeof(uint8_t)) {
1147         buffer.push_back((data >> 8u) & 0xffu);
1148     }
1149     if constexpr (sizeof(T) > sizeof(uint16_t)) {
1150         buffer.push_back((data >> 16u) & 0xffu);
1151     }
1152     if constexpr (sizeof(T) >= sizeof(uint32_t)) {
1153         buffer.push_back((data >> 24u) & 0xffu);
1154     }
1155 }
1156 
CreatePushConstantReflection(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags)1157 std::vector<Gles::PushConstantReflection> CreatePushConstantReflection(const spirv_cross::Compiler& compiler,
1158     const spirv_cross::ShaderResources& resources, ShaderStageFlags shaderStateFlags)
1159 {
1160     std::vector<Gles::PushConstantReflection> pushConstantReflection;
1161     Gles::PushConstantReflection base {};
1162     base.stage = shaderStateFlags;
1163 
1164     for (auto& remap : resources.push_constant_buffers) {
1165         const auto& blockType = compiler.get_type(remap.base_type_id);
1166         base.name = compiler.get_name(remap.id);
1167         (void)(blockType);
1168         assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
1169 
1170         Gles::ProcessStruct(compiler, base, remap.base_type_id, pushConstantReflection);
1171     }
1172     return pushConstantReflection;
1173 }
1174 
AppendPushConstants(std::vector<uint8_t> & reflection,uint32_t byteSize,array_view<const Gles::PushConstantReflection> pushConstantReflection)1175 uint16_t AppendPushConstants(std::vector<uint8_t>& reflection, uint32_t byteSize,
1176     array_view<const Gles::PushConstantReflection> pushConstantReflection)
1177 {
1178     const uint16_t offsetPushConstants = static_cast<uint16_t>(reflection.size());
1179     if (byteSize) {
1180         reflection.push_back(1);
1181         Push(reflection, static_cast<uint16_t>(byteSize));
1182 
1183         Push(reflection, static_cast<uint8_t>(pushConstantReflection.size()));
1184         for (const auto& refl : pushConstantReflection) {
1185             Push(reflection, refl.type);
1186             Push(reflection, static_cast<uint16_t>(refl.offset));
1187             Push(reflection, static_cast<uint16_t>(refl.size));
1188             Push(reflection, static_cast<uint16_t>(refl.arraySize));
1189             Push(reflection, static_cast<uint16_t>(refl.arrayStride));
1190             Push(reflection, static_cast<uint16_t>(refl.matrixStride));
1191             Push(reflection, static_cast<uint16_t>(refl.name.size()));
1192             reflection.insert(reflection.end(), std::begin(refl.name), std::end(refl.name));
1193         }
1194     } else {
1195         reflection.push_back(0);
1196     }
1197     return offsetPushConstants;
1198 }
1199 
AppendSpecializationConstants(std::vector<uint8_t> & reflection,array_view<const ShaderSpecializationConstant> specializationConstants)1200 uint16_t AppendSpecializationConstants(
1201     std::vector<uint8_t>& reflection, array_view<const ShaderSpecializationConstant> specializationConstants)
1202 {
1203     const uint16_t offsetSpecializationConstants = static_cast<uint16_t>(reflection.size());
1204     {
1205         const auto size = static_cast<uint32_t>(specializationConstants.size());
1206         Push(reflection, static_cast<uint32_t>(specializationConstants.size()));
1207     }
1208     for (auto const& constant : specializationConstants) {
1209         Push(reflection, static_cast<uint32_t>(constant.id));
1210         Push(reflection, static_cast<uint32_t>(constant.type));
1211     }
1212     return offsetSpecializationConstants;
1213 }
1214 
AppendDescriptorSets(std::vector<uint8_t> & reflection,const PipelineLayout & pipelineLayout)1215 uint16_t AppendDescriptorSets(std::vector<uint8_t>& reflection, const PipelineLayout& pipelineLayout)
1216 {
1217     const uint16_t offsetDescriptorSets = static_cast<uint16_t>(reflection.size());
1218     {
1219         Push(reflection, static_cast<uint16_t>(pipelineLayout.descriptorSetCount));
1220     }
1221     auto begin = std::begin(pipelineLayout.descriptorSetLayouts);
1222     auto end = begin;
1223     std::advance(end, pipelineLayout.descriptorSetCount);
1224     std::for_each(begin, end, [&reflection](const DescriptorSetLayout& layout) {
1225         Push(reflection, static_cast<uint16_t>(layout.set));
1226         Push(reflection, static_cast<uint16_t>(layout.bindings.size()));
1227         for (const auto& binding : layout.bindings) {
1228             Push(reflection, static_cast<uint16_t>(binding.binding));
1229             Push(reflection, static_cast<uint16_t>(binding.descriptorType));
1230             Push(reflection, static_cast<uint16_t>(binding.descriptorCount));
1231             Push(reflection, static_cast<uint8_t>(binding.imageDimension));
1232             Push(reflection, static_cast<uint8_t>(binding.imageFlags));
1233         }
1234     });
1235     return offsetDescriptorSets;
1236 }
1237 
AppendVertexInputAttributes(std::vector<uint8_t> & reflection,array_view<const VertexInputAttributeDescription> vertexInputAttributes)1238 uint16_t AppendVertexInputAttributes(
1239     std::vector<uint8_t>& reflection, array_view<const VertexInputAttributeDescription> vertexInputAttributes)
1240 {
1241     const uint16_t offsetInputs = static_cast<uint16_t>(reflection.size());
1242     const auto size = static_cast<uint16_t>(vertexInputAttributes.size());
1243     Push(reflection, size);
1244     for (const auto& input : vertexInputAttributes) {
1245         Push(reflection, static_cast<uint16_t>(input.location));
1246         Push(reflection, static_cast<uint16_t>(input.format));
1247     }
1248     return offsetInputs;
1249 }
1250 
AppendExecutionLocalSize(std::vector<uint8_t> & reflection,const spirv_cross::Compiler & compiler)1251 uint16_t AppendExecutionLocalSize(std::vector<uint8_t>& reflection, const spirv_cross::Compiler& compiler)
1252 {
1253     const uint16_t offsetLocalSize = static_cast<uint16_t>(reflection.size());
1254     uint32_t size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 0u); // X = 0
1255     Push(reflection, size);
1256 
1257     size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 1u); // Y = 1
1258     Push(reflection, size);
1259 
1260     size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 2u); // Z = 2
1261     Push(reflection, size);
1262     return offsetLocalSize;
1263 }
1264 
ReflectSpvBinary(const std::vector<uint32_t> & aBinary,ShaderKind kind)1265 std::vector<uint8_t> ReflectSpvBinary(const std::vector<uint32_t>& aBinary, ShaderKind kind)
1266 {
1267     const spirv_cross::Compiler compiler(aBinary);
1268 
1269     const auto shaderStateFlags = ShaderStageFlags(kind);
1270 
1271     const spirv_cross::ShaderResources resources = compiler.get_shader_resources();
1272 
1273     PipelineLayout pipelineLayout;
1274     ReflectDescriptorSets(compiler, resources, shaderStateFlags, pipelineLayout.descriptorSetLayouts);
1275     pipelineLayout.descriptorSetCount =
1276         static_cast<uint32_t>(std::count_if(std::begin(pipelineLayout.descriptorSetLayouts),
1277             std::end(pipelineLayout.descriptorSetLayouts), [](const DescriptorSetLayout& layout) {
1278                 return layout.set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT;
1279             }));
1280     ReflectPushContants(compiler, resources, shaderStateFlags, pipelineLayout.pushConstant);
1281 
1282     // some additional information mainly for GL
1283     const auto pushConstantReflection = CreatePushConstantReflection(compiler, resources, shaderStateFlags);
1284 
1285     const auto specializationConstants = ReflectSpecializationConstants(compiler, shaderStateFlags);
1286 
1287     // NOTE: this is done for all although the name is 'Vertex'InputAttributes
1288     const auto vertexInputAttributes = ReflectVertexInputs(compiler, resources, shaderStateFlags);
1289 
1290     std::vector<uint8_t> reflection;
1291     reflection.reserve(512u);
1292     // tag
1293     reflection.insert(reflection.end(), std::begin(REFLECTION_TAG), std::end(REFLECTION_TAG));
1294 
1295     // shader type
1296     Push(reflection, static_cast<uint16_t>(shaderStateFlags.flags));
1297 
1298     // offsets
1299     // allocate size for the five offsets. data will be added after the offsets.
1300     reflection.resize(reflection.size() + sizeof(uint16_t) * 5u);
1301 
1302     // push constants
1303     const uint16_t offsetPushConstants =
1304         AppendPushConstants(reflection, pipelineLayout.pushConstant.byteSize, pushConstantReflection);
1305 
1306     // specialization constants
1307     const uint16_t offsetSpecializationConstants = AppendSpecializationConstants(reflection, specializationConstants);
1308 
1309     // descriptor sets
1310     const uint16_t offsetDescriptorSets = AppendDescriptorSets(reflection, pipelineLayout);
1311 
1312     // inputs
1313     const uint16_t offsetInputs = AppendVertexInputAttributes(reflection, vertexInputAttributes);
1314 
1315     // local size
1316     const uint16_t offsetLocalSize =
1317         (shaderStateFlags & ShaderStageFlagBits::COMPUTE_BIT) ? AppendExecutionLocalSize(reflection, compiler) : 0u;
1318 
1319     // update offsets to real values
1320     {
1321         auto ptr = reflection.data() + (sizeof(REFLECTION_TAG) + sizeof(uint16_t));
1322         *ptr++ = offsetPushConstants & 0xffu;
1323         *ptr++ = (offsetPushConstants >> 8u) & 0xffu;
1324         *ptr++ = offsetSpecializationConstants & 0xffu;
1325         *ptr++ = (offsetSpecializationConstants >> 8u) & 0xffu;
1326         *ptr++ = offsetDescriptorSets & 0xffu;
1327         *ptr++ = (offsetDescriptorSets >> 8u) & 0xffu;
1328         *ptr++ = offsetInputs & 0xffu;
1329         *ptr++ = (offsetInputs >> 8u) & 0xffu;
1330         *ptr++ = offsetLocalSize & 0xffu;
1331         *ptr++ = (offsetLocalSize >> 8u) & 0xffu;
1332     }
1333 
1334     return reflection;
1335 }
1336 
1337 struct Binding {
1338     uint8_t set;
1339     uint8_t bind;
1340 };
1341 
GetBinding(Gles::CoreCompiler & compiler,spirv_cross::ID id)1342 Binding GetBinding(Gles::CoreCompiler& compiler, spirv_cross::ID id)
1343 {
1344     const uint32_t dset = compiler.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
1345     const uint32_t dbind = compiler.get_decoration(id, spv::Decoration::DecorationBinding);
1346     assert(dset < Gles::ResourceLimits::MAX_SETS);
1347     assert(dbind < Gles::ResourceLimits::MAX_BIND_IN_SET);
1348     const uint8_t set = static_cast<uint8_t>(dset);
1349     const uint8_t bind = static_cast<uint8_t>(dbind);
1350     return { set, bind };
1351 }
1352 
SortSets(PipelineLayout & pipelineLayout)1353 void SortSets(PipelineLayout& pipelineLayout)
1354 {
1355     pipelineLayout.descriptorSetCount = 0;
1356     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
1357         DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
1358         if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
1359             pipelineLayout.descriptorSetCount++;
1360             std::sort(currSet.bindings.begin(), currSet.bindings.end(),
1361                 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
1362         }
1363     }
1364 }
1365 
Collect(Gles::CoreCompiler & compiler,const spirv_cross::SmallVector<spirv_cross::Resource> & resources,const uint32_t forceBinding=0)1366 void Collect(Gles::CoreCompiler& compiler, const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
1367     const uint32_t forceBinding = 0)
1368 {
1369     std::string name;
1370 
1371     for (const auto& remap : resources) {
1372         const auto binding = GetBinding(compiler, remap.id);
1373 
1374         name.resize(name.capacity() - 1);
1375         const auto nameLen = snprintf(name.data(), name.size(), "s%u_b%u", binding.set, binding.bind);
1376         name.resize(nameLen);
1377 
1378         // if name is empty it's a block and we need to rename the base_type_id i.e.
1379         // "uniform <base_type_id> { vec4 foo; } <id>;"
1380         if (auto origname = compiler.get_name(remap.id); origname.empty()) {
1381             compiler.set_name(remap.base_type_id, name);
1382             name.insert(name.begin(), '_');
1383             compiler.set_name(remap.id, name);
1384         } else {
1385             // "uniform <id> vec4 foo;"
1386             compiler.set_name(remap.id, name);
1387         }
1388 
1389         compiler.unset_decoration(remap.id, spv::DecorationDescriptorSet);
1390         compiler.unset_decoration(remap.id, spv::DecorationBinding);
1391         if (forceBinding > 0) {
1392             compiler.set_decoration(remap.id, spv::DecorationBinding,
1393                 forceBinding - 1); // will be over-written later. (special handling)
1394         }
1395     }
1396 }
1397 
1398 struct ShaderModulePlatformDataGLES {
1399     std::vector<Gles::PushConstantReflection> infos;
1400 };
1401 
CollectRes(Gles::CoreCompiler & compiler,const spirv_cross::ShaderResources & res,ShaderModulePlatformDataGLES & plat)1402 void CollectRes(
1403     Gles::CoreCompiler& compiler, const spirv_cross::ShaderResources& res, ShaderModulePlatformDataGLES& plat)
1404 {
1405     // collect names for later linkage
1406     static constexpr uint32_t defaultBinding = 11;
1407     Collect(compiler, res.storage_buffers, defaultBinding + 1);
1408     Collect(compiler, res.storage_images, defaultBinding + 1);
1409     Collect(compiler, res.uniform_buffers, 0); // 0 == remove binding decorations (let's the compiler decide)
1410     Collect(compiler, res.subpass_inputs, 0);  // 0 == remove binding decorations (let's the compiler decide)
1411 
1412     // handle the real sampled images.
1413     Collect(compiler, res.sampled_images, 0); // 0 == remove binding decorations (let's the compiler decide)
1414 
1415     // and now the "generated ones" (separate image/sampler)
1416     std::string imageName;
1417     std::string samplerName;
1418     std::string temp;
1419     for (auto& remap : compiler.get_combined_image_samplers()) {
1420         const auto imageBinding = GetBinding(compiler, remap.image_id);
1421         {
1422             imageName.resize(imageName.capacity() - 1);
1423             const auto nameLen =
1424                 snprintf(imageName.data(), imageName.size(), "s%u_b%u", imageBinding.set, imageBinding.bind);
1425             if (nameLen < 0) {
1426                 LUME_LOG_E("Could not get imageName, error");
1427                 return;
1428             }
1429             imageName.resize(nameLen);
1430         }
1431         const auto samplerBinding = GetBinding(compiler, remap.sampler_id);
1432         {
1433             samplerName.resize(samplerName.capacity() - 1);
1434             const auto nameLen =
1435                 snprintf(samplerName.data(), samplerName.size(), "s%u_b%u", samplerBinding.set, samplerBinding.bind);
1436             if (nameLen < 0) {
1437                 LUME_LOG_E("Could not get sampleName, error");
1438                 return;
1439             }
1440             samplerName.resize(nameLen);
1441         }
1442 
1443         temp.reserve(imageName.size() + samplerName.size() + 1);
1444         temp.clear();
1445         temp.append(imageName);
1446         temp.append(1, '_');
1447         temp.append(samplerName);
1448         compiler.set_name(remap.combined_id, temp);
1449     }
1450 }
1451 
1452 /** Device backend type */
1453 enum class DeviceBackendType {
1454     /** Vulkan backend */
1455     VULKAN,
1456     /** GLES backend */
1457     OPENGLES,
1458     /** OpenGL backend */
1459     OPENGL
1460 };
1461 
SetupSpirvCross(ShaderStageFlags stage,Gles::CoreCompiler * compiler,DeviceBackendType backend,bool ovrEnabled)1462 void SetupSpirvCross(ShaderStageFlags stage, Gles::CoreCompiler* compiler, DeviceBackendType backend, bool ovrEnabled)
1463 {
1464     spirv_cross::CompilerGLSL::Options options;
1465     if (backend == DeviceBackendType::OPENGLES) {
1466         static constexpr auto glesVersion = 320;
1467         options.version = glesVersion;
1468         options.es = true;
1469         options.fragment.default_float_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1470         options.fragment.default_int_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1471     }
1472 
1473     if (backend == DeviceBackendType::OPENGL) {
1474         static constexpr auto glVersion = 450;
1475         options.version = glVersion;
1476         options.es = false;
1477     }
1478 
1479 #if defined(CORE_USE_SEPARATE_SHADER_OBJECTS) && (CORE_USE_SEPARATE_SHADER_OBJECTS == 1)
1480     if (stage & (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT)) {
1481         options.separate_shader_objects = true;
1482     }
1483 #endif
1484 
1485     options.ovr_multiview_view_count = ovrEnabled ? 1 : 0;
1486 
1487     compiler->set_common_options(options);
1488 }
1489 
1490 struct Shader {
1491     ShaderStageFlags shaderStageFlags;
1492     DeviceBackendType backend { DeviceBackendType::OPENGL };
1493     ShaderModulePlatformDataGLES plat;
1494     bool ovrEnabled { false };
1495 
1496     std::string source_;
1497 };
1498 
ProcessShaderModule(Shader & me,const ShaderModuleCreateInfo & createInfo)1499 void ProcessShaderModule(Shader& me, const ShaderModuleCreateInfo& createInfo)
1500 {
1501     // perform reflection.
1502     auto compiler = Gles::CoreCompiler(reinterpret_cast<const uint32_t*>(createInfo.spvData.data()),
1503         static_cast<uint32_t>(createInfo.spvData.size() / sizeof(uint32_t)));
1504     // Set some options.
1505     SetupSpirvCross(me.shaderStageFlags, &compiler, me.backend, me.ovrEnabled);
1506 
1507     // first step in converting CORE_FLIP_NDC to regular uniform. (specializationconstant -> constant) this makes
1508     // the compiled glsl more readable, and simpler to post process later.
1509     Gles::ConvertSpecConstToConstant(compiler, "CORE_FLIP_NDC");
1510 
1511     auto active = compiler.get_active_interface_variables();
1512     const auto& res = compiler.get_shader_resources(active);
1513     compiler.set_enabled_interface_variables(std::move(active));
1514 
1515     Gles::ReflectPushConstants(compiler, res, me.plat.infos, me.shaderStageFlags);
1516     compiler.build_combined_image_samplers();
1517     CollectRes(compiler, res, me.plat);
1518 
1519     // set "CORE_BACKEND_TYPE" specialization to 1.
1520     Gles::SetSpecMacro(compiler, "CORE_BACKEND_TYPE", 1U);
1521 
1522     me.source_ = compiler.compile();
1523     Gles::ConvertConstantToUniform(compiler, me.source_, "CORE_FLIP_NDC");
1524 }
1525 
1526 template<typename T>
WriteToFile(const array_view<T> & data,std::filesystem::path destinationFile)1527 bool WriteToFile(const array_view<T>& data, std::filesystem::path destinationFile)
1528 {
1529     std::ofstream outputStream(destinationFile, std::ios::out | std::ios::binary);
1530     if (outputStream.is_open()) {
1531         outputStream.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(T));
1532         outputStream.close();
1533         return true;
1534     } else {
1535         LUME_LOG_E("Could not write file: '%s'", destinationFile.u8string().c_str());
1536         return false;
1537     }
1538 }
1539 
UsesMultiviewExtension(std::string_view shaderSource)1540 bool UsesMultiviewExtension(std::string_view shaderSource)
1541 {
1542     static constexpr const std::string_view multiview = "GL_EXT_multiview";
1543     for (auto pos = shaderSource.find(multiview); pos != std::string_view::npos;
1544          pos = shaderSource.find(multiview, pos + multiview.size())) {
1545         if ((shaderSource.rfind("#extension", pos) != std::string_view::npos) &&
1546             (shaderSource.find("enable", pos + multiview.size()) != std::string::npos)) {
1547             return true;
1548         }
1549     }
1550     return false;
1551 }
1552 
CreateGlShader(std::filesystem::path outputFilename,bool multiviewEnabled,DeviceBackendType backendType,const ShaderModuleCreateInfo & info)1553 bool CreateGlShader(std::filesystem::path outputFilename, bool multiviewEnabled, DeviceBackendType backendType,
1554     const ShaderModuleCreateInfo& info)
1555 {
1556     try {
1557         Shader shader;
1558         shader.shaderStageFlags = info.shaderStageFlags;
1559         shader.backend = backendType;
1560         shader.ovrEnabled = multiviewEnabled;
1561         ProcessShaderModule(shader, info);
1562         const auto* data = static_cast<const uint8_t*>(static_cast<const void*>(shader.source_.data()));
1563         WriteToFile(array_view(data, shader.source_.size()), outputFilename);
1564     } catch (std::exception const& e) {
1565         LUME_LOG_E("Failed to generate GL(ES) shader: %s", e.what());
1566         return false;
1567     }
1568     return true;
1569 }
1570 
CreateGlShaders(std::filesystem::path outputFilename,std::string_view shaderSource,ShaderKind shaderKind,array_view<const uint32_t> spvBinary,ShaderReflectionData reflectionData)1571 void CreateGlShaders(std::filesystem::path outputFilename, std::string_view shaderSource, ShaderKind shaderKind,
1572     array_view<const uint32_t> spvBinary, ShaderReflectionData reflectionData)
1573 {
1574     auto glFile = outputFilename;
1575     glFile += ".gl";
1576     const bool multiviewEnabled = UsesMultiviewExtension(shaderSource);
1577 
1578     const auto info = ShaderModuleCreateInfo {
1579         ShaderStageFlags(shaderKind),
1580         array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())), spvBinary.size_bytes()),
1581         reflectionData,
1582     };
1583 
1584     CreateGlShader(glFile, multiviewEnabled, DeviceBackendType::OPENGL, info);
1585     glFile += "es";
1586     CreateGlShader(glFile, multiviewEnabled, DeviceBackendType::OPENGLES, info);
1587 }
1588 
RunAllCompilationStages(std::string_view inputFilename,CompilationSettings & settings,const std::optional<Inputs> & params)1589 bool RunAllCompilationStages(
1590     std::string_view inputFilename, CompilationSettings& settings, const std::optional<Inputs>& params)
1591 {
1592     try {
1593         const auto inputFilenamePath = std::filesystem::u8path(inputFilename);
1594         const std::filesystem::path relativeInputFilename =
1595             std::filesystem::relative(inputFilenamePath, settings.shaderSourcePath);
1596         const std::string relativeFilename = relativeInputFilename.u8string();
1597         const std::string extension = inputFilenamePath.extension().string();
1598         std::filesystem::path outputFilename = settings.compiledShaderDestinationPath / relativeInputFilename;
1599 
1600         // Make sure the output dir hierarchy exists.
1601         std::filesystem::create_directories(outputFilename.parent_path());
1602 
1603         ShaderKind shaderKind;
1604 
1605         // Just copying json files to the destination dir.
1606         if (extension == ".json") {
1607             if (!std::filesystem::exists(outputFilename) ||
1608                 !std::filesystem::equivalent(inputFilenamePath, outputFilename)) {
1609                 LUME_LOG_I("  %s", relativeFilename.c_str());
1610                 std::filesystem::copy(
1611                     inputFilenamePath, outputFilename, std::filesystem::copy_options::overwrite_existing);
1612             }
1613             return true;
1614         } else if (extension == ".vert") {
1615             shaderKind = ShaderKind::VERTEX;
1616         } else if (extension == ".frag") {
1617             shaderKind = ShaderKind::FRAGMENT;
1618         } else if (extension == ".comp") {
1619             shaderKind = ShaderKind::COMPUTE;
1620         } else {
1621             return false;
1622         }
1623 
1624         outputFilename += ".spv";
1625 
1626         LUME_LOG_I("  %s", relativeFilename.c_str());
1627 
1628         LUME_LOG_V("    input: '%s'", inputFilename.data());
1629         LUME_LOG_V("      dst: '%s'", settings.compiledShaderDestinationPath.u8string().c_str());
1630         LUME_LOG_V(" relative: '%s'", relativeFilename.c_str());
1631         LUME_LOG_V("   output: '%s'", outputFilename.u8string().c_str());
1632 
1633         const std::string shaderSource = ReadFileToString(inputFilename);
1634         if (shaderSource.empty()) {
1635             return false;
1636         }
1637 
1638         const std::string preProcessedShader = PreProcessShader(shaderSource, shaderKind, relativeFilename, settings);
1639         if (preProcessedShader.empty()) {
1640             return false;
1641         }
1642         if constexpr (false) {
1643             auto preprocessedFile = outputFilename;
1644             preprocessedFile += ".pre";
1645             if (!WriteToFile(array_view(preProcessedShader.data(), preProcessedShader.size()), preprocessedFile)) {
1646                 LUME_LOG_E("Failed to save preprocessed %s", preprocessedFile.u8string().data());
1647             }
1648         }
1649 
1650         auto spvBinary = CompileShaderToSpirvBinary(preProcessedShader, shaderKind, relativeFilename, settings);
1651         if (spvBinary.empty()) {
1652             return false;
1653         }
1654 
1655         const auto reflection = ReflectSpvBinary(spvBinary, shaderKind);
1656         if (reflection.empty()) {
1657             LUME_LOG_E("Failed to reflect %s", inputFilename.data());
1658         } else {
1659             auto reflectionFile = outputFilename;
1660             reflectionFile += ".lsb";
1661             if (!WriteToFile(array_view(reflection.data(), reflection.size()), reflectionFile)) {
1662                 LUME_LOG_E("Failed to save reflection %s", reflectionFile.u8string().data());
1663             }
1664         }
1665 
1666         // spirv-opt resets the passes everytime so then need to be setup
1667         if (params->optimizeSpirv) {
1668             settings.optimizer->RegisterPerformancePasses();
1669         }
1670 
1671         RegisterStripPreprocessorDebugInfoPass(settings.optimizer);
1672         if (!settings.optimizer->Run(spvBinary.data(), spvBinary.size(), &spvBinary)) {
1673             LUME_LOG_E("Failed to optimize %s", inputFilename.data());
1674         }
1675 
1676         // generate gl and gles shaders from optimized binary but with file names intact but with just preprocessor
1677         // extension directives stripped out
1678         CreateGlShaders(outputFilename, preProcessedShader, shaderKind, spvBinary, ShaderReflectionData { reflection });
1679 
1680         // strip out all other debug information like variable names, function names
1681         if (params->stripDebugInformation == true) {
1682             bool registerPass = settings.optimizer->RegisterPassFromFlag("--strip-debug");
1683             if (registerPass == false || !settings.optimizer->Run(spvBinary.data(), spvBinary.size(), &spvBinary)) {
1684                 LUME_LOG_E("Failed to strip debug information %s", inputFilename.data());
1685             }
1686         }
1687 
1688         // write the spirv-binary to disk
1689         if (!WriteToFile(array_view(spvBinary.data(), spvBinary.size()), outputFilename)) {
1690             return false;
1691         }
1692 
1693         LUME_LOG_D("  -> %s", outputFilename.u8string().c_str());
1694 
1695         return true;
1696     } catch (std::exception const& e) {
1697         LUME_LOG_E("Processing file failed '%s': %s", inputFilename.data(), e.what());
1698     }
1699     return false;
1700 }
1701 
ShowUsage()1702 void ShowUsage()
1703 {
1704     std::cout << "LumeShaderCompiler - Supported shader types: vertex (.vert), fragment (.frag), compute (.comp)\n\n"
1705                  "How to use: \n"
1706                  "LumeShaderCompiler.exe --source <source path> (sets destination path to same as source)\n"
1707                  "LumeShaderCompiler.exe --source <source path> --destination <destination path>\n"
1708                  "LumeShaderCompiler.exe --monitor (monitors changes in the source files)\n";
1709 }
1710 
FilterByExtension(const std::vector<std::string> & aFilenames,const std::vector<std::string_view> & aIncludeExtensions)1711 std::vector<std::string> FilterByExtension(
1712     const std::vector<std::string>& aFilenames, const std::vector<std::string_view>& aIncludeExtensions)
1713 {
1714     std::vector<std::string> filtered;
1715     for (auto const& file : aFilenames) {
1716         std::string lowercaseFileExt = std::filesystem::u8path(file).extension().u8string();
1717         std::transform(lowercaseFileExt.begin(), lowercaseFileExt.end(), lowercaseFileExt.begin(), tolower);
1718         if (std::any_of(aIncludeExtensions.cbegin(), aIncludeExtensions.cend(),
1719             [lowercaseExt = std::string_view(lowercaseFileExt)](
1720                 const std::string_view& extension) { return (lowercaseExt == extension); })) {
1721             filtered.push_back(file);
1722         }
1723     }
1724 
1725     return filtered;
1726 }
1727 
1728 struct Args {
1729     std::string_view name;
1730     int32_t additionalArguments;
1731     bool (*handler)(Inputs&, char* []);
1732 };
1733 
1734 constexpr Args ARGS[] = {
1735     { "--help", 0,
__anon44c0966e0f02() 1736         [](Inputs&, char *[]) {
1737             ShowUsage();
1738             return false;
1739         } },
1740     { "--sourceFile", 1,
__anon44c0966e1002() 1741         [](Inputs& params, char* argv[]) {
1742             params.sourceFile = std::filesystem::u8path(*argv);
1743             params.sourceFile.make_preferred();
1744             params.shaderSourcesPath = params.sourceFile;
1745             params.shaderSourcesPath.remove_filename();
1746             if (params.compiledShaderDestinationPath.empty()) {
1747                 params.compiledShaderDestinationPath = params.shaderSourcesPath;
1748             }
1749             return true;
1750         } },
1751     { "--source", 1,
__anon44c0966e1102() 1752         [](Inputs& params, char* argv[]) {
1753             params.shaderSourcesPath = std::filesystem::u8path(*argv);
1754             params.shaderSourcesPath.make_preferred();
1755             if (params.compiledShaderDestinationPath.empty()) {
1756                 params.compiledShaderDestinationPath = params.shaderSourcesPath;
1757             }
1758             return true;
1759         } },
1760     { "--destination", 1,
__anon44c0966e1202() 1761         [](Inputs& params, char* argv[]) {
1762             params.compiledShaderDestinationPath = std::filesystem::u8path(*argv);
1763             params.compiledShaderDestinationPath.make_preferred();
1764             return true;
1765         } },
1766     { "--include", 1,
__anon44c0966e1302() 1767         [](Inputs& params, char* argv[]) {
1768             params.shaderIncludePaths.emplace_back(std::filesystem::u8path(*argv)).make_preferred();
1769             return true;
1770         } },
1771     { "--monitor", 0,
__anon44c0966e1402() 1772         [](Inputs& params, char* argv[]) {
1773             params.monitorChanges = true;
1774             return true;
1775         } },
1776     { "--optimize", 0,
__anon44c0966e1502() 1777         [](Inputs& params, char* argv[]) {
1778             params.optimizeSpirv = true;
1779             return true;
1780         } },
1781     { "--strip-debug-information", 0,
__anon44c0966e1602() 1782         [](Inputs& params, char* argv[]) {
1783             params.stripDebugInformation = true;
1784             return true;
1785         } },
1786     { "--vulkan", 1,
__anon44c0966e1702() 1787         [](Inputs& params, char* argv[]) {
1788             const auto version = std::string_view(*argv);
1789             if (version == "1.0") {
1790                 params.envVersion = ShaderEnv::version_vulkan_1_0;
1791             } else if (version == "1.1") {
1792                 params.envVersion = ShaderEnv::version_vulkan_1_1;
1793             } else if (version == "1.2") {
1794                 params.envVersion = ShaderEnv::version_vulkan_1_2;
1795 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
1796             } else if (version == "1.3") {
1797                 params.envVersion = ShaderEnv::version_vulkan_1_3;
1798 #endif
1799             } else {
1800                 LUME_LOG_E("Invalid argument for option --vulkan.");
1801                 return false;
1802             }
1803             return true;
1804         } },
1805 };
1806 
Parse(const int argc,char * argv[])1807 std::optional<Inputs> Parse(const int argc, char* argv[])
1808 {
1809     Inputs params;
1810     const std::filesystem::path currentFolder = std::filesystem::current_path();
1811     params.shaderSourcesPath = currentFolder;
1812 
1813     for (int i = 1; i < argc;) {
1814         const auto arg = std::string_view(argv[i]);
1815         const auto pos =
1816             std::find_if(std::begin(ARGS), std::end(ARGS), [arg](const Args& item) { return item.name == arg; });
1817         if (pos == std::end(ARGS)) {
1818             LUME_LOG_E("Unknonw argument %.*s.\n", int(arg.size()), arg.data());
1819             return std::nullopt;
1820         }
1821         if ((i + pos->additionalArguments) >= argc) {
1822             LUME_LOG_E("%.*s option requires %u additional arguments.\n", int(arg.size()), arg.data(),
1823                 pos->additionalArguments);
1824             return std::nullopt;
1825         }
1826 
1827         if (!pos->handler(params, argv + (i + 1))) {
1828             return std::nullopt;
1829         }
1830         i = (i + 1) + pos->additionalArguments;
1831     }
1832     if (params.compiledShaderDestinationPath.empty()) {
1833         params.compiledShaderDestinationPath = currentFolder;
1834     }
1835 
1836     return params;
1837 }
1838 } // namespace
1839 
CompilerMain(int argc,char * argv[])1840 int CompilerMain(int argc, char* argv[])
1841 {
1842     if (argc == 1) {
1843         ShowUsage();
1844         return 0;
1845     }
1846 
1847     const std::optional<Inputs> params = Parse(argc, argv);
1848     if (!params) {
1849         return 1;
1850     }
1851     ige::FileMonitor fileMonitor;
1852 
1853     if (!std::filesystem::exists(params->shaderSourcesPath)) {
1854         LUME_LOG_E("Source path does not exist: '%s'", params->shaderSourcesPath.u8string().c_str());
1855         return 1;
1856     }
1857 
1858     // Make sure the destination dir exists.
1859     std::filesystem::create_directories(params->compiledShaderDestinationPath);
1860 
1861     if (!std::filesystem::exists(params->compiledShaderDestinationPath)) {
1862         LUME_LOG_E("Destination path does not exist: '%s'", params->compiledShaderDestinationPath.u8string().c_str());
1863         return 1;
1864     }
1865 
1866     fileMonitor.AddPath(params->shaderSourcesPath.u8string());
1867     std::vector<std::string> fileList = [&]() {
1868         std::vector<std::string> list;
1869         if (!params->sourceFile.empty()) {
1870             list.push_back(params->sourceFile.u8string());
1871         } else {
1872             list = fileMonitor.GetMonitoredFiles();
1873         }
1874         return list;
1875     }();
1876 
1877     const std::vector<std::string_view> supportedFileTypes = { ".vert", ".frag", ".comp", ".json" };
1878     fileList = FilterByExtension(fileList, supportedFileTypes);
1879 
1880     LUME_LOG_I("     Source path: '%s'", std::filesystem::absolute(params->shaderSourcesPath).u8string().c_str());
1881     for (auto const& path : params->shaderIncludePaths) {
1882         LUME_LOG_I("    Include path: '%s'", std::filesystem::absolute(path).u8string().c_str());
1883     }
1884     LUME_LOG_I(
1885         "Destination path: '%s'", std::filesystem::absolute(params->compiledShaderDestinationPath).u8string().c_str());
1886     LUME_LOG_I("");
1887     LUME_LOG_I("Processing:");
1888 
1889     int errorCount = 0;
1890     Scope scope(glslang::InitializeProcess, glslang::FinalizeProcess);
1891 
1892     std::vector<std::filesystem::path> searchPath;
1893     searchPath.reserve(1U + params->shaderIncludePaths.size());
1894     searchPath.emplace_back(params->shaderSourcesPath.u8string());
1895     std::transform(params->shaderIncludePaths.cbegin(), params->shaderIncludePaths.cend(),
1896         std::back_inserter(searchPath), [](const std::filesystem::path& path) { return path.u8string(); });
1897 
1898     auto fileIncluder = FileIncluder(params->shaderSourcesPath, searchPath);
1899     auto settings = CompilationSettings { params->envVersion, searchPath, {}, params->shaderSourcesPath,
1900         params->compiledShaderDestinationPath, fileIncluder };
1901 
1902     {
1903         spv_target_env targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1904         switch (params->envVersion) {
1905             case ShaderEnv::version_vulkan_1_0:
1906                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1907                 break;
1908             case ShaderEnv::version_vulkan_1_1:
1909                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_1;
1910                 break;
1911             case ShaderEnv::version_vulkan_1_2:
1912                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_2;
1913                 break;
1914             case ShaderEnv::version_vulkan_1_3:
1915                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_3;
1916                 break;
1917             default:
1918                 break;
1919         }
1920         settings.optimizer.emplace(targetEnv);
1921     }
1922 
1923     // Startup compilation.
1924     for (auto const& file : fileList) {
1925         std::string relativeFilename = std::filesystem::relative(file, params->shaderSourcesPath).u8string();
1926         LUME_LOG_D("Tracked source file: '%s'", relativeFilename.c_str());
1927         if (!RunAllCompilationStages(file, settings, params)) {
1928             errorCount++;
1929         }
1930     }
1931 
1932     if (errorCount == 0) {
1933         LUME_LOG_I("Success.");
1934     } else {
1935         LUME_LOG_E("Failed: %d", errorCount);
1936     }
1937 
1938     if (params->monitorChanges) {
1939         LUME_LOG_I("Monitoring file changes.");
1940     }
1941 
1942     // Main loop.
1943     while (params->monitorChanges) {
1944         std::vector<std::string> addedFiles;
1945         std::vector<std::string> removedFiles;
1946         std::vector<std::string> modifiedFiles;
1947         fileMonitor.ScanModifications(addedFiles, removedFiles, modifiedFiles);
1948         modifiedFiles = FilterByExtension(modifiedFiles, supportedFileTypes);
1949 
1950         if (params->sourceFile.empty()) {
1951             fileIncluder.Reset();
1952             addedFiles = FilterByExtension(addedFiles, supportedFileTypes);
1953             removedFiles = FilterByExtension(removedFiles, supportedFileTypes);
1954 
1955             if (!addedFiles.empty()) {
1956                 LUME_LOG_I("Files added:");
1957                 for (auto const& addedFile : addedFiles) {
1958                     RunAllCompilationStages(addedFile, settings, params);
1959                 }
1960             }
1961 
1962             if (!removedFiles.empty()) {
1963                 LUME_LOG_I("Files removed:");
1964                 for (auto const& removedFile : removedFiles) {
1965                     std::string relativeFilename =
1966                         std::filesystem::relative(removedFile, params->shaderSourcesPath).u8string();
1967                     LUME_LOG_I("  %s", relativeFilename.c_str());
1968                 }
1969             }
1970 
1971             if (!modifiedFiles.empty()) {
1972                 LUME_LOG_I("Files modified:");
1973                 for (auto const& modifiedFile : modifiedFiles) {
1974                     RunAllCompilationStages(modifiedFile, settings, params);
1975                 }
1976             }
1977         } else if (!modifiedFiles.empty()) {
1978             fileIncluder.Reset();
1979             auto pos = std::find_if(modifiedFiles.cbegin(), modifiedFiles.cend(),
1980                 [&sourceFile = params->sourceFile](const std::string& modified) { return modified == sourceFile; });
1981             if (pos != modifiedFiles.cend()) {
1982                 RunAllCompilationStages(*pos, settings, params);
1983             }
1984         }
1985 
1986         std::this_thread::sleep_for(std::chrono::seconds(1));
1987     }
1988 
1989     return errorCount;
1990 }
1991