1 /*
2 * Copyright 2016-2021 Robert Konrad
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 /*
20 * At your option, you may choose to accept this material under either:
21 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23 */
24
25 #include "spirv_hlsl.hpp"
26 #include "GLSL.std.450.h"
27 #include <algorithm>
28 #include <assert.h>
29
30 using namespace spv;
31 using namespace SPIRV_CROSS_NAMESPACE;
32 using namespace std;
33
34 enum class ImageFormatNormalizedState
35 {
36 None = 0,
37 Unorm = 1,
38 Snorm = 2
39 };
40
image_format_to_normalized_state(ImageFormat fmt)41 static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42 {
43 switch (fmt)
44 {
45 case ImageFormatR8:
46 case ImageFormatR16:
47 case ImageFormatRg8:
48 case ImageFormatRg16:
49 case ImageFormatRgba8:
50 case ImageFormatRgba16:
51 case ImageFormatRgb10A2:
52 return ImageFormatNormalizedState::Unorm;
53
54 case ImageFormatR8Snorm:
55 case ImageFormatR16Snorm:
56 case ImageFormatRg8Snorm:
57 case ImageFormatRg16Snorm:
58 case ImageFormatRgba8Snorm:
59 case ImageFormatRgba16Snorm:
60 return ImageFormatNormalizedState::Snorm;
61
62 default:
63 break;
64 }
65
66 return ImageFormatNormalizedState::None;
67 }
68
image_format_to_components(ImageFormat fmt)69 static unsigned image_format_to_components(ImageFormat fmt)
70 {
71 switch (fmt)
72 {
73 case ImageFormatR8:
74 case ImageFormatR16:
75 case ImageFormatR8Snorm:
76 case ImageFormatR16Snorm:
77 case ImageFormatR16f:
78 case ImageFormatR32f:
79 case ImageFormatR8i:
80 case ImageFormatR16i:
81 case ImageFormatR32i:
82 case ImageFormatR8ui:
83 case ImageFormatR16ui:
84 case ImageFormatR32ui:
85 return 1;
86
87 case ImageFormatRg8:
88 case ImageFormatRg16:
89 case ImageFormatRg8Snorm:
90 case ImageFormatRg16Snorm:
91 case ImageFormatRg16f:
92 case ImageFormatRg32f:
93 case ImageFormatRg8i:
94 case ImageFormatRg16i:
95 case ImageFormatRg32i:
96 case ImageFormatRg8ui:
97 case ImageFormatRg16ui:
98 case ImageFormatRg32ui:
99 return 2;
100
101 case ImageFormatR11fG11fB10f:
102 return 3;
103
104 case ImageFormatRgba8:
105 case ImageFormatRgba16:
106 case ImageFormatRgb10A2:
107 case ImageFormatRgba8Snorm:
108 case ImageFormatRgba16Snorm:
109 case ImageFormatRgba16f:
110 case ImageFormatRgba32f:
111 case ImageFormatRgba8i:
112 case ImageFormatRgba16i:
113 case ImageFormatRgba32i:
114 case ImageFormatRgba8ui:
115 case ImageFormatRgba16ui:
116 case ImageFormatRgba32ui:
117 case ImageFormatRgb10a2ui:
118 return 4;
119
120 case ImageFormatUnknown:
121 return 4; // Assume 4.
122
123 default:
124 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 }
126 }
127
image_format_to_type(ImageFormat fmt,SPIRType::BaseType basetype)128 static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129 {
130 switch (fmt)
131 {
132 case ImageFormatR8:
133 case ImageFormatR16:
134 if (basetype != SPIRType::Float)
135 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 return "unorm float";
137 case ImageFormatRg8:
138 case ImageFormatRg16:
139 if (basetype != SPIRType::Float)
140 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 return "unorm float2";
142 case ImageFormatRgba8:
143 case ImageFormatRgba16:
144 if (basetype != SPIRType::Float)
145 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 return "unorm float4";
147 case ImageFormatRgb10A2:
148 if (basetype != SPIRType::Float)
149 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 return "unorm float4";
151
152 case ImageFormatR8Snorm:
153 case ImageFormatR16Snorm:
154 if (basetype != SPIRType::Float)
155 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 return "snorm float";
157 case ImageFormatRg8Snorm:
158 case ImageFormatRg16Snorm:
159 if (basetype != SPIRType::Float)
160 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 return "snorm float2";
162 case ImageFormatRgba8Snorm:
163 case ImageFormatRgba16Snorm:
164 if (basetype != SPIRType::Float)
165 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 return "snorm float4";
167
168 case ImageFormatR16f:
169 case ImageFormatR32f:
170 if (basetype != SPIRType::Float)
171 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 return "float";
173 case ImageFormatRg16f:
174 case ImageFormatRg32f:
175 if (basetype != SPIRType::Float)
176 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 return "float2";
178 case ImageFormatRgba16f:
179 case ImageFormatRgba32f:
180 if (basetype != SPIRType::Float)
181 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 return "float4";
183
184 case ImageFormatR11fG11fB10f:
185 if (basetype != SPIRType::Float)
186 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 return "float3";
188
189 case ImageFormatR8i:
190 case ImageFormatR16i:
191 case ImageFormatR32i:
192 if (basetype != SPIRType::Int)
193 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 return "int";
195 case ImageFormatRg8i:
196 case ImageFormatRg16i:
197 case ImageFormatRg32i:
198 if (basetype != SPIRType::Int)
199 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 return "int2";
201 case ImageFormatRgba8i:
202 case ImageFormatRgba16i:
203 case ImageFormatRgba32i:
204 if (basetype != SPIRType::Int)
205 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 return "int4";
207
208 case ImageFormatR8ui:
209 case ImageFormatR16ui:
210 case ImageFormatR32ui:
211 if (basetype != SPIRType::UInt)
212 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 return "uint";
214 case ImageFormatRg8ui:
215 case ImageFormatRg16ui:
216 case ImageFormatRg32ui:
217 if (basetype != SPIRType::UInt)
218 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 return "uint2";
220 case ImageFormatRgba8ui:
221 case ImageFormatRgba16ui:
222 case ImageFormatRgba32ui:
223 if (basetype != SPIRType::UInt)
224 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 return "uint4";
226 case ImageFormatRgb10a2ui:
227 if (basetype != SPIRType::UInt)
228 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 return "uint4";
230
231 case ImageFormatUnknown:
232 switch (basetype)
233 {
234 case SPIRType::Float:
235 return "float4";
236 case SPIRType::Int:
237 return "int4";
238 case SPIRType::UInt:
239 return "uint4";
240 default:
241 SPIRV_CROSS_THROW("Unsupported base type for image.");
242 }
243
244 default:
245 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 }
247 }
248
image_type_hlsl_modern(const SPIRType & type,uint32_t id)249 string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250 {
251 auto &imagetype = get<SPIRType>(type.image.type);
252 const char *dim = nullptr;
253 bool typed_load = false;
254 uint32_t components = 4;
255
256 bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, DecorationNonWritable);
257
258 switch (type.image.dim)
259 {
260 case Dim1D:
261 typed_load = type.image.sampled == 2;
262 dim = "1D";
263 break;
264 case Dim2D:
265 typed_load = type.image.sampled == 2;
266 dim = "2D";
267 break;
268 case Dim3D:
269 typed_load = type.image.sampled == 2;
270 dim = "3D";
271 break;
272 case DimCube:
273 if (type.image.sampled == 2)
274 SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 dim = "Cube";
276 break;
277 case DimRect:
278 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 case DimBuffer:
280 if (type.image.sampled == 1)
281 return join("Buffer<", type_to_glsl(imagetype), components, ">");
282 else if (type.image.sampled == 2)
283 {
284 if (interlocked_resources.count(id))
285 return join("RasterizerOrderedBuffer<", image_format_to_type(type.image.format, imagetype.basetype),
286 ">");
287
288 typed_load = !force_image_srv && type.image.sampled == 2;
289
290 const char *rw = force_image_srv ? "" : "RW";
291 return join(rw, "Buffer<",
292 typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
293 join(type_to_glsl(imagetype), components),
294 ">");
295 }
296 else
297 SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 case DimSubpassData:
299 dim = "2D";
300 typed_load = false;
301 break;
302 default:
303 SPIRV_CROSS_THROW("Invalid dimension.");
304 }
305 const char *arrayed = type.image.arrayed ? "Array" : "";
306 const char *ms = type.image.ms ? "MS" : "";
307 const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309 if (force_image_srv)
310 typed_load = false;
311
312 if (typed_load && interlocked_resources.count(id))
313 rw = "RasterizerOrdered";
314
315 return join(rw, "Texture", dim, ms, arrayed, "<",
316 typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
317 join(type_to_glsl(imagetype), components),
318 ">");
319 }
320
image_type_hlsl_legacy(const SPIRType & type,uint32_t)321 string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322 {
323 auto &imagetype = get<SPIRType>(type.image.type);
324 string res;
325
326 switch (imagetype.basetype)
327 {
328 case SPIRType::Int:
329 res = "i";
330 break;
331 case SPIRType::UInt:
332 res = "u";
333 break;
334 default:
335 break;
336 }
337
338 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341 // If we're emulating subpassInput with samplers, force sampler2D
342 // so we don't have to specify format.
343 if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 {
345 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 res += "sampler";
348 else
349 res += type.image.sampled == 2 ? "image" : "texture";
350 }
351 else
352 res += "sampler";
353
354 switch (type.image.dim)
355 {
356 case Dim1D:
357 res += "1D";
358 break;
359 case Dim2D:
360 res += "2D";
361 break;
362 case Dim3D:
363 res += "3D";
364 break;
365 case DimCube:
366 res += "CUBE";
367 break;
368
369 case DimBuffer:
370 res += "Buffer";
371 break;
372
373 case DimSubpassData:
374 res += "2D";
375 break;
376 default:
377 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 }
379
380 if (type.image.ms)
381 res += "MS";
382 if (type.image.arrayed)
383 res += "Array";
384
385 return res;
386 }
387
image_type_hlsl(const SPIRType & type,uint32_t id)388 string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389 {
390 if (hlsl_options.shader_model <= 30)
391 return image_type_hlsl_legacy(type, id);
392 else
393 return image_type_hlsl_modern(type, id);
394 }
395
396 // The optional id parameter indicates the object whose type we are trying
397 // to find the description for. It is optional. Most type descriptions do not
398 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)399 string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400 {
401 // Ignore the pointer type since GLSL doesn't have pointers.
402
403 switch (type.basetype)
404 {
405 case SPIRType::Struct:
406 // Need OpName lookup here to get a "sensible" name for a struct.
407 if (backend.explicit_struct_type)
408 return join("struct ", to_name(type.self));
409 else
410 return to_name(type.self);
411
412 case SPIRType::Image:
413 case SPIRType::SampledImage:
414 return image_type_hlsl(type, id);
415
416 case SPIRType::Sampler:
417 return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState";
418
419 case SPIRType::Void:
420 return "void";
421
422 default:
423 break;
424 }
425
426 if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 {
428 switch (type.basetype)
429 {
430 case SPIRType::Boolean:
431 return "bool";
432 case SPIRType::Int:
433 return backend.basic_int_type;
434 case SPIRType::UInt:
435 return backend.basic_uint_type;
436 case SPIRType::AtomicCounter:
437 return "atomic_uint";
438 case SPIRType::Half:
439 if (hlsl_options.enable_16bit_types)
440 return "half";
441 else
442 return "min16float";
443 case SPIRType::Short:
444 if (hlsl_options.enable_16bit_types)
445 return "int16_t";
446 else
447 return "min16int";
448 case SPIRType::UShort:
449 if (hlsl_options.enable_16bit_types)
450 return "uint16_t";
451 else
452 return "min16uint";
453 case SPIRType::Float:
454 return "float";
455 case SPIRType::Double:
456 return "double";
457 case SPIRType::Int64:
458 if (hlsl_options.shader_model < 60)
459 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 return "int64_t";
461 case SPIRType::UInt64:
462 if (hlsl_options.shader_model < 60)
463 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 return "uint64_t";
465 default:
466 return "???";
467 }
468 }
469 else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
470 {
471 switch (type.basetype)
472 {
473 case SPIRType::Boolean:
474 return join("bool", type.vecsize);
475 case SPIRType::Int:
476 return join("int", type.vecsize);
477 case SPIRType::UInt:
478 return join("uint", type.vecsize);
479 case SPIRType::Half:
480 return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
481 case SPIRType::Short:
482 return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
483 case SPIRType::UShort:
484 return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
485 case SPIRType::Float:
486 return join("float", type.vecsize);
487 case SPIRType::Double:
488 return join("double", type.vecsize);
489 case SPIRType::Int64:
490 return join("i64vec", type.vecsize);
491 case SPIRType::UInt64:
492 return join("u64vec", type.vecsize);
493 default:
494 return "???";
495 }
496 }
497 else
498 {
499 switch (type.basetype)
500 {
501 case SPIRType::Boolean:
502 return join("bool", type.columns, "x", type.vecsize);
503 case SPIRType::Int:
504 return join("int", type.columns, "x", type.vecsize);
505 case SPIRType::UInt:
506 return join("uint", type.columns, "x", type.vecsize);
507 case SPIRType::Half:
508 return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
509 case SPIRType::Short:
510 return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
511 case SPIRType::UShort:
512 return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
513 case SPIRType::Float:
514 return join("float", type.columns, "x", type.vecsize);
515 case SPIRType::Double:
516 return join("double", type.columns, "x", type.vecsize);
517 // Matrix types not supported for int64/uint64.
518 default:
519 return "???";
520 }
521 }
522 }
523
emit_header()524 void CompilerHLSL::emit_header()
525 {
526 for (auto &header : header_lines)
527 statement(header);
528
529 if (header_lines.size() > 0)
530 {
531 statement("");
532 }
533 }
534
emit_interface_block_globally(const SPIRVariable & var)535 void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
536 {
537 add_resource_name(var.self);
538
539 // The global copies of I/O variables should not contain interpolation qualifiers.
540 // These are emitted inside the interface structs.
541 auto &flags = ir.meta[var.self].decoration.decoration_flags;
542 auto old_flags = flags;
543 flags.reset();
544 statement("static ", variable_decl(var), ";");
545 flags = old_flags;
546 }
547
to_storage_qualifiers_glsl(const SPIRVariable & var)548 const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
549 {
550 // Input and output variables are handled specially in HLSL backend.
551 // The variables are declared as global, private variables, and do not need any qualifiers.
552 if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
553 var.storage == StorageClassPushConstant)
554 {
555 return "uniform ";
556 }
557
558 return "";
559 }
560
emit_builtin_outputs_in_struct()561 void CompilerHLSL::emit_builtin_outputs_in_struct()
562 {
563 auto &execution = get_entry_point();
564
565 bool legacy = hlsl_options.shader_model <= 30;
566 active_output_builtins.for_each_bit([&](uint32_t i) {
567 const char *type = nullptr;
568 const char *semantic = nullptr;
569 auto builtin = static_cast<BuiltIn>(i);
570 switch (builtin)
571 {
572 case BuiltInPosition:
573 type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
574 semantic = legacy ? "POSITION" : "SV_Position";
575 break;
576
577 case BuiltInSampleMask:
578 if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
579 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
580 type = "uint";
581 semantic = "SV_Coverage";
582 break;
583
584 case BuiltInFragDepth:
585 type = "float";
586 if (legacy)
587 {
588 semantic = "DEPTH";
589 }
590 else
591 {
592 if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthGreater))
593 semantic = "SV_DepthGreaterEqual";
594 else if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthLess))
595 semantic = "SV_DepthLessEqual";
596 else
597 semantic = "SV_Depth";
598 }
599 break;
600
601 case BuiltInClipDistance:
602 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
603 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
604 {
605 uint32_t to_declare = clip_distance_count - clip;
606 if (to_declare > 4)
607 to_declare = 4;
608
609 uint32_t semantic_index = clip / 4;
610
611 static const char *types[] = { "float", "float2", "float3", "float4" };
612 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
613 " : SV_ClipDistance", semantic_index, ";");
614 }
615 break;
616
617 case BuiltInCullDistance:
618 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
619 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
620 {
621 uint32_t to_declare = cull_distance_count - cull;
622 if (to_declare > 4)
623 to_declare = 4;
624
625 uint32_t semantic_index = cull / 4;
626
627 static const char *types[] = { "float", "float2", "float3", "float4" };
628 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
629 " : SV_CullDistance", semantic_index, ";");
630 }
631 break;
632
633 case BuiltInPointSize:
634 // If point_size_compat is enabled, just ignore PointSize.
635 // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
636 // even if it means working around the missing feature.
637 if (hlsl_options.point_size_compat)
638 break;
639 else
640 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
641
642 default:
643 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
644 }
645
646 if (type && semantic)
647 statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
648 });
649 }
650
emit_builtin_inputs_in_struct()651 void CompilerHLSL::emit_builtin_inputs_in_struct()
652 {
653 bool legacy = hlsl_options.shader_model <= 30;
654 active_input_builtins.for_each_bit([&](uint32_t i) {
655 const char *type = nullptr;
656 const char *semantic = nullptr;
657 auto builtin = static_cast<BuiltIn>(i);
658 switch (builtin)
659 {
660 case BuiltInFragCoord:
661 type = "float4";
662 semantic = legacy ? "VPOS" : "SV_Position";
663 break;
664
665 case BuiltInVertexId:
666 case BuiltInVertexIndex:
667 if (legacy)
668 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
669 type = "uint";
670 semantic = "SV_VertexID";
671 break;
672
673 case BuiltInInstanceId:
674 case BuiltInInstanceIndex:
675 if (legacy)
676 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
677 type = "uint";
678 semantic = "SV_InstanceID";
679 break;
680
681 case BuiltInSampleId:
682 if (legacy)
683 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
684 type = "uint";
685 semantic = "SV_SampleIndex";
686 break;
687
688 case BuiltInSampleMask:
689 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
690 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
691 type = "uint";
692 semantic = "SV_Coverage";
693 break;
694
695 case BuiltInGlobalInvocationId:
696 type = "uint3";
697 semantic = "SV_DispatchThreadID";
698 break;
699
700 case BuiltInLocalInvocationId:
701 type = "uint3";
702 semantic = "SV_GroupThreadID";
703 break;
704
705 case BuiltInLocalInvocationIndex:
706 type = "uint";
707 semantic = "SV_GroupIndex";
708 break;
709
710 case BuiltInWorkgroupId:
711 type = "uint3";
712 semantic = "SV_GroupID";
713 break;
714
715 case BuiltInFrontFacing:
716 type = "bool";
717 semantic = "SV_IsFrontFace";
718 break;
719
720 case BuiltInNumWorkgroups:
721 case BuiltInSubgroupSize:
722 case BuiltInSubgroupLocalInvocationId:
723 case BuiltInSubgroupEqMask:
724 case BuiltInSubgroupLtMask:
725 case BuiltInSubgroupLeMask:
726 case BuiltInSubgroupGtMask:
727 case BuiltInSubgroupGeMask:
728 // Handled specially.
729 break;
730
731 case BuiltInClipDistance:
732 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
733 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
734 {
735 uint32_t to_declare = clip_distance_count - clip;
736 if (to_declare > 4)
737 to_declare = 4;
738
739 uint32_t semantic_index = clip / 4;
740
741 static const char *types[] = { "float", "float2", "float3", "float4" };
742 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
743 " : SV_ClipDistance", semantic_index, ";");
744 }
745 break;
746
747 case BuiltInCullDistance:
748 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
749 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
750 {
751 uint32_t to_declare = cull_distance_count - cull;
752 if (to_declare > 4)
753 to_declare = 4;
754
755 uint32_t semantic_index = cull / 4;
756
757 static const char *types[] = { "float", "float2", "float3", "float4" };
758 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
759 " : SV_CullDistance", semantic_index, ";");
760 }
761 break;
762
763 case BuiltInPointCoord:
764 // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
765 if (hlsl_options.point_coord_compat)
766 break;
767 else
768 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
769
770 default:
771 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
772 }
773
774 if (type && semantic)
775 statement(type, " ", builtin_to_glsl(builtin, StorageClassInput), " : ", semantic, ";");
776 });
777 }
778
type_to_consumed_locations(const SPIRType & type) const779 uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
780 {
781 // TODO: Need to verify correctness.
782 uint32_t elements = 0;
783
784 if (type.basetype == SPIRType::Struct)
785 {
786 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
787 elements += type_to_consumed_locations(get<SPIRType>(type.member_types[i]));
788 }
789 else
790 {
791 uint32_t array_multiplier = 1;
792 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
793 {
794 if (type.array_size_literal[i])
795 array_multiplier *= type.array[i];
796 else
797 array_multiplier *= evaluate_constant_u32(type.array[i]);
798 }
799 elements += array_multiplier * type.columns;
800 }
801 return elements;
802 }
803
to_interpolation_qualifiers(const Bitset & flags)804 string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
805 {
806 string res;
807 //if (flags & (1ull << DecorationSmooth))
808 // res += "linear ";
809 if (flags.get(DecorationFlat))
810 res += "nointerpolation ";
811 if (flags.get(DecorationNoPerspective))
812 res += "noperspective ";
813 if (flags.get(DecorationCentroid))
814 res += "centroid ";
815 if (flags.get(DecorationPatch))
816 res += "patch "; // Seems to be different in actual HLSL.
817 if (flags.get(DecorationSample))
818 res += "sample ";
819 if (flags.get(DecorationInvariant) && backend.support_precise_qualifier)
820 res += "precise "; // Not supported?
821
822 return res;
823 }
824
to_semantic(uint32_t location,ExecutionModel em,StorageClass sc)825 std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
826 {
827 if (em == ExecutionModelVertex && sc == StorageClassInput)
828 {
829 // We have a vertex attribute - we should look at remapping it if the user provided
830 // vertex attribute hints.
831 for (auto &attribute : remap_vertex_attributes)
832 if (attribute.location == location)
833 return attribute.semantic;
834 }
835
836 // Not a vertex attribute, or no remap_vertex_attributes entry.
837 return join("TEXCOORD", location);
838 }
839
to_initializer_expression(const SPIRVariable & var)840 std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
841 {
842 // We cannot emit static const initializer for block constants for practical reasons,
843 // so just inline the initializer.
844 // FIXME: There is a theoretical problem here if someone tries to composite extract
845 // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
846 auto &type = get<SPIRType>(var.basetype);
847 bool is_block = has_decoration(type.self, DecorationBlock);
848 auto *c = maybe_get<SPIRConstant>(var.initializer);
849 if (is_block && c)
850 return constant_expression(*c);
851 else
852 return CompilerGLSL::to_initializer_expression(var);
853 }
854
emit_interface_block_member_in_struct(const SPIRVariable & var,uint32_t member_index,uint32_t location,std::unordered_set<uint32_t> & active_locations)855 void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
856 uint32_t location,
857 std::unordered_set<uint32_t> &active_locations)
858 {
859 auto &execution = get_entry_point();
860 auto type = get<SPIRType>(var.basetype);
861 auto semantic = to_semantic(location, execution.model, var.storage);
862 auto mbr_name = join(to_name(type.self), "_", to_member_name(type, member_index));
863 auto &mbr_type = get<SPIRType>(type.member_types[member_index]);
864
865 statement(to_interpolation_qualifiers(get_member_decoration_bitset(type.self, member_index)),
866 type_to_glsl(mbr_type),
867 " ", mbr_name, type_to_array_glsl(mbr_type),
868 " : ", semantic, ";");
869
870 // Structs and arrays should consume more locations.
871 uint32_t consumed_locations = type_to_consumed_locations(mbr_type);
872 for (uint32_t i = 0; i < consumed_locations; i++)
873 active_locations.insert(location + i);
874 }
875
emit_interface_block_in_struct(const SPIRVariable & var,unordered_set<uint32_t> & active_locations)876 void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
877 {
878 auto &execution = get_entry_point();
879 auto type = get<SPIRType>(var.basetype);
880
881 string binding;
882 bool use_location_number = true;
883 bool legacy = hlsl_options.shader_model <= 30;
884 if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
885 {
886 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
887 uint32_t index = get_decoration(var.self, DecorationIndex);
888 uint32_t location = get_decoration(var.self, DecorationLocation);
889
890 if (index != 0 && location != 0)
891 SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
892
893 binding = join(legacy ? "COLOR" : "SV_Target", location + index);
894 use_location_number = false;
895 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
896 type.vecsize = 4;
897 }
898
899 const auto get_vacant_location = [&]() -> uint32_t {
900 for (uint32_t i = 0; i < 64; i++)
901 if (!active_locations.count(i))
902 return i;
903 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
904 };
905
906 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
907
908 auto name = to_name(var.self);
909 if (use_location_number)
910 {
911 uint32_t location_number;
912
913 // If an explicit location exists, use it with TEXCOORD[N] semantic.
914 // Otherwise, pick a vacant location.
915 if (has_decoration(var.self, DecorationLocation))
916 location_number = get_decoration(var.self, DecorationLocation);
917 else
918 location_number = get_vacant_location();
919
920 // Allow semantic remap if specified.
921 auto semantic = to_semantic(location_number, execution.model, var.storage);
922
923 if (need_matrix_unroll && type.columns > 1)
924 {
925 if (!type.array.empty())
926 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
927
928 // Unroll matrices.
929 for (uint32_t i = 0; i < type.columns; i++)
930 {
931 SPIRType newtype = type;
932 newtype.columns = 1;
933
934 string effective_semantic;
935 if (hlsl_options.flatten_matrix_vertex_input_semantics)
936 effective_semantic = to_semantic(location_number, execution.model, var.storage);
937 else
938 effective_semantic = join(semantic, "_", i);
939
940 statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)),
941 variable_decl(newtype, join(name, "_", i)), " : ", effective_semantic, ";");
942 active_locations.insert(location_number++);
943 }
944 }
945 else
946 {
947 statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
948 semantic, ";");
949
950 // Structs and arrays should consume more locations.
951 uint32_t consumed_locations = type_to_consumed_locations(type);
952 for (uint32_t i = 0; i < consumed_locations; i++)
953 active_locations.insert(location_number + i);
954 }
955 }
956 else
957 statement(variable_decl(type, name), " : ", binding, ";");
958 }
959
builtin_to_glsl(spv::BuiltIn builtin,spv::StorageClass storage)960 std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
961 {
962 switch (builtin)
963 {
964 case BuiltInVertexId:
965 return "gl_VertexID";
966 case BuiltInInstanceId:
967 return "gl_InstanceID";
968 case BuiltInNumWorkgroups:
969 {
970 if (!num_workgroups_builtin)
971 SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
972 "Cannot emit code for this builtin.");
973
974 auto &var = get<SPIRVariable>(num_workgroups_builtin);
975 auto &type = get<SPIRType>(var.basetype);
976 auto ret = join(to_name(num_workgroups_builtin), "_", get_member_name(type.self, 0));
977 ParsedIR::sanitize_underscores(ret);
978 return ret;
979 }
980 case BuiltInPointCoord:
981 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
982 return "float2(0.5f, 0.5f)";
983 case BuiltInSubgroupLocalInvocationId:
984 return "WaveGetLaneIndex()";
985 case BuiltInSubgroupSize:
986 return "WaveGetLaneCount()";
987
988 default:
989 return CompilerGLSL::builtin_to_glsl(builtin, storage);
990 }
991 }
992
emit_builtin_variables()993 void CompilerHLSL::emit_builtin_variables()
994 {
995 Bitset builtins = active_input_builtins;
996 builtins.merge_or(active_output_builtins);
997
998 bool need_base_vertex_info = false;
999
1000 std::unordered_map<uint32_t, ID> builtin_to_initializer;
1001 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1002 if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1003 return;
1004
1005 auto *c = this->maybe_get<SPIRConstant>(var.initializer);
1006 if (!c)
1007 return;
1008
1009 auto &type = this->get<SPIRType>(var.basetype);
1010 if (type.basetype == SPIRType::Struct)
1011 {
1012 uint32_t member_count = uint32_t(type.member_types.size());
1013 for (uint32_t i = 0; i < member_count; i++)
1014 {
1015 if (has_member_decoration(type.self, i, DecorationBuiltIn))
1016 {
1017 builtin_to_initializer[get_member_decoration(type.self, i, DecorationBuiltIn)] =
1018 c->subconstants[i];
1019 }
1020 }
1021 }
1022 else if (has_decoration(var.self, DecorationBuiltIn))
1023 builtin_to_initializer[get_decoration(var.self, DecorationBuiltIn)] = var.initializer;
1024 });
1025
1026 // Emit global variables for the interface variables which are statically used by the shader.
1027 builtins.for_each_bit([&](uint32_t i) {
1028 const char *type = nullptr;
1029 auto builtin = static_cast<BuiltIn>(i);
1030 uint32_t array_size = 0;
1031
1032 string init_expr;
1033 auto init_itr = builtin_to_initializer.find(builtin);
1034 if (init_itr != builtin_to_initializer.end())
1035 init_expr = join(" = ", to_expression(init_itr->second));
1036
1037 switch (builtin)
1038 {
1039 case BuiltInFragCoord:
1040 case BuiltInPosition:
1041 type = "float4";
1042 break;
1043
1044 case BuiltInFragDepth:
1045 type = "float";
1046 break;
1047
1048 case BuiltInVertexId:
1049 case BuiltInVertexIndex:
1050 case BuiltInInstanceIndex:
1051 type = "int";
1052 if (hlsl_options.support_nonzero_base_vertex_base_instance)
1053 need_base_vertex_info = true;
1054 break;
1055
1056 case BuiltInInstanceId:
1057 case BuiltInSampleId:
1058 type = "int";
1059 break;
1060
1061 case BuiltInPointSize:
1062 if (hlsl_options.point_size_compat)
1063 {
1064 // Just emit the global variable, it will be ignored.
1065 type = "float";
1066 break;
1067 }
1068 else
1069 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1070
1071 case BuiltInGlobalInvocationId:
1072 case BuiltInLocalInvocationId:
1073 case BuiltInWorkgroupId:
1074 type = "uint3";
1075 break;
1076
1077 case BuiltInLocalInvocationIndex:
1078 type = "uint";
1079 break;
1080
1081 case BuiltInFrontFacing:
1082 type = "bool";
1083 break;
1084
1085 case BuiltInNumWorkgroups:
1086 case BuiltInPointCoord:
1087 // Handled specially.
1088 break;
1089
1090 case BuiltInSubgroupLocalInvocationId:
1091 case BuiltInSubgroupSize:
1092 if (hlsl_options.shader_model < 60)
1093 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1094 break;
1095
1096 case BuiltInSubgroupEqMask:
1097 case BuiltInSubgroupLtMask:
1098 case BuiltInSubgroupLeMask:
1099 case BuiltInSubgroupGtMask:
1100 case BuiltInSubgroupGeMask:
1101 if (hlsl_options.shader_model < 60)
1102 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1103 type = "uint4";
1104 break;
1105
1106 case BuiltInClipDistance:
1107 array_size = clip_distance_count;
1108 type = "float";
1109 break;
1110
1111 case BuiltInCullDistance:
1112 array_size = cull_distance_count;
1113 type = "float";
1114 break;
1115
1116 case BuiltInSampleMask:
1117 type = "int";
1118 break;
1119
1120 default:
1121 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1122 }
1123
1124 StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
1125
1126 if (type)
1127 {
1128 if (array_size)
1129 statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";");
1130 else
1131 statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";");
1132 }
1133
1134 // SampleMask can be both in and out with sample builtin, in this case we have already
1135 // declared the input variable and we need to add the output one now.
1136 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i))
1137 {
1138 statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";");
1139 }
1140 });
1141
1142 if (need_base_vertex_info)
1143 {
1144 statement("cbuffer SPIRV_Cross_VertexInfo");
1145 begin_scope();
1146 statement("int SPIRV_Cross_BaseVertex;");
1147 statement("int SPIRV_Cross_BaseInstance;");
1148 end_scope_decl();
1149 statement("");
1150 }
1151 }
1152
emit_composite_constants()1153 void CompilerHLSL::emit_composite_constants()
1154 {
1155 // HLSL cannot declare structs or arrays inline, so we must move them out to
1156 // global constants directly.
1157 bool emitted = false;
1158
1159 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
1160 if (c.specialization)
1161 return;
1162
1163 auto &type = this->get<SPIRType>(c.constant_type);
1164
1165 if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1166 return;
1167
1168 if (type.basetype == SPIRType::Struct || !type.array.empty())
1169 {
1170 auto name = to_name(c.self);
1171 statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1172 emitted = true;
1173 }
1174 });
1175
1176 if (emitted)
1177 statement("");
1178 }
1179
emit_specialization_constants_and_structs()1180 void CompilerHLSL::emit_specialization_constants_and_structs()
1181 {
1182 bool emitted = false;
1183 SpecializationConstant wg_x, wg_y, wg_z;
1184 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
1185
1186 std::unordered_set<TypeID> io_block_types;
1187 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, const SPIRVariable &var) {
1188 auto &type = this->get<SPIRType>(var.basetype);
1189 if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1190 !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1191 interface_variable_exists_in_entry_point(var.self) &&
1192 has_decoration(type.self, DecorationBlock))
1193 {
1194 io_block_types.insert(type.self);
1195 }
1196 });
1197
1198 auto loop_lock = ir.create_loop_hard_lock();
1199 for (auto &id_ : ir.ids_for_constant_or_type)
1200 {
1201 auto &id = ir.ids[id_];
1202
1203 if (id.get_type() == TypeConstant)
1204 {
1205 auto &c = id.get<SPIRConstant>();
1206
1207 if (c.self == workgroup_size_id)
1208 {
1209 statement("static const uint3 gl_WorkGroupSize = ",
1210 constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
1211 emitted = true;
1212 }
1213 else if (c.specialization)
1214 {
1215 auto &type = get<SPIRType>(c.constant_type);
1216 auto name = to_name(c.self);
1217
1218 // HLSL does not support specialization constants, so fallback to macros.
1219 c.specialization_constant_macro_name =
1220 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
1221
1222 statement("#ifndef ", c.specialization_constant_macro_name);
1223 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
1224 statement("#endif");
1225 statement("static const ", variable_decl(type, name), " = ", c.specialization_constant_macro_name, ";");
1226 emitted = true;
1227 }
1228 }
1229 else if (id.get_type() == TypeConstantOp)
1230 {
1231 auto &c = id.get<SPIRConstantOp>();
1232 auto &type = get<SPIRType>(c.basetype);
1233 auto name = to_name(c.self);
1234 statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
1235 emitted = true;
1236 }
1237 else if (id.get_type() == TypeType)
1238 {
1239 auto &type = id.get<SPIRType>();
1240 bool is_non_io_block = has_decoration(type.self, DecorationBlock) &&
1241 io_block_types.count(type.self) == 0;
1242 bool is_buffer_block = has_decoration(type.self, DecorationBufferBlock);
1243 if (type.basetype == SPIRType::Struct && type.array.empty() &&
1244 !type.pointer && !is_non_io_block && !is_buffer_block)
1245 {
1246 if (emitted)
1247 statement("");
1248 emitted = false;
1249
1250 emit_struct(type);
1251 }
1252 }
1253 }
1254
1255 if (emitted)
1256 statement("");
1257 }
1258
replace_illegal_names()1259 void CompilerHLSL::replace_illegal_names()
1260 {
1261 static const unordered_set<string> keywords = {
1262 // Additional HLSL specific keywords.
1263 "line", "linear", "matrix", "point", "row_major", "sampler", "vector"
1264 };
1265
1266 CompilerGLSL::replace_illegal_names(keywords);
1267 CompilerGLSL::replace_illegal_names();
1268 }
1269
declare_undefined_values()1270 void CompilerHLSL::declare_undefined_values()
1271 {
1272 bool emitted = false;
1273 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
1274 auto &type = this->get<SPIRType>(undef.basetype);
1275 // OpUndef can be void for some reason ...
1276 if (type.basetype == SPIRType::Void)
1277 return;
1278
1279 string initializer;
1280 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1281 initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
1282
1283 statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
1284 emitted = true;
1285 });
1286
1287 if (emitted)
1288 statement("");
1289 }
1290
emit_resources()1291 void CompilerHLSL::emit_resources()
1292 {
1293 auto &execution = get_entry_point();
1294
1295 replace_illegal_names();
1296
1297 emit_specialization_constants_and_structs();
1298 emit_composite_constants();
1299
1300 bool emitted = false;
1301
1302 // Output UBOs and SSBOs
1303 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1304 auto &type = this->get<SPIRType>(var.basetype);
1305
1306 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1307 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
1308 ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
1309
1310 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1311 has_block_flags)
1312 {
1313 emit_buffer_block(var);
1314 emitted = true;
1315 }
1316 });
1317
1318 // Output push constant blocks
1319 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1320 auto &type = this->get<SPIRType>(var.basetype);
1321 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1322 !is_hidden_variable(var))
1323 {
1324 emit_push_constant_block(var);
1325 emitted = true;
1326 }
1327 });
1328
1329 if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30)
1330 {
1331 statement("uniform float4 gl_HalfPixel;");
1332 emitted = true;
1333 }
1334
1335 bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1336
1337 // Output Uniform Constants (values, samplers, images, etc).
1338 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1339 auto &type = this->get<SPIRType>(var.basetype);
1340
1341 // If we're remapping separate samplers and images, only emit the combined samplers.
1342 if (skip_separate_image_sampler)
1343 {
1344 // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1345 bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1346 bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1347 bool separate_sampler = type.basetype == SPIRType::Sampler;
1348 if (!sampler_buffer && (separate_image || separate_sampler))
1349 return;
1350 }
1351
1352 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1353 type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1354 !is_hidden_variable(var))
1355 {
1356 emit_uniform(var);
1357 emitted = true;
1358 }
1359 });
1360
1361 if (emitted)
1362 statement("");
1363 emitted = false;
1364
1365 // Emit builtin input and output variables here.
1366 emit_builtin_variables();
1367
1368 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1369 auto &type = this->get<SPIRType>(var.basetype);
1370
1371 if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1372 (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1373 interface_variable_exists_in_entry_point(var.self))
1374 {
1375 // Builtin variables are handled separately.
1376 emit_interface_block_globally(var);
1377 emitted = true;
1378 }
1379 });
1380
1381 if (emitted)
1382 statement("");
1383 emitted = false;
1384
1385 require_input = false;
1386 require_output = false;
1387 unordered_set<uint32_t> active_inputs;
1388 unordered_set<uint32_t> active_outputs;
1389
1390 struct IOVariable
1391 {
1392 const SPIRVariable *var;
1393 uint32_t location;
1394 uint32_t block_member_index;
1395 bool block;
1396 };
1397
1398 SmallVector<IOVariable> input_variables;
1399 SmallVector<IOVariable> output_variables;
1400
1401 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1402 auto &type = this->get<SPIRType>(var.basetype);
1403 bool block = has_decoration(type.self, DecorationBlock);
1404
1405 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1406 return;
1407
1408 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1409 interface_variable_exists_in_entry_point(var.self))
1410 {
1411 if (block)
1412 {
1413 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1414 {
1415 uint32_t location = get_declared_member_location(var, i, false);
1416 if (var.storage == StorageClassInput)
1417 input_variables.push_back({ &var, location, i, true });
1418 else
1419 output_variables.push_back({ &var, location, i, true });
1420 }
1421 }
1422 else
1423 {
1424 uint32_t location = get_decoration(var.self, DecorationLocation);
1425 if (var.storage == StorageClassInput)
1426 input_variables.push_back({ &var, location, 0, false });
1427 else
1428 output_variables.push_back({ &var, location, 0, false });
1429 }
1430 }
1431 });
1432
1433 const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1434 // Sort input and output variables based on, from more robust to less robust:
1435 // - Location
1436 // - Variable has a location
1437 // - Name comparison
1438 // - Variable has a name
1439 // - Fallback: ID
1440 bool has_location_a = a.block || has_decoration(a.var->self, DecorationLocation);
1441 bool has_location_b = b.block || has_decoration(b.var->self, DecorationLocation);
1442
1443 if (has_location_a && has_location_b)
1444 return a.location < b.location;
1445 else if (has_location_a && !has_location_b)
1446 return true;
1447 else if (!has_location_a && has_location_b)
1448 return false;
1449
1450 const auto &name1 = to_name(a.var->self);
1451 const auto &name2 = to_name(b.var->self);
1452
1453 if (name1.empty() && name2.empty())
1454 return a.var->self < b.var->self;
1455 else if (name1.empty())
1456 return true;
1457 else if (name2.empty())
1458 return false;
1459
1460 return name1.compare(name2) < 0;
1461 };
1462
1463 auto input_builtins = active_input_builtins;
1464 input_builtins.clear(BuiltInNumWorkgroups);
1465 input_builtins.clear(BuiltInPointCoord);
1466 input_builtins.clear(BuiltInSubgroupSize);
1467 input_builtins.clear(BuiltInSubgroupLocalInvocationId);
1468 input_builtins.clear(BuiltInSubgroupEqMask);
1469 input_builtins.clear(BuiltInSubgroupLtMask);
1470 input_builtins.clear(BuiltInSubgroupLeMask);
1471 input_builtins.clear(BuiltInSubgroupGtMask);
1472 input_builtins.clear(BuiltInSubgroupGeMask);
1473
1474 if (!input_variables.empty() || !input_builtins.empty())
1475 {
1476 require_input = true;
1477 statement("struct SPIRV_Cross_Input");
1478
1479 begin_scope();
1480 sort(input_variables.begin(), input_variables.end(), variable_compare);
1481 for (auto &var : input_variables)
1482 {
1483 if (var.block)
1484 emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_inputs);
1485 else
1486 emit_interface_block_in_struct(*var.var, active_inputs);
1487 }
1488 emit_builtin_inputs_in_struct();
1489 end_scope_decl();
1490 statement("");
1491 }
1492
1493 if (!output_variables.empty() || !active_output_builtins.empty())
1494 {
1495 require_output = true;
1496 statement("struct SPIRV_Cross_Output");
1497
1498 begin_scope();
1499 sort(output_variables.begin(), output_variables.end(), variable_compare);
1500 for (auto &var : output_variables)
1501 {
1502 if (var.block)
1503 emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_outputs);
1504 else
1505 emit_interface_block_in_struct(*var.var, active_outputs);
1506 }
1507 emit_builtin_outputs_in_struct();
1508 end_scope_decl();
1509 statement("");
1510 }
1511
1512 // Global variables.
1513 for (auto global : global_variables)
1514 {
1515 auto &var = get<SPIRVariable>(global);
1516 if (is_hidden_variable(var, true))
1517 continue;
1518
1519 if (var.storage != StorageClassOutput)
1520 {
1521 if (!variable_is_lut(var))
1522 {
1523 add_resource_name(var.self);
1524
1525 const char *storage = nullptr;
1526 switch (var.storage)
1527 {
1528 case StorageClassWorkgroup:
1529 storage = "groupshared";
1530 break;
1531
1532 default:
1533 storage = "static";
1534 break;
1535 }
1536
1537 string initializer;
1538 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1539 !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var)))
1540 {
1541 initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var)));
1542 }
1543 statement(storage, " ", variable_decl(var), initializer, ";");
1544
1545 emitted = true;
1546 }
1547 }
1548 }
1549
1550 if (emitted)
1551 statement("");
1552
1553 declare_undefined_values();
1554
1555 if (requires_op_fmod)
1556 {
1557 static const char *types[] = {
1558 "float",
1559 "float2",
1560 "float3",
1561 "float4",
1562 };
1563
1564 for (auto &type : types)
1565 {
1566 statement(type, " mod(", type, " x, ", type, " y)");
1567 begin_scope();
1568 statement("return x - y * floor(x / y);");
1569 end_scope();
1570 statement("");
1571 }
1572 }
1573
1574 emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
1575 for (uint32_t norm = 0; norm < 3; norm++)
1576 {
1577 for (uint32_t comp = 0; comp < 4; comp++)
1578 {
1579 static const char *qualifiers[] = { "", "unorm ", "snorm " };
1580 static const char *vecsizes[] = { "", "2", "3", "4" };
1581 emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true,
1582 qualifiers[norm]);
1583 }
1584 }
1585
1586 if (requires_fp16_packing)
1587 {
1588 // HLSL does not pack into a single word sadly :(
1589 statement("uint spvPackHalf2x16(float2 value)");
1590 begin_scope();
1591 statement("uint2 Packed = f32tof16(value);");
1592 statement("return Packed.x | (Packed.y << 16);");
1593 end_scope();
1594 statement("");
1595
1596 statement("float2 spvUnpackHalf2x16(uint value)");
1597 begin_scope();
1598 statement("return f16tof32(uint2(value & 0xffff, value >> 16));");
1599 end_scope();
1600 statement("");
1601 }
1602
1603 if (requires_uint2_packing)
1604 {
1605 statement("uint64_t spvPackUint2x32(uint2 value)");
1606 begin_scope();
1607 statement("return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1608 end_scope();
1609 statement("");
1610
1611 statement("uint2 spvUnpackUint2x32(uint64_t value)");
1612 begin_scope();
1613 statement("uint2 Unpacked;");
1614 statement("Unpacked.x = uint(value & 0xffffffff);");
1615 statement("Unpacked.y = uint(value >> 32);");
1616 statement("return Unpacked;");
1617 end_scope();
1618 statement("");
1619 }
1620
1621 if (requires_explicit_fp16_packing)
1622 {
1623 // HLSL does not pack into a single word sadly :(
1624 statement("uint spvPackFloat2x16(min16float2 value)");
1625 begin_scope();
1626 statement("uint2 Packed = f32tof16(value);");
1627 statement("return Packed.x | (Packed.y << 16);");
1628 end_scope();
1629 statement("");
1630
1631 statement("min16float2 spvUnpackFloat2x16(uint value)");
1632 begin_scope();
1633 statement("return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1634 end_scope();
1635 statement("");
1636 }
1637
1638 // HLSL does not seem to have builtins for these operation, so roll them by hand ...
1639 if (requires_unorm8_packing)
1640 {
1641 statement("uint spvPackUnorm4x8(float4 value)");
1642 begin_scope();
1643 statement("uint4 Packed = uint4(round(saturate(value) * 255.0));");
1644 statement("return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1645 end_scope();
1646 statement("");
1647
1648 statement("float4 spvUnpackUnorm4x8(uint value)");
1649 begin_scope();
1650 statement("uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1651 statement("return float4(Packed) / 255.0;");
1652 end_scope();
1653 statement("");
1654 }
1655
1656 if (requires_snorm8_packing)
1657 {
1658 statement("uint spvPackSnorm4x8(float4 value)");
1659 begin_scope();
1660 statement("int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1661 statement("return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1662 end_scope();
1663 statement("");
1664
1665 statement("float4 spvUnpackSnorm4x8(uint value)");
1666 begin_scope();
1667 statement("int SignedValue = int(value);");
1668 statement("int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1669 statement("return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1670 end_scope();
1671 statement("");
1672 }
1673
1674 if (requires_unorm16_packing)
1675 {
1676 statement("uint spvPackUnorm2x16(float2 value)");
1677 begin_scope();
1678 statement("uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1679 statement("return Packed.x | (Packed.y << 16);");
1680 end_scope();
1681 statement("");
1682
1683 statement("float2 spvUnpackUnorm2x16(uint value)");
1684 begin_scope();
1685 statement("uint2 Packed = uint2(value & 0xffff, value >> 16);");
1686 statement("return float2(Packed) / 65535.0;");
1687 end_scope();
1688 statement("");
1689 }
1690
1691 if (requires_snorm16_packing)
1692 {
1693 statement("uint spvPackSnorm2x16(float2 value)");
1694 begin_scope();
1695 statement("int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1696 statement("return uint(Packed.x | (Packed.y << 16));");
1697 end_scope();
1698 statement("");
1699
1700 statement("float2 spvUnpackSnorm2x16(uint value)");
1701 begin_scope();
1702 statement("int SignedValue = int(value);");
1703 statement("int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1704 statement("return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1705 end_scope();
1706 statement("");
1707 }
1708
1709 if (requires_bitfield_insert)
1710 {
1711 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1712 for (auto &type : types)
1713 {
1714 statement(type, " spvBitfieldInsert(", type, " Base, ", type, " Insert, uint Offset, uint Count)");
1715 begin_scope();
1716 statement("uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1717 statement("return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1718 end_scope();
1719 statement("");
1720 }
1721 }
1722
1723 if (requires_bitfield_extract)
1724 {
1725 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1726 for (auto &type : unsigned_types)
1727 {
1728 statement(type, " spvBitfieldUExtract(", type, " Base, uint Offset, uint Count)");
1729 begin_scope();
1730 statement("uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1731 statement("return (Base >> Offset) & Mask;");
1732 end_scope();
1733 statement("");
1734 }
1735
1736 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1737 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1738 for (auto &type : signed_types)
1739 {
1740 statement(type, " spvBitfieldSExtract(", type, " Base, int Offset, int Count)");
1741 begin_scope();
1742 statement("int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1743 statement(type, " Masked = (Base >> Offset) & Mask;");
1744 statement("int ExtendShift = (32 - Count) & 31;");
1745 statement("return (Masked << ExtendShift) >> ExtendShift;");
1746 end_scope();
1747 statement("");
1748 }
1749 }
1750
1751 if (requires_inverse_2x2)
1752 {
1753 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1754 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1755 statement("float2x2 spvInverse(float2x2 m)");
1756 begin_scope();
1757 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
1758 statement_no_indent("");
1759 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1760 statement("adj[0][0] = m[1][1];");
1761 statement("adj[0][1] = -m[0][1];");
1762 statement_no_indent("");
1763 statement("adj[1][0] = -m[1][0];");
1764 statement("adj[1][1] = m[0][0];");
1765 statement_no_indent("");
1766 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1767 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1768 statement_no_indent("");
1769 statement("// Divide the classical adjoint matrix by the determinant.");
1770 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1771 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1772 end_scope();
1773 statement("");
1774 }
1775
1776 if (requires_inverse_3x3)
1777 {
1778 statement("// Returns the determinant of a 2x2 matrix.");
1779 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1780 begin_scope();
1781 statement("return a1 * b2 - b1 * a2;");
1782 end_scope();
1783 statement_no_indent("");
1784 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1785 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1786 statement("float3x3 spvInverse(float3x3 m)");
1787 begin_scope();
1788 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
1789 statement_no_indent("");
1790 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1791 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1792 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1793 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1794 statement_no_indent("");
1795 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1796 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1797 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1798 statement_no_indent("");
1799 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1800 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1801 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1802 statement_no_indent("");
1803 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1804 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1805 statement_no_indent("");
1806 statement("// Divide the classical adjoint matrix by the determinant.");
1807 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1808 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1809 end_scope();
1810 statement("");
1811 }
1812
1813 if (requires_inverse_4x4)
1814 {
1815 if (!requires_inverse_3x3)
1816 {
1817 statement("// Returns the determinant of a 2x2 matrix.");
1818 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1819 begin_scope();
1820 statement("return a1 * b2 - b1 * a2;");
1821 end_scope();
1822 statement("");
1823 }
1824
1825 statement("// Returns the determinant of a 3x3 matrix.");
1826 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1827 "float c2, float c3)");
1828 begin_scope();
1829 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1830 "spvDet2x2(a2, a3, "
1831 "b2, b3);");
1832 end_scope();
1833 statement_no_indent("");
1834 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1835 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1836 statement("float4x4 spvInverse(float4x4 m)");
1837 begin_scope();
1838 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
1839 statement_no_indent("");
1840 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1841 statement(
1842 "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1843 "m[3][3]);");
1844 statement(
1845 "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1846 "m[3][3]);");
1847 statement(
1848 "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1849 "m[3][3]);");
1850 statement(
1851 "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1852 "m[2][3]);");
1853 statement_no_indent("");
1854 statement(
1855 "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1856 "m[3][3]);");
1857 statement(
1858 "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1859 "m[3][3]);");
1860 statement(
1861 "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1862 "m[3][3]);");
1863 statement(
1864 "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1865 "m[2][3]);");
1866 statement_no_indent("");
1867 statement(
1868 "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1869 "m[3][3]);");
1870 statement(
1871 "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1872 "m[3][3]);");
1873 statement(
1874 "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1875 "m[3][3]);");
1876 statement(
1877 "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1878 "m[2][3]);");
1879 statement_no_indent("");
1880 statement(
1881 "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1882 "m[3][2]);");
1883 statement(
1884 "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1885 "m[3][2]);");
1886 statement(
1887 "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1888 "m[3][2]);");
1889 statement(
1890 "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1891 "m[2][2]);");
1892 statement_no_indent("");
1893 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1894 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1895 "* m[3][0]);");
1896 statement_no_indent("");
1897 statement("// Divide the classical adjoint matrix by the determinant.");
1898 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1899 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1900 end_scope();
1901 statement("");
1902 }
1903
1904 if (requires_scalar_reflect)
1905 {
1906 // FP16/FP64? No templates in HLSL.
1907 statement("float spvReflect(float i, float n)");
1908 begin_scope();
1909 statement("return i - 2.0 * dot(n, i) * n;");
1910 end_scope();
1911 statement("");
1912 }
1913
1914 if (requires_scalar_refract)
1915 {
1916 // FP16/FP64? No templates in HLSL.
1917 statement("float spvRefract(float i, float n, float eta)");
1918 begin_scope();
1919 statement("float NoI = n * i;");
1920 statement("float NoI2 = NoI * NoI;");
1921 statement("float k = 1.0 - eta * eta * (1.0 - NoI2);");
1922 statement("if (k < 0.0)");
1923 begin_scope();
1924 statement("return 0.0;");
1925 end_scope();
1926 statement("else");
1927 begin_scope();
1928 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
1929 end_scope();
1930 end_scope();
1931 statement("");
1932 }
1933
1934 if (requires_scalar_faceforward)
1935 {
1936 // FP16/FP64? No templates in HLSL.
1937 statement("float spvFaceForward(float n, float i, float nref)");
1938 begin_scope();
1939 statement("return i * nref < 0.0 ? n : -n;");
1940 end_scope();
1941 statement("");
1942 }
1943
1944 for (TypeID type_id : composite_selection_workaround_types)
1945 {
1946 // Need out variable since HLSL does not support returning arrays.
1947 auto &type = get<SPIRType>(type_id);
1948 auto type_str = type_to_glsl(type);
1949 auto type_arr_str = type_to_array_glsl(type);
1950 statement("void spvSelectComposite(out ", type_str, " out_value", type_arr_str, ", bool cond, ",
1951 type_str, " true_val", type_arr_str, ", ",
1952 type_str, " false_val", type_arr_str, ")");
1953 begin_scope();
1954 statement("if (cond)");
1955 begin_scope();
1956 statement("out_value = true_val;");
1957 end_scope();
1958 statement("else");
1959 begin_scope();
1960 statement("out_value = false_val;");
1961 end_scope();
1962 end_scope();
1963 statement("");
1964 }
1965 }
1966
emit_texture_size_variants(uint64_t variant_mask,const char * vecsize_qualifier,bool uav,const char * type_qualifier)1967 void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
1968 const char *type_qualifier)
1969 {
1970 if (variant_mask == 0)
1971 return;
1972
1973 static const char *types[QueryTypeCount] = { "float", "int", "uint" };
1974 static const char *dims[QueryDimCount] = { "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray",
1975 "Texture3D", "Buffer", "TextureCube", "TextureCubeArray",
1976 "Texture2DMS", "Texture2DMSArray" };
1977
1978 static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
1979
1980 static const char *ret_types[QueryDimCount] = {
1981 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
1982 };
1983
1984 static const uint32_t return_arguments[QueryDimCount] = {
1985 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
1986 };
1987
1988 for (uint32_t index = 0; index < QueryDimCount; index++)
1989 {
1990 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
1991 {
1992 uint32_t bit = 16 * type_index + index;
1993 uint64_t mask = 1ull << bit;
1994
1995 if ((variant_mask & mask) == 0)
1996 continue;
1997
1998 statement(ret_types[index], " spv", (uav ? "Image" : "Texture"), "Size(", (uav ? "RW" : ""),
1999 dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier, "> Tex, ",
2000 (uav ? "" : "uint Level, "), "out uint Param)");
2001 begin_scope();
2002 statement(ret_types[index], " ret;");
2003 switch (return_arguments[index])
2004 {
2005 case 1:
2006 if (has_lod[index] && !uav)
2007 statement("Tex.GetDimensions(Level, ret.x, Param);");
2008 else
2009 {
2010 statement("Tex.GetDimensions(ret.x);");
2011 statement("Param = 0u;");
2012 }
2013 break;
2014 case 2:
2015 if (has_lod[index] && !uav)
2016 statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2017 else if (!uav)
2018 statement("Tex.GetDimensions(ret.x, ret.y, Param);");
2019 else
2020 {
2021 statement("Tex.GetDimensions(ret.x, ret.y);");
2022 statement("Param = 0u;");
2023 }
2024 break;
2025 case 3:
2026 if (has_lod[index] && !uav)
2027 statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2028 else if (!uav)
2029 statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2030 else
2031 {
2032 statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
2033 statement("Param = 0u;");
2034 }
2035 break;
2036 }
2037
2038 statement("return ret;");
2039 end_scope();
2040 statement("");
2041 }
2042 }
2043 }
2044
layout_for_member(const SPIRType & type,uint32_t index)2045 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2046 {
2047 auto &flags = get_member_decoration_bitset(type.self, index);
2048
2049 // HLSL can emit row_major or column_major decoration in any struct.
2050 // Do not try to merge combined decorations for children like in GLSL.
2051
2052 // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2053 // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2054 if (flags.get(DecorationColMajor))
2055 return "row_major ";
2056 else if (flags.get(DecorationRowMajor))
2057 return "column_major ";
2058
2059 return "";
2060 }
2061
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t base_offset)2062 void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2063 const string &qualifier, uint32_t base_offset)
2064 {
2065 auto &membertype = get<SPIRType>(member_type_id);
2066
2067 Bitset memberflags;
2068 auto &memb = ir.meta[type.self].members;
2069 if (index < memb.size())
2070 memberflags = memb[index].decoration_flags;
2071
2072 string packing_offset;
2073 bool is_push_constant = type.storage == StorageClassPushConstant;
2074
2075 if ((has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2076 has_member_decoration(type.self, index, DecorationOffset))
2077 {
2078 uint32_t offset = memb[index].offset - base_offset;
2079 if (offset & 3)
2080 SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2081
2082 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2083 packing_offset = join(" : packoffset(c", offset / 16, packing_swizzle[(offset & 15) >> 2], ")");
2084 }
2085
2086 statement(layout_for_member(type, index), qualifier,
2087 variable_decl(membertype, to_member_name(type, index)), packing_offset, ";");
2088 }
2089
emit_buffer_block(const SPIRVariable & var)2090 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2091 {
2092 auto &type = get<SPIRType>(var.basetype);
2093
2094 bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock);
2095
2096 if (is_uav)
2097 {
2098 Bitset flags = ir.get_buffer_block_flags(var);
2099 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
2100 bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
2101 bool is_interlocked = interlocked_resources.count(var.self) > 0;
2102 const char *type_name = "ByteAddressBuffer ";
2103 if (!is_readonly)
2104 type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2105 add_resource_name(var.self);
2106 statement(is_coherent ? "globallycoherent " : "", type_name, to_name(var.self), type_to_array_glsl(type),
2107 to_resource_binding(var), ";");
2108 }
2109 else
2110 {
2111 if (type.array.empty())
2112 {
2113 // Flatten the top-level struct so we can use packoffset,
2114 // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2115 flattened_structs[var.self] = false;
2116
2117 // Prefer the block name if possible.
2118 auto buffer_name = to_name(type.self, false);
2119 if (ir.meta[type.self].decoration.alias.empty() ||
2120 resource_names.find(buffer_name) != end(resource_names) ||
2121 block_names.find(buffer_name) != end(block_names))
2122 {
2123 buffer_name = get_block_fallback_name(var.self);
2124 }
2125
2126 add_variable(block_names, resource_names, buffer_name);
2127
2128 // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2129 // This cannot conflict with anything else, so we're safe now.
2130 if (buffer_name.empty())
2131 buffer_name = join("_", get<SPIRType>(var.basetype).self, "_", var.self);
2132
2133 uint32_t failed_index = 0;
2134 if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index))
2135 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2136 else
2137 {
2138 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2139 failed_index, " (name: ", to_member_name(type, failed_index),
2140 ") cannot be expressed with either HLSL packing layout or packoffset."));
2141 }
2142
2143 block_names.insert(buffer_name);
2144
2145 // Save for post-reflection later.
2146 declared_block_names[var.self] = buffer_name;
2147
2148 type.member_name_cache.clear();
2149 // var.self can be used as a backup name for the block name,
2150 // so we need to make sure we don't disturb the name here on a recompile.
2151 // It will need to be reset if we have to recompile.
2152 preserve_alias_on_reset(var.self);
2153 add_resource_name(var.self);
2154 statement("cbuffer ", buffer_name, to_resource_binding(var));
2155 begin_scope();
2156
2157 uint32_t i = 0;
2158 for (auto &member : type.member_types)
2159 {
2160 add_member_name(type, i);
2161 auto backup_name = get_member_name(type.self, i);
2162 auto member_name = to_member_name(type, i);
2163 member_name = join(to_name(var.self), "_", member_name);
2164 ParsedIR::sanitize_underscores(member_name);
2165 set_member_name(type.self, i, member_name);
2166 emit_struct_member(type, member, i, "");
2167 set_member_name(type.self, i, backup_name);
2168 i++;
2169 }
2170
2171 end_scope_decl();
2172 statement("");
2173 }
2174 else
2175 {
2176 if (hlsl_options.shader_model < 51)
2177 SPIRV_CROSS_THROW(
2178 "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2179
2180 add_resource_name(type.self);
2181 add_resource_name(var.self);
2182
2183 // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2184 uint32_t failed_index = 0;
2185 if (!buffer_is_packing_standard(type, BufferPackingHLSLCbuffer, &failed_index))
2186 {
2187 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2188 "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2189 ") cannot be expressed with normal HLSL packing rules."));
2190 }
2191
2192 emit_struct(get<SPIRType>(type.self));
2193 statement("ConstantBuffer<", to_name(type.self), "> ", to_name(var.self), type_to_array_glsl(type),
2194 to_resource_binding(var), ";");
2195 }
2196 }
2197 }
2198
emit_push_constant_block(const SPIRVariable & var)2199 void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2200 {
2201 if (root_constants_layout.empty())
2202 {
2203 emit_buffer_block(var);
2204 }
2205 else
2206 {
2207 for (const auto &layout : root_constants_layout)
2208 {
2209 auto &type = get<SPIRType>(var.basetype);
2210
2211 uint32_t failed_index = 0;
2212 if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index, layout.start,
2213 layout.end))
2214 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2215 else
2216 {
2217 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2218 ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2219 ") cannot be expressed with either HLSL packing layout or packoffset."));
2220 }
2221
2222 flattened_structs[var.self] = false;
2223 type.member_name_cache.clear();
2224 add_resource_name(var.self);
2225 auto &memb = ir.meta[type.self].members;
2226
2227 statement("cbuffer SPIRV_CROSS_RootConstant_", to_name(var.self),
2228 to_resource_register(HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, 'b', layout.binding, layout.space));
2229 begin_scope();
2230
2231 // Index of the next field in the generated root constant constant buffer
2232 auto constant_index = 0u;
2233
2234 // Iterate over all member of the push constant and check which of the fields
2235 // fit into the given root constant layout.
2236 for (auto i = 0u; i < memb.size(); i++)
2237 {
2238 const auto offset = memb[i].offset;
2239 if (layout.start <= offset && offset < layout.end)
2240 {
2241 const auto &member = type.member_types[i];
2242
2243 add_member_name(type, constant_index);
2244 auto backup_name = get_member_name(type.self, i);
2245 auto member_name = to_member_name(type, i);
2246 member_name = join(to_name(var.self), "_", member_name);
2247 ParsedIR::sanitize_underscores(member_name);
2248 set_member_name(type.self, constant_index, member_name);
2249 emit_struct_member(type, member, i, "", layout.start);
2250 set_member_name(type.self, constant_index, backup_name);
2251
2252 constant_index++;
2253 }
2254 }
2255
2256 end_scope_decl();
2257 }
2258 }
2259 }
2260
to_sampler_expression(uint32_t id)2261 string CompilerHLSL::to_sampler_expression(uint32_t id)
2262 {
2263 auto expr = join("_", to_non_uniform_aware_expression(id));
2264 auto index = expr.find_first_of('[');
2265 if (index == string::npos)
2266 {
2267 return expr + "_sampler";
2268 }
2269 else
2270 {
2271 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2272 return expr.insert(index, "_sampler");
2273 }
2274 }
2275
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)2276 void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2277 {
2278 if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2279 {
2280 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
2281 }
2282 else
2283 {
2284 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2285 emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true);
2286 }
2287 }
2288
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)2289 string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2290 {
2291 string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2292
2293 if (hlsl_options.shader_model <= 30)
2294 return arg_str;
2295
2296 // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2297 auto &type = expression_type(id);
2298
2299 // We don't have to consider combined image samplers here via OpSampledImage because
2300 // those variables cannot be passed as arguments to functions.
2301 // Only global SampledImage variables may be used as arguments.
2302 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2303 arg_str += ", " + to_sampler_expression(id);
2304
2305 return arg_str;
2306 }
2307
emit_function_prototype(SPIRFunction & func,const Bitset & return_flags)2308 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2309 {
2310 if (func.self != ir.default_entry_point)
2311 add_function_overload(func);
2312
2313 auto &execution = get_entry_point();
2314 // Avoid shadow declarations.
2315 local_variable_names = resource_names;
2316
2317 string decl;
2318
2319 auto &type = get<SPIRType>(func.return_type);
2320 if (type.array.empty())
2321 {
2322 decl += flags_to_qualifiers_glsl(type, return_flags);
2323 decl += type_to_glsl(type);
2324 decl += " ";
2325 }
2326 else
2327 {
2328 // We cannot return arrays in HLSL, so "return" through an out variable.
2329 decl = "void ";
2330 }
2331
2332 if (func.self == ir.default_entry_point)
2333 {
2334 if (execution.model == ExecutionModelVertex)
2335 decl += "vert_main";
2336 else if (execution.model == ExecutionModelFragment)
2337 decl += "frag_main";
2338 else if (execution.model == ExecutionModelGLCompute)
2339 decl += "comp_main";
2340 else
2341 SPIRV_CROSS_THROW("Unsupported execution model.");
2342 processing_entry_point = true;
2343 }
2344 else
2345 decl += to_name(func.self);
2346
2347 decl += "(";
2348 SmallVector<string> arglist;
2349
2350 if (!type.array.empty())
2351 {
2352 // Fake array returns by writing to an out array instead.
2353 string out_argument;
2354 out_argument += "out ";
2355 out_argument += type_to_glsl(type);
2356 out_argument += " ";
2357 out_argument += "spvReturnValue";
2358 out_argument += type_to_array_glsl(type);
2359 arglist.push_back(move(out_argument));
2360 }
2361
2362 for (auto &arg : func.arguments)
2363 {
2364 // Do not pass in separate images or samplers if we're remapping
2365 // to combined image samplers.
2366 if (skip_argument(arg.id))
2367 continue;
2368
2369 // Might change the variable name if it already exists in this function.
2370 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2371 // to use same name for variables.
2372 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2373 add_local_variable_name(arg.id);
2374
2375 arglist.push_back(argument_decl(arg));
2376
2377 // Flatten a combined sampler to two separate arguments in modern HLSL.
2378 auto &arg_type = get<SPIRType>(arg.type);
2379 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2380 arg_type.image.dim != DimBuffer)
2381 {
2382 // Manufacture automatic sampler arg for SampledImage texture
2383 arglist.push_back(join(image_is_comparison(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
2384 to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
2385 }
2386
2387 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2388 auto *var = maybe_get<SPIRVariable>(arg.id);
2389 if (var)
2390 var->parameter = &arg;
2391 }
2392
2393 for (auto &arg : func.shadow_arguments)
2394 {
2395 // Might change the variable name if it already exists in this function.
2396 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2397 // to use same name for variables.
2398 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2399 add_local_variable_name(arg.id);
2400
2401 arglist.push_back(argument_decl(arg));
2402
2403 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2404 auto *var = maybe_get<SPIRVariable>(arg.id);
2405 if (var)
2406 var->parameter = &arg;
2407 }
2408
2409 decl += merge(arglist);
2410 decl += ")";
2411 statement(decl);
2412 }
2413
emit_hlsl_entry_point()2414 void CompilerHLSL::emit_hlsl_entry_point()
2415 {
2416 SmallVector<string> arguments;
2417
2418 if (require_input)
2419 arguments.push_back("SPIRV_Cross_Input stage_input");
2420
2421 auto &execution = get_entry_point();
2422
2423 switch (execution.model)
2424 {
2425 case ExecutionModelGLCompute:
2426 {
2427 SpecializationConstant wg_x, wg_y, wg_z;
2428 get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
2429
2430 uint32_t x = execution.workgroup_size.x;
2431 uint32_t y = execution.workgroup_size.y;
2432 uint32_t z = execution.workgroup_size.z;
2433
2434 auto x_expr = wg_x.id ? get<SPIRConstant>(wg_x.id).specialization_constant_macro_name : to_string(x);
2435 auto y_expr = wg_y.id ? get<SPIRConstant>(wg_y.id).specialization_constant_macro_name : to_string(y);
2436 auto z_expr = wg_z.id ? get<SPIRConstant>(wg_z.id).specialization_constant_macro_name : to_string(z);
2437
2438 statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]");
2439 break;
2440 }
2441 case ExecutionModelFragment:
2442 if (execution.flags.get(ExecutionModeEarlyFragmentTests))
2443 statement("[earlydepthstencil]");
2444 break;
2445 default:
2446 break;
2447 }
2448
2449 statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")");
2450 begin_scope();
2451 bool legacy = hlsl_options.shader_model <= 30;
2452
2453 // Copy builtins from entry point arguments to globals.
2454 active_input_builtins.for_each_bit([&](uint32_t i) {
2455 auto builtin = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassInput);
2456 switch (static_cast<BuiltIn>(i))
2457 {
2458 case BuiltInFragCoord:
2459 // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2460 // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2461 // on a D3D10+ system with a different rasterization config?
2462 if (legacy)
2463 statement(builtin, " = stage_input.", builtin, " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2464 else
2465 {
2466 statement(builtin, " = stage_input.", builtin, ";");
2467 // ZW are undefined in D3D9, only do this fixup here.
2468 statement(builtin, ".w = 1.0 / ", builtin, ".w;");
2469 }
2470 break;
2471
2472 case BuiltInVertexId:
2473 case BuiltInVertexIndex:
2474 case BuiltInInstanceIndex:
2475 // D3D semantics are uint, but shader wants int.
2476 if (hlsl_options.support_nonzero_base_vertex_base_instance)
2477 {
2478 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2479 statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
2480 else
2481 statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
2482 }
2483 else
2484 statement(builtin, " = int(stage_input.", builtin, ");");
2485 break;
2486
2487 case BuiltInInstanceId:
2488 // D3D semantics are uint, but shader wants int.
2489 statement(builtin, " = int(stage_input.", builtin, ");");
2490 break;
2491
2492 case BuiltInNumWorkgroups:
2493 case BuiltInPointCoord:
2494 case BuiltInSubgroupSize:
2495 case BuiltInSubgroupLocalInvocationId:
2496 break;
2497
2498 case BuiltInSubgroupEqMask:
2499 // Emulate these ...
2500 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2501 statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2502 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2503 statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2504 statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2505 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2506 break;
2507
2508 case BuiltInSubgroupGeMask:
2509 // Emulate these ...
2510 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2511 statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2512 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2513 statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2514 statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2515 statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2516 statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2517 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2518 break;
2519
2520 case BuiltInSubgroupGtMask:
2521 // Emulate these ...
2522 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2523 statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
2524 statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2525 statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2526 statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2527 statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2528 statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2529 statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2530 statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2531 statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2532 break;
2533
2534 case BuiltInSubgroupLeMask:
2535 // Emulate these ...
2536 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2537 statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
2538 statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2539 statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2540 statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2541 statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2542 statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2543 statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2544 statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2545 statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2546 break;
2547
2548 case BuiltInSubgroupLtMask:
2549 // Emulate these ...
2550 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2551 statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2552 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2553 statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2554 statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2555 statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2556 statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2557 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2558 break;
2559
2560 case BuiltInClipDistance:
2561 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2562 statement("gl_ClipDistance[", clip, "] = stage_input.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3],
2563 ";");
2564 break;
2565
2566 case BuiltInCullDistance:
2567 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2568 statement("gl_CullDistance[", cull, "] = stage_input.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3],
2569 ";");
2570 break;
2571
2572 default:
2573 statement(builtin, " = stage_input.", builtin, ";");
2574 break;
2575 }
2576 });
2577
2578 // Copy from stage input struct to globals.
2579 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2580 auto &type = this->get<SPIRType>(var.basetype);
2581 bool block = has_decoration(type.self, DecorationBlock);
2582
2583 if (var.storage != StorageClassInput)
2584 return;
2585
2586 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2587
2588 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2589 interface_variable_exists_in_entry_point(var.self))
2590 {
2591 if (block)
2592 {
2593 auto type_name = to_name(type.self);
2594 auto var_name = to_name(var.self);
2595 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2596 {
2597 auto mbr_name = to_member_name(type, mbr_idx);
2598 auto flat_name = join(type_name, "_", mbr_name);
2599 statement(var_name, ".", mbr_name, " = stage_input.", flat_name, ";");
2600 }
2601 }
2602 else
2603 {
2604 auto name = to_name(var.self);
2605 auto &mtype = this->get<SPIRType>(var.basetype);
2606 if (need_matrix_unroll && mtype.columns > 1)
2607 {
2608 // Unroll matrices.
2609 for (uint32_t col = 0; col < mtype.columns; col++)
2610 statement(name, "[", col, "] = stage_input.", name, "_", col, ";");
2611 }
2612 else
2613 {
2614 statement(name, " = stage_input.", name, ";");
2615 }
2616 }
2617 }
2618 });
2619
2620 // Run the shader.
2621 if (execution.model == ExecutionModelVertex)
2622 statement("vert_main();");
2623 else if (execution.model == ExecutionModelFragment)
2624 statement("frag_main();");
2625 else if (execution.model == ExecutionModelGLCompute)
2626 statement("comp_main();");
2627 else
2628 SPIRV_CROSS_THROW("Unsupported shader stage.");
2629
2630 // Copy stage outputs.
2631 if (require_output)
2632 {
2633 statement("SPIRV_Cross_Output stage_output;");
2634
2635 // Copy builtins from globals to return struct.
2636 active_output_builtins.for_each_bit([&](uint32_t i) {
2637 // PointSize doesn't exist in HLSL.
2638 if (i == BuiltInPointSize)
2639 return;
2640
2641 switch (static_cast<BuiltIn>(i))
2642 {
2643 case BuiltInClipDistance:
2644 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2645 statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
2646 clip, "];");
2647 break;
2648
2649 case BuiltInCullDistance:
2650 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2651 statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
2652 cull, "];");
2653 break;
2654
2655 default:
2656 {
2657 auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
2658 statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
2659 break;
2660 }
2661 }
2662 });
2663
2664 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2665 auto &type = this->get<SPIRType>(var.basetype);
2666 bool block = has_decoration(type.self, DecorationBlock);
2667
2668 if (var.storage != StorageClassOutput)
2669 return;
2670
2671 if (!var.remapped_variable && type.pointer &&
2672 !is_builtin_variable(var) &&
2673 interface_variable_exists_in_entry_point(var.self))
2674 {
2675 if (block)
2676 {
2677 // I/O blocks need to flatten output.
2678 auto type_name = to_name(type.self);
2679 auto var_name = to_name(var.self);
2680 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2681 {
2682 auto mbr_name = to_member_name(type, mbr_idx);
2683 auto flat_name = join(type_name, "_", mbr_name);
2684 statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";");
2685 }
2686 }
2687 else
2688 {
2689 auto name = to_name(var.self);
2690
2691 if (legacy && execution.model == ExecutionModelFragment)
2692 {
2693 string output_filler;
2694 for (uint32_t size = type.vecsize; size < 4; ++size)
2695 output_filler += ", 0.0";
2696
2697 statement("stage_output.", name, " = float4(", name, output_filler, ");");
2698 }
2699 else
2700 {
2701 statement("stage_output.", name, " = ", name, ";");
2702 }
2703 }
2704 }
2705 });
2706
2707 statement("return stage_output;");
2708 }
2709
2710 end_scope();
2711 }
2712
emit_fixup()2713 void CompilerHLSL::emit_fixup()
2714 {
2715 if (is_vertex_like_shader())
2716 {
2717 // Do various mangling on the gl_Position.
2718 if (hlsl_options.shader_model <= 30)
2719 {
2720 statement("gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2721 "gl_Position.w;");
2722 statement("gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2723 "gl_Position.w;");
2724 }
2725
2726 if (options.vertex.flip_vert_y)
2727 statement("gl_Position.y = -gl_Position.y;");
2728 if (options.vertex.fixup_clipspace)
2729 statement("gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2730 }
2731 }
2732
emit_texture_op(const Instruction & i,bool sparse)2733 void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2734 {
2735 if (sparse)
2736 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2737
2738 auto *ops = stream(i);
2739 auto op = static_cast<Op>(i.op);
2740 uint32_t length = i.length;
2741
2742 SmallVector<uint32_t> inherited_expressions;
2743
2744 uint32_t result_type = ops[0];
2745 uint32_t id = ops[1];
2746 VariableID img = ops[2];
2747 uint32_t coord = ops[3];
2748 uint32_t dref = 0;
2749 uint32_t comp = 0;
2750 bool gather = false;
2751 bool proj = false;
2752 const uint32_t *opt = nullptr;
2753 auto *combined_image = maybe_get<SPIRCombinedImageSampler>(img);
2754
2755 if (combined_image && has_decoration(img, DecorationNonUniform))
2756 {
2757 set_decoration(combined_image->image, DecorationNonUniform);
2758 set_decoration(combined_image->sampler, DecorationNonUniform);
2759 }
2760
2761 auto img_expr = to_non_uniform_aware_expression(combined_image ? combined_image->image : img);
2762
2763 inherited_expressions.push_back(coord);
2764
2765 switch (op)
2766 {
2767 case OpImageSampleDrefImplicitLod:
2768 case OpImageSampleDrefExplicitLod:
2769 dref = ops[4];
2770 opt = &ops[5];
2771 length -= 5;
2772 break;
2773
2774 case OpImageSampleProjDrefImplicitLod:
2775 case OpImageSampleProjDrefExplicitLod:
2776 dref = ops[4];
2777 proj = true;
2778 opt = &ops[5];
2779 length -= 5;
2780 break;
2781
2782 case OpImageDrefGather:
2783 dref = ops[4];
2784 opt = &ops[5];
2785 gather = true;
2786 length -= 5;
2787 break;
2788
2789 case OpImageGather:
2790 comp = ops[4];
2791 opt = &ops[5];
2792 gather = true;
2793 length -= 5;
2794 break;
2795
2796 case OpImageSampleProjImplicitLod:
2797 case OpImageSampleProjExplicitLod:
2798 opt = &ops[4];
2799 length -= 4;
2800 proj = true;
2801 break;
2802
2803 case OpImageQueryLod:
2804 opt = &ops[4];
2805 length -= 4;
2806 break;
2807
2808 default:
2809 opt = &ops[4];
2810 length -= 4;
2811 break;
2812 }
2813
2814 auto &imgtype = expression_type(img);
2815 uint32_t coord_components = 0;
2816 switch (imgtype.image.dim)
2817 {
2818 case spv::Dim1D:
2819 coord_components = 1;
2820 break;
2821 case spv::Dim2D:
2822 coord_components = 2;
2823 break;
2824 case spv::Dim3D:
2825 coord_components = 3;
2826 break;
2827 case spv::DimCube:
2828 coord_components = 3;
2829 break;
2830 case spv::DimBuffer:
2831 coord_components = 1;
2832 break;
2833 default:
2834 coord_components = 2;
2835 break;
2836 }
2837
2838 if (dref)
2839 inherited_expressions.push_back(dref);
2840
2841 if (imgtype.image.arrayed)
2842 coord_components++;
2843
2844 uint32_t bias = 0;
2845 uint32_t lod = 0;
2846 uint32_t grad_x = 0;
2847 uint32_t grad_y = 0;
2848 uint32_t coffset = 0;
2849 uint32_t offset = 0;
2850 uint32_t coffsets = 0;
2851 uint32_t sample = 0;
2852 uint32_t minlod = 0;
2853 uint32_t flags = 0;
2854
2855 if (length)
2856 {
2857 flags = opt[0];
2858 opt++;
2859 length--;
2860 }
2861
2862 auto test = [&](uint32_t &v, uint32_t flag) {
2863 if (length && (flags & flag))
2864 {
2865 v = *opt++;
2866 inherited_expressions.push_back(v);
2867 length--;
2868 }
2869 };
2870
2871 test(bias, ImageOperandsBiasMask);
2872 test(lod, ImageOperandsLodMask);
2873 test(grad_x, ImageOperandsGradMask);
2874 test(grad_y, ImageOperandsGradMask);
2875 test(coffset, ImageOperandsConstOffsetMask);
2876 test(offset, ImageOperandsOffsetMask);
2877 test(coffsets, ImageOperandsConstOffsetsMask);
2878 test(sample, ImageOperandsSampleMask);
2879 test(minlod, ImageOperandsMinLodMask);
2880
2881 string expr;
2882 string texop;
2883
2884 if (minlod != 0)
2885 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2886
2887 if (op == OpImageFetch)
2888 {
2889 if (hlsl_options.shader_model < 40)
2890 {
2891 SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
2892 }
2893 texop += img_expr;
2894 texop += ".Load";
2895 }
2896 else if (op == OpImageQueryLod)
2897 {
2898 texop += img_expr;
2899 texop += ".CalculateLevelOfDetail";
2900 }
2901 else
2902 {
2903 auto &imgformat = get<SPIRType>(imgtype.image.type);
2904 if (imgformat.basetype != SPIRType::Float)
2905 {
2906 SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
2907 }
2908
2909 if (hlsl_options.shader_model >= 40)
2910 {
2911 texop += img_expr;
2912
2913 if (image_is_comparison(imgtype, img))
2914 {
2915 if (gather)
2916 {
2917 SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
2918 }
2919 else if (lod || grad_x || grad_y)
2920 {
2921 // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
2922 texop += ".SampleCmpLevelZero";
2923 }
2924 else
2925 texop += ".SampleCmp";
2926 }
2927 else if (gather)
2928 {
2929 uint32_t comp_num = evaluate_constant_u32(comp);
2930 if (hlsl_options.shader_model >= 50)
2931 {
2932 switch (comp_num)
2933 {
2934 case 0:
2935 texop += ".GatherRed";
2936 break;
2937 case 1:
2938 texop += ".GatherGreen";
2939 break;
2940 case 2:
2941 texop += ".GatherBlue";
2942 break;
2943 case 3:
2944 texop += ".GatherAlpha";
2945 break;
2946 default:
2947 SPIRV_CROSS_THROW("Invalid component.");
2948 }
2949 }
2950 else
2951 {
2952 if (comp_num == 0)
2953 texop += ".Gather";
2954 else
2955 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
2956 }
2957 }
2958 else if (bias)
2959 texop += ".SampleBias";
2960 else if (grad_x || grad_y)
2961 texop += ".SampleGrad";
2962 else if (lod)
2963 texop += ".SampleLevel";
2964 else
2965 texop += ".Sample";
2966 }
2967 else
2968 {
2969 switch (imgtype.image.dim)
2970 {
2971 case Dim1D:
2972 texop += "tex1D";
2973 break;
2974 case Dim2D:
2975 texop += "tex2D";
2976 break;
2977 case Dim3D:
2978 texop += "tex3D";
2979 break;
2980 case DimCube:
2981 texop += "texCUBE";
2982 break;
2983 case DimRect:
2984 case DimBuffer:
2985 case DimSubpassData:
2986 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
2987 default:
2988 SPIRV_CROSS_THROW("Invalid dimension.");
2989 }
2990
2991 if (gather)
2992 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
2993 if (offset || coffset)
2994 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
2995
2996 if (grad_x || grad_y)
2997 texop += "grad";
2998 else if (lod)
2999 texop += "lod";
3000 else if (bias)
3001 texop += "bias";
3002 else if (proj || dref)
3003 texop += "proj";
3004 }
3005 }
3006
3007 expr += texop;
3008 expr += "(";
3009 if (hlsl_options.shader_model < 40)
3010 {
3011 if (combined_image)
3012 SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3013 expr += to_expression(img);
3014 }
3015 else if (op != OpImageFetch)
3016 {
3017 string sampler_expr;
3018 if (combined_image)
3019 sampler_expr = to_non_uniform_aware_expression(combined_image->sampler);
3020 else
3021 sampler_expr = to_sampler_expression(img);
3022 expr += sampler_expr;
3023 }
3024
3025 auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3026 if (comps == in_comps)
3027 return "";
3028
3029 switch (comps)
3030 {
3031 case 1:
3032 return ".x";
3033 case 2:
3034 return ".xy";
3035 case 3:
3036 return ".xyz";
3037 default:
3038 return "";
3039 }
3040 };
3041
3042 bool forward = should_forward(coord);
3043
3044 // The IR can give us more components than we need, so chop them off as needed.
3045 string coord_expr;
3046 auto &coord_type = expression_type(coord);
3047 if (coord_components != coord_type.vecsize)
3048 coord_expr = to_enclosed_expression(coord) + swizzle(coord_components, expression_type(coord).vecsize);
3049 else
3050 coord_expr = to_expression(coord);
3051
3052 if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3053 coord_expr = coord_expr + " / " + to_extract_component_expression(coord, coord_components);
3054
3055 if (hlsl_options.shader_model < 40)
3056 {
3057 if (dref)
3058 {
3059 if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3060 {
3061 SPIRV_CROSS_THROW(
3062 "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3063 }
3064
3065 if (grad_x || grad_y)
3066 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3067
3068 for (uint32_t size = coord_components; size < 2; ++size)
3069 coord_expr += ", 0.0";
3070
3071 forward = forward && should_forward(dref);
3072 coord_expr += ", " + to_expression(dref);
3073 }
3074 else if (lod || bias || proj)
3075 {
3076 for (uint32_t size = coord_components; size < 3; ++size)
3077 coord_expr += ", 0.0";
3078 }
3079
3080 if (lod)
3081 {
3082 coord_expr = "float4(" + coord_expr + ", " + to_expression(lod) + ")";
3083 }
3084 else if (bias)
3085 {
3086 coord_expr = "float4(" + coord_expr + ", " + to_expression(bias) + ")";
3087 }
3088 else if (proj)
3089 {
3090 coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(coord, coord_components) + ")";
3091 }
3092 else if (dref)
3093 {
3094 // A "normal" sample gets fed into tex2Dproj as well, because the
3095 // regular tex2D accepts only two coordinates.
3096 coord_expr = "float4(" + coord_expr + ", 1.0)";
3097 }
3098
3099 if (!!lod + !!bias + !!proj > 1)
3100 SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3101 }
3102
3103 if (op == OpImageFetch)
3104 {
3105 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3106 coord_expr =
3107 join("int", coord_components + 1, "(", coord_expr, ", ", lod ? to_expression(lod) : string("0"), ")");
3108 }
3109 else
3110 expr += ", ";
3111 expr += coord_expr;
3112
3113 if (dref && hlsl_options.shader_model >= 40)
3114 {
3115 forward = forward && should_forward(dref);
3116 expr += ", ";
3117
3118 if (proj)
3119 expr += to_enclosed_expression(dref) + " / " + to_extract_component_expression(coord, coord_components);
3120 else
3121 expr += to_expression(dref);
3122 }
3123
3124 if (!dref && (grad_x || grad_y))
3125 {
3126 forward = forward && should_forward(grad_x);
3127 forward = forward && should_forward(grad_y);
3128 expr += ", ";
3129 expr += to_expression(grad_x);
3130 expr += ", ";
3131 expr += to_expression(grad_y);
3132 }
3133
3134 if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3135 {
3136 forward = forward && should_forward(lod);
3137 expr += ", ";
3138 expr += to_expression(lod);
3139 }
3140
3141 if (!dref && bias && hlsl_options.shader_model >= 40)
3142 {
3143 forward = forward && should_forward(bias);
3144 expr += ", ";
3145 expr += to_expression(bias);
3146 }
3147
3148 if (coffset)
3149 {
3150 forward = forward && should_forward(coffset);
3151 expr += ", ";
3152 expr += to_expression(coffset);
3153 }
3154 else if (offset)
3155 {
3156 forward = forward && should_forward(offset);
3157 expr += ", ";
3158 expr += to_expression(offset);
3159 }
3160
3161 if (sample)
3162 {
3163 expr += ", ";
3164 expr += to_expression(sample);
3165 }
3166
3167 expr += ")";
3168
3169 if (dref && hlsl_options.shader_model < 40)
3170 expr += ".x";
3171
3172 if (op == OpImageQueryLod)
3173 {
3174 // This is rather awkward.
3175 // textureQueryLod returns two values, the "accessed level",
3176 // as well as the actual LOD lambda.
3177 // As far as I can tell, there is no way to get the .x component
3178 // according to GLSL spec, and it depends on the sampler itself.
3179 // Just assume X == Y, so we will need to splat the result to a float2.
3180 statement("float _", id, "_tmp = ", expr, ";");
3181 statement("float2 _", id, " = _", id, "_tmp.xx;");
3182 set<SPIRExpression>(id, join("_", id), result_type, true);
3183 }
3184 else
3185 {
3186 emit_op(result_type, id, expr, forward, false);
3187 }
3188
3189 for (auto &inherit : inherited_expressions)
3190 inherit_expression_dependencies(id, inherit);
3191
3192 switch (op)
3193 {
3194 case OpImageSampleDrefImplicitLod:
3195 case OpImageSampleImplicitLod:
3196 case OpImageSampleProjImplicitLod:
3197 case OpImageSampleProjDrefImplicitLod:
3198 register_control_dependent_expression(id);
3199 break;
3200
3201 default:
3202 break;
3203 }
3204 }
3205
to_resource_binding(const SPIRVariable & var)3206 string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3207 {
3208 const auto &type = get<SPIRType>(var.basetype);
3209
3210 // We can remap push constant blocks, even if they don't have any binding decoration.
3211 if (type.storage != StorageClassPushConstant && !has_decoration(var.self, DecorationBinding))
3212 return "";
3213
3214 char space = '\0';
3215
3216 HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3217
3218 switch (type.basetype)
3219 {
3220 case SPIRType::SampledImage:
3221 space = 't'; // SRV
3222 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3223 break;
3224
3225 case SPIRType::Image:
3226 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3227 {
3228 if (has_decoration(var.self, DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3229 {
3230 space = 't'; // SRV
3231 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3232 }
3233 else
3234 {
3235 space = 'u'; // UAV
3236 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3237 }
3238 }
3239 else
3240 {
3241 space = 't'; // SRV
3242 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3243 }
3244 break;
3245
3246 case SPIRType::Sampler:
3247 space = 's';
3248 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3249 break;
3250
3251 case SPIRType::Struct:
3252 {
3253 auto storage = type.storage;
3254 if (storage == StorageClassUniform)
3255 {
3256 if (has_decoration(type.self, DecorationBufferBlock))
3257 {
3258 Bitset flags = ir.get_buffer_block_flags(var);
3259 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3260 space = is_readonly ? 't' : 'u'; // UAV
3261 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3262 }
3263 else if (has_decoration(type.self, DecorationBlock))
3264 {
3265 space = 'b'; // Constant buffers
3266 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3267 }
3268 }
3269 else if (storage == StorageClassPushConstant)
3270 {
3271 space = 'b'; // Constant buffers
3272 resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3273 }
3274 else if (storage == StorageClassStorageBuffer)
3275 {
3276 // UAV or SRV depending on readonly flag.
3277 Bitset flags = ir.get_buffer_block_flags(var);
3278 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3279 space = is_readonly ? 't' : 'u';
3280 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3281 }
3282
3283 break;
3284 }
3285 default:
3286 break;
3287 }
3288
3289 if (!space)
3290 return "";
3291
3292 uint32_t desc_set =
3293 resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3294 uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3295
3296 if (has_decoration(var.self, DecorationBinding))
3297 binding = get_decoration(var.self, DecorationBinding);
3298 if (has_decoration(var.self, DecorationDescriptorSet))
3299 desc_set = get_decoration(var.self, DecorationDescriptorSet);
3300
3301 return to_resource_register(resource_flags, space, binding, desc_set);
3302 }
3303
to_resource_binding_sampler(const SPIRVariable & var)3304 string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3305 {
3306 // For combined image samplers.
3307 if (!has_decoration(var.self, DecorationBinding))
3308 return "";
3309
3310 return to_resource_register(HLSL_BINDING_AUTO_SAMPLER_BIT, 's', get_decoration(var.self, DecorationBinding),
3311 get_decoration(var.self, DecorationDescriptorSet));
3312 }
3313
remap_hlsl_resource_binding(HLSLBindingFlagBits type,uint32_t & desc_set,uint32_t & binding)3314 void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3315 {
3316 auto itr = resource_bindings.find({ get_execution_model(), desc_set, binding });
3317 if (itr != end(resource_bindings))
3318 {
3319 auto &remap = itr->second;
3320 remap.second = true;
3321
3322 switch (type)
3323 {
3324 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3325 case HLSL_BINDING_AUTO_CBV_BIT:
3326 desc_set = remap.first.cbv.register_space;
3327 binding = remap.first.cbv.register_binding;
3328 break;
3329
3330 case HLSL_BINDING_AUTO_SRV_BIT:
3331 desc_set = remap.first.srv.register_space;
3332 binding = remap.first.srv.register_binding;
3333 break;
3334
3335 case HLSL_BINDING_AUTO_SAMPLER_BIT:
3336 desc_set = remap.first.sampler.register_space;
3337 binding = remap.first.sampler.register_binding;
3338 break;
3339
3340 case HLSL_BINDING_AUTO_UAV_BIT:
3341 desc_set = remap.first.uav.register_space;
3342 binding = remap.first.uav.register_binding;
3343 break;
3344
3345 default:
3346 break;
3347 }
3348 }
3349 }
3350
to_resource_register(HLSLBindingFlagBits flag,char space,uint32_t binding,uint32_t space_set)3351 string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3352 {
3353 if ((flag & resource_binding_flags) == 0)
3354 {
3355 remap_hlsl_resource_binding(flag, space_set, binding);
3356
3357 // The push constant block did not have a binding, and there were no remap for it,
3358 // so, declare without register binding.
3359 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3360 return "";
3361
3362 if (hlsl_options.shader_model >= 51)
3363 return join(" : register(", space, binding, ", space", space_set, ")");
3364 else
3365 return join(" : register(", space, binding, ")");
3366 }
3367 else
3368 return "";
3369 }
3370
emit_modern_uniform(const SPIRVariable & var)3371 void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3372 {
3373 auto &type = get<SPIRType>(var.basetype);
3374 switch (type.basetype)
3375 {
3376 case SPIRType::SampledImage:
3377 case SPIRType::Image:
3378 {
3379 bool is_coherent = false;
3380 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3381 is_coherent = has_decoration(var.self, DecorationCoherent);
3382
3383 statement(is_coherent ? "globallycoherent " : "", image_type_hlsl_modern(type, var.self), " ",
3384 to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3385
3386 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3387 {
3388 // For combined image samplers, also emit a combined image sampler.
3389 if (image_is_comparison(type, var.self))
3390 statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3391 to_resource_binding_sampler(var), ";");
3392 else
3393 statement("SamplerState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3394 to_resource_binding_sampler(var), ";");
3395 }
3396 break;
3397 }
3398
3399 case SPIRType::Sampler:
3400 if (comparison_ids.count(var.self))
3401 statement("SamplerComparisonState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var),
3402 ";");
3403 else
3404 statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3405 break;
3406
3407 default:
3408 statement(variable_decl(var), to_resource_binding(var), ";");
3409 break;
3410 }
3411 }
3412
emit_legacy_uniform(const SPIRVariable & var)3413 void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3414 {
3415 auto &type = get<SPIRType>(var.basetype);
3416 switch (type.basetype)
3417 {
3418 case SPIRType::Sampler:
3419 case SPIRType::Image:
3420 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3421
3422 default:
3423 statement(variable_decl(var), ";");
3424 break;
3425 }
3426 }
3427
emit_uniform(const SPIRVariable & var)3428 void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3429 {
3430 add_resource_name(var.self);
3431 if (hlsl_options.shader_model >= 40)
3432 emit_modern_uniform(var);
3433 else
3434 emit_legacy_uniform(var);
3435 }
3436
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)3437 bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3438 {
3439 return false;
3440 }
3441
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)3442 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3443 {
3444 if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3445 return type_to_glsl(out_type);
3446 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3447 return type_to_glsl(out_type);
3448 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3449 return "asuint";
3450 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3451 return type_to_glsl(out_type);
3452 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3453 return type_to_glsl(out_type);
3454 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3455 return "asint";
3456 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3457 return "asfloat";
3458 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3459 return "asfloat";
3460 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3461 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3462 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3463 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3464 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3465 return "asdouble";
3466 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3467 return "asdouble";
3468 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3469 {
3470 if (!requires_explicit_fp16_packing)
3471 {
3472 requires_explicit_fp16_packing = true;
3473 force_recompile();
3474 }
3475 return "spvUnpackFloat2x16";
3476 }
3477 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3478 {
3479 if (!requires_explicit_fp16_packing)
3480 {
3481 requires_explicit_fp16_packing = true;
3482 force_recompile();
3483 }
3484 return "spvPackFloat2x16";
3485 }
3486 else
3487 return "";
3488 }
3489
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)3490 void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3491 {
3492 auto op = static_cast<GLSLstd450>(eop);
3493
3494 // If we need to do implicit bitcasts, make sure we do it with the correct type.
3495 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
3496 auto int_type = to_signed_basetype(integer_width);
3497 auto uint_type = to_unsigned_basetype(integer_width);
3498
3499 switch (op)
3500 {
3501 case GLSLstd450InverseSqrt:
3502 emit_unary_func_op(result_type, id, args[0], "rsqrt");
3503 break;
3504
3505 case GLSLstd450Fract:
3506 emit_unary_func_op(result_type, id, args[0], "frac");
3507 break;
3508
3509 case GLSLstd450RoundEven:
3510 if (hlsl_options.shader_model < 40)
3511 SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3512 emit_unary_func_op(result_type, id, args[0], "round");
3513 break;
3514
3515 case GLSLstd450Acosh:
3516 case GLSLstd450Asinh:
3517 case GLSLstd450Atanh:
3518 SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3519
3520 case GLSLstd450FMix:
3521 case GLSLstd450IMix:
3522 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "lerp");
3523 break;
3524
3525 case GLSLstd450Atan2:
3526 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
3527 break;
3528
3529 case GLSLstd450Fma:
3530 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mad");
3531 break;
3532
3533 case GLSLstd450InterpolateAtCentroid:
3534 emit_unary_func_op(result_type, id, args[0], "EvaluateAttributeAtCentroid");
3535 break;
3536 case GLSLstd450InterpolateAtSample:
3537 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeAtSample");
3538 break;
3539 case GLSLstd450InterpolateAtOffset:
3540 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeSnapped");
3541 break;
3542
3543 case GLSLstd450PackHalf2x16:
3544 if (!requires_fp16_packing)
3545 {
3546 requires_fp16_packing = true;
3547 force_recompile();
3548 }
3549 emit_unary_func_op(result_type, id, args[0], "spvPackHalf2x16");
3550 break;
3551
3552 case GLSLstd450UnpackHalf2x16:
3553 if (!requires_fp16_packing)
3554 {
3555 requires_fp16_packing = true;
3556 force_recompile();
3557 }
3558 emit_unary_func_op(result_type, id, args[0], "spvUnpackHalf2x16");
3559 break;
3560
3561 case GLSLstd450PackSnorm4x8:
3562 if (!requires_snorm8_packing)
3563 {
3564 requires_snorm8_packing = true;
3565 force_recompile();
3566 }
3567 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm4x8");
3568 break;
3569
3570 case GLSLstd450UnpackSnorm4x8:
3571 if (!requires_snorm8_packing)
3572 {
3573 requires_snorm8_packing = true;
3574 force_recompile();
3575 }
3576 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm4x8");
3577 break;
3578
3579 case GLSLstd450PackUnorm4x8:
3580 if (!requires_unorm8_packing)
3581 {
3582 requires_unorm8_packing = true;
3583 force_recompile();
3584 }
3585 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm4x8");
3586 break;
3587
3588 case GLSLstd450UnpackUnorm4x8:
3589 if (!requires_unorm8_packing)
3590 {
3591 requires_unorm8_packing = true;
3592 force_recompile();
3593 }
3594 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm4x8");
3595 break;
3596
3597 case GLSLstd450PackSnorm2x16:
3598 if (!requires_snorm16_packing)
3599 {
3600 requires_snorm16_packing = true;
3601 force_recompile();
3602 }
3603 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm2x16");
3604 break;
3605
3606 case GLSLstd450UnpackSnorm2x16:
3607 if (!requires_snorm16_packing)
3608 {
3609 requires_snorm16_packing = true;
3610 force_recompile();
3611 }
3612 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm2x16");
3613 break;
3614
3615 case GLSLstd450PackUnorm2x16:
3616 if (!requires_unorm16_packing)
3617 {
3618 requires_unorm16_packing = true;
3619 force_recompile();
3620 }
3621 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm2x16");
3622 break;
3623
3624 case GLSLstd450UnpackUnorm2x16:
3625 if (!requires_unorm16_packing)
3626 {
3627 requires_unorm16_packing = true;
3628 force_recompile();
3629 }
3630 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm2x16");
3631 break;
3632
3633 case GLSLstd450PackDouble2x32:
3634 case GLSLstd450UnpackDouble2x32:
3635 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3636
3637 case GLSLstd450FindILsb:
3638 {
3639 auto basetype = expression_type(args[0]).basetype;
3640 emit_unary_func_op_cast(result_type, id, args[0], "firstbitlow", basetype, basetype);
3641 break;
3642 }
3643
3644 case GLSLstd450FindSMsb:
3645 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", int_type, int_type);
3646 break;
3647
3648 case GLSLstd450FindUMsb:
3649 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", uint_type, uint_type);
3650 break;
3651
3652 case GLSLstd450MatrixInverse:
3653 {
3654 auto &type = get<SPIRType>(result_type);
3655 if (type.vecsize == 2 && type.columns == 2)
3656 {
3657 if (!requires_inverse_2x2)
3658 {
3659 requires_inverse_2x2 = true;
3660 force_recompile();
3661 }
3662 }
3663 else if (type.vecsize == 3 && type.columns == 3)
3664 {
3665 if (!requires_inverse_3x3)
3666 {
3667 requires_inverse_3x3 = true;
3668 force_recompile();
3669 }
3670 }
3671 else if (type.vecsize == 4 && type.columns == 4)
3672 {
3673 if (!requires_inverse_4x4)
3674 {
3675 requires_inverse_4x4 = true;
3676 force_recompile();
3677 }
3678 }
3679 emit_unary_func_op(result_type, id, args[0], "spvInverse");
3680 break;
3681 }
3682
3683 case GLSLstd450Normalize:
3684 // HLSL does not support scalar versions here.
3685 if (expression_type(args[0]).vecsize == 1)
3686 {
3687 // Returns -1 or 1 for valid input, sign() does the job.
3688 emit_unary_func_op(result_type, id, args[0], "sign");
3689 }
3690 else
3691 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3692 break;
3693
3694 case GLSLstd450Reflect:
3695 if (get<SPIRType>(result_type).vecsize == 1)
3696 {
3697 if (!requires_scalar_reflect)
3698 {
3699 requires_scalar_reflect = true;
3700 force_recompile();
3701 }
3702 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
3703 }
3704 else
3705 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3706 break;
3707
3708 case GLSLstd450Refract:
3709 if (get<SPIRType>(result_type).vecsize == 1)
3710 {
3711 if (!requires_scalar_refract)
3712 {
3713 requires_scalar_refract = true;
3714 force_recompile();
3715 }
3716 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
3717 }
3718 else
3719 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3720 break;
3721
3722 case GLSLstd450FaceForward:
3723 if (get<SPIRType>(result_type).vecsize == 1)
3724 {
3725 if (!requires_scalar_faceforward)
3726 {
3727 requires_scalar_faceforward = true;
3728 force_recompile();
3729 }
3730 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
3731 }
3732 else
3733 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3734 break;
3735
3736 default:
3737 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3738 break;
3739 }
3740 }
3741
read_access_chain_array(const string & lhs,const SPIRAccessChain & chain)3742 void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3743 {
3744 auto &type = get<SPIRType>(chain.basetype);
3745
3746 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3747 auto ident = get_unique_identifier();
3748
3749 statement("[unroll]");
3750 statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
3751 ident, "++)");
3752 begin_scope();
3753 auto subchain = chain;
3754 subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
3755 subchain.basetype = type.parent_type;
3756 if (!get<SPIRType>(subchain.basetype).array.empty())
3757 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
3758 read_access_chain(nullptr, join(lhs, "[", ident, "]"), subchain);
3759 end_scope();
3760 }
3761
read_access_chain_struct(const string & lhs,const SPIRAccessChain & chain)3762 void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3763 {
3764 auto &type = get<SPIRType>(chain.basetype);
3765 auto subchain = chain;
3766 uint32_t member_count = uint32_t(type.member_types.size());
3767
3768 for (uint32_t i = 0; i < member_count; i++)
3769 {
3770 uint32_t offset = type_struct_member_offset(type, i);
3771 subchain.static_index = chain.static_index + offset;
3772 subchain.basetype = type.member_types[i];
3773
3774 subchain.matrix_stride = 0;
3775 subchain.array_stride = 0;
3776 subchain.row_major_matrix = false;
3777
3778 auto &member_type = get<SPIRType>(subchain.basetype);
3779 if (member_type.columns > 1)
3780 {
3781 subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
3782 subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
3783 }
3784
3785 if (!member_type.array.empty())
3786 subchain.array_stride = type_struct_member_array_stride(type, i);
3787
3788 read_access_chain(nullptr, join(lhs, ".", to_member_name(type, i)), subchain);
3789 }
3790 }
3791
read_access_chain(string * expr,const string & lhs,const SPIRAccessChain & chain)3792 void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3793 {
3794 auto &type = get<SPIRType>(chain.basetype);
3795
3796 SPIRType target_type;
3797 target_type.basetype = SPIRType::UInt;
3798 target_type.vecsize = type.vecsize;
3799 target_type.columns = type.columns;
3800
3801 if (!type.array.empty())
3802 {
3803 read_access_chain_array(lhs, chain);
3804 return;
3805 }
3806 else if (type.basetype == SPIRType::Struct)
3807 {
3808 read_access_chain_struct(lhs, chain);
3809 return;
3810 }
3811 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3812 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3813 "native 16-bit types are enabled.");
3814
3815 string base = chain.base;
3816 if (has_decoration(chain.self, DecorationNonUniform))
3817 convert_non_uniform_expression(base, chain.self);
3818
3819 bool templated_load = hlsl_options.shader_model >= 62;
3820 string load_expr;
3821
3822 string template_expr;
3823 if (templated_load)
3824 template_expr = join("<", type_to_glsl(type), ">");
3825
3826 // Load a vector or scalar.
3827 if (type.columns == 1 && !chain.row_major_matrix)
3828 {
3829 const char *load_op = nullptr;
3830 switch (type.vecsize)
3831 {
3832 case 1:
3833 load_op = "Load";
3834 break;
3835 case 2:
3836 load_op = "Load2";
3837 break;
3838 case 3:
3839 load_op = "Load3";
3840 break;
3841 case 4:
3842 load_op = "Load4";
3843 break;
3844 default:
3845 SPIRV_CROSS_THROW("Unknown vector size.");
3846 }
3847
3848 if (templated_load)
3849 load_op = "Load";
3850
3851 load_expr = join(base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
3852 }
3853 else if (type.columns == 1)
3854 {
3855 // Strided load since we are loading a column from a row-major matrix.
3856 if (templated_load)
3857 {
3858 auto scalar_type = type;
3859 scalar_type.vecsize = 1;
3860 scalar_type.columns = 1;
3861 template_expr = join("<", type_to_glsl(scalar_type), ">");
3862 if (type.vecsize > 1)
3863 load_expr += type_to_glsl(type) + "(";
3864 }
3865 else if (type.vecsize > 1)
3866 {
3867 load_expr = type_to_glsl(target_type);
3868 load_expr += "(";
3869 }
3870
3871 for (uint32_t r = 0; r < type.vecsize; r++)
3872 {
3873 load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3874 chain.static_index + r * chain.matrix_stride, ")");
3875 if (r + 1 < type.vecsize)
3876 load_expr += ", ";
3877 }
3878
3879 if (type.vecsize > 1)
3880 load_expr += ")";
3881 }
3882 else if (!chain.row_major_matrix)
3883 {
3884 // Load a matrix, column-major, the easy case.
3885 const char *load_op = nullptr;
3886 switch (type.vecsize)
3887 {
3888 case 1:
3889 load_op = "Load";
3890 break;
3891 case 2:
3892 load_op = "Load2";
3893 break;
3894 case 3:
3895 load_op = "Load3";
3896 break;
3897 case 4:
3898 load_op = "Load4";
3899 break;
3900 default:
3901 SPIRV_CROSS_THROW("Unknown vector size.");
3902 }
3903
3904 if (templated_load)
3905 {
3906 auto vector_type = type;
3907 vector_type.columns = 1;
3908 template_expr = join("<", type_to_glsl(vector_type), ">");
3909 load_expr = type_to_glsl(type);
3910 load_op = "Load";
3911 }
3912 else
3913 {
3914 // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
3915 // so row-major is technically column-major ...
3916 load_expr = type_to_glsl(target_type);
3917 }
3918 load_expr += "(";
3919
3920 for (uint32_t c = 0; c < type.columns; c++)
3921 {
3922 load_expr += join(base, ".", load_op, template_expr, "(", chain.dynamic_index,
3923 chain.static_index + c * chain.matrix_stride, ")");
3924 if (c + 1 < type.columns)
3925 load_expr += ", ";
3926 }
3927 load_expr += ")";
3928 }
3929 else
3930 {
3931 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
3932 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
3933
3934 if (templated_load)
3935 {
3936 load_expr = type_to_glsl(type);
3937 auto scalar_type = type;
3938 scalar_type.vecsize = 1;
3939 scalar_type.columns = 1;
3940 template_expr = join("<", type_to_glsl(scalar_type), ">");
3941 }
3942 else
3943 load_expr = type_to_glsl(target_type);
3944
3945 load_expr += "(";
3946
3947 for (uint32_t c = 0; c < type.columns; c++)
3948 {
3949 for (uint32_t r = 0; r < type.vecsize; r++)
3950 {
3951 load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3952 chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
3953
3954 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
3955 load_expr += ", ";
3956 }
3957 }
3958 load_expr += ")";
3959 }
3960
3961 if (!templated_load)
3962 {
3963 auto bitcast_op = bitcast_glsl_op(type, target_type);
3964 if (!bitcast_op.empty())
3965 load_expr = join(bitcast_op, "(", load_expr, ")");
3966 }
3967
3968 if (lhs.empty())
3969 {
3970 assert(expr);
3971 *expr = move(load_expr);
3972 }
3973 else
3974 statement(lhs, " = ", load_expr, ";");
3975 }
3976
emit_load(const Instruction & instruction)3977 void CompilerHLSL::emit_load(const Instruction &instruction)
3978 {
3979 auto ops = stream(instruction);
3980
3981 auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
3982 if (chain)
3983 {
3984 uint32_t result_type = ops[0];
3985 uint32_t id = ops[1];
3986 uint32_t ptr = ops[2];
3987
3988 auto &type = get<SPIRType>(result_type);
3989 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
3990
3991 if (composite_load)
3992 {
3993 // We cannot make this work in one single expression as we might have nested structures and arrays,
3994 // so unroll the load to an uninitialized temporary.
3995 emit_uninitialized_temporary_expression(result_type, id);
3996 read_access_chain(nullptr, to_expression(id), *chain);
3997 track_expression_read(chain->self);
3998 }
3999 else
4000 {
4001 string load_expr;
4002 read_access_chain(&load_expr, "", *chain);
4003
4004 bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
4005
4006 // If we are forwarding this load,
4007 // don't register the read to access chain here, defer that to when we actually use the expression,
4008 // using the add_implied_read_expression mechanism.
4009 if (!forward)
4010 track_expression_read(chain->self);
4011
4012 // Do not forward complex load sequences like matrices, structs and arrays.
4013 if (type.columns > 1)
4014 forward = false;
4015
4016 auto &e = emit_op(result_type, id, load_expr, forward, true);
4017 e.need_transpose = false;
4018 register_read(id, ptr, forward);
4019 inherit_expression_dependencies(id, ptr);
4020 if (forward)
4021 add_implied_read_expression(e, chain->self);
4022 }
4023 }
4024 else
4025 CompilerGLSL::emit_instruction(instruction);
4026 }
4027
write_access_chain_array(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4028 void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4029 const SmallVector<uint32_t> &composite_chain)
4030 {
4031 auto &type = get<SPIRType>(chain.basetype);
4032
4033 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4034 auto ident = get_unique_identifier();
4035
4036 uint32_t id = ir.increase_bound_by(2);
4037 uint32_t int_type_id = id + 1;
4038 SPIRType int_type;
4039 int_type.basetype = SPIRType::Int;
4040 int_type.width = 32;
4041 set<SPIRType>(int_type_id, int_type);
4042 set<SPIRExpression>(id, ident, int_type_id, true);
4043 set_name(id, ident);
4044 suppressed_usage_tracking.insert(id);
4045
4046 statement("[unroll]");
4047 statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
4048 ident, "++)");
4049 begin_scope();
4050 auto subchain = chain;
4051 subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
4052 subchain.basetype = type.parent_type;
4053
4054 // Forcefully allow us to use an ID here by setting MSB.
4055 auto subcomposite_chain = composite_chain;
4056 subcomposite_chain.push_back(0x80000000u | id);
4057
4058 if (!get<SPIRType>(subchain.basetype).array.empty())
4059 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
4060
4061 write_access_chain(subchain, value, subcomposite_chain);
4062 end_scope();
4063 }
4064
write_access_chain_struct(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4065 void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4066 const SmallVector<uint32_t> &composite_chain)
4067 {
4068 auto &type = get<SPIRType>(chain.basetype);
4069 uint32_t member_count = uint32_t(type.member_types.size());
4070 auto subchain = chain;
4071
4072 auto subcomposite_chain = composite_chain;
4073 subcomposite_chain.push_back(0);
4074
4075 for (uint32_t i = 0; i < member_count; i++)
4076 {
4077 uint32_t offset = type_struct_member_offset(type, i);
4078 subchain.static_index = chain.static_index + offset;
4079 subchain.basetype = type.member_types[i];
4080
4081 subchain.matrix_stride = 0;
4082 subchain.array_stride = 0;
4083 subchain.row_major_matrix = false;
4084
4085 auto &member_type = get<SPIRType>(subchain.basetype);
4086 if (member_type.columns > 1)
4087 {
4088 subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
4089 subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
4090 }
4091
4092 if (!member_type.array.empty())
4093 subchain.array_stride = type_struct_member_array_stride(type, i);
4094
4095 subcomposite_chain.back() = i;
4096 write_access_chain(subchain, value, subcomposite_chain);
4097 }
4098 }
4099
write_access_chain_value(uint32_t value,const SmallVector<uint32_t> & composite_chain,bool enclose)4100 string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4101 bool enclose)
4102 {
4103 string ret;
4104 if (composite_chain.empty())
4105 ret = to_expression(value);
4106 else
4107 {
4108 AccessChainMeta meta;
4109 ret = access_chain_internal(value, composite_chain.data(), uint32_t(composite_chain.size()),
4110 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, &meta);
4111 }
4112
4113 if (enclose)
4114 ret = enclose_expression(ret);
4115 return ret;
4116 }
4117
write_access_chain(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4118 void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4119 const SmallVector<uint32_t> &composite_chain)
4120 {
4121 auto &type = get<SPIRType>(chain.basetype);
4122
4123 // Make sure we trigger a read of the constituents in the access chain.
4124 track_expression_read(chain.self);
4125
4126 SPIRType target_type;
4127 target_type.basetype = SPIRType::UInt;
4128 target_type.vecsize = type.vecsize;
4129 target_type.columns = type.columns;
4130
4131 if (!type.array.empty())
4132 {
4133 write_access_chain_array(chain, value, composite_chain);
4134 register_write(chain.self);
4135 return;
4136 }
4137 else if (type.basetype == SPIRType::Struct)
4138 {
4139 write_access_chain_struct(chain, value, composite_chain);
4140 register_write(chain.self);
4141 return;
4142 }
4143 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4144 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4145 "native 16-bit types are enabled.");
4146
4147 bool templated_store = hlsl_options.shader_model >= 62;
4148
4149 auto base = chain.base;
4150 if (has_decoration(chain.self, DecorationNonUniform))
4151 convert_non_uniform_expression(base, chain.self);
4152
4153 string template_expr;
4154 if (templated_store)
4155 template_expr = join("<", type_to_glsl(type), ">");
4156
4157 if (type.columns == 1 && !chain.row_major_matrix)
4158 {
4159 const char *store_op = nullptr;
4160 switch (type.vecsize)
4161 {
4162 case 1:
4163 store_op = "Store";
4164 break;
4165 case 2:
4166 store_op = "Store2";
4167 break;
4168 case 3:
4169 store_op = "Store3";
4170 break;
4171 case 4:
4172 store_op = "Store4";
4173 break;
4174 default:
4175 SPIRV_CROSS_THROW("Unknown vector size.");
4176 }
4177
4178 auto store_expr = write_access_chain_value(value, composite_chain, false);
4179
4180 if (!templated_store)
4181 {
4182 auto bitcast_op = bitcast_glsl_op(target_type, type);
4183 if (!bitcast_op.empty())
4184 store_expr = join(bitcast_op, "(", store_expr, ")");
4185 }
4186 else
4187 store_op = "Store";
4188 statement(base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ",
4189 store_expr, ");");
4190 }
4191 else if (type.columns == 1)
4192 {
4193 if (templated_store)
4194 {
4195 auto scalar_type = type;
4196 scalar_type.vecsize = 1;
4197 scalar_type.columns = 1;
4198 template_expr = join("<", type_to_glsl(scalar_type), ">");
4199 }
4200
4201 // Strided store.
4202 for (uint32_t r = 0; r < type.vecsize; r++)
4203 {
4204 auto store_expr = write_access_chain_value(value, composite_chain, true);
4205 if (type.vecsize > 1)
4206 {
4207 store_expr += ".";
4208 store_expr += index_to_swizzle(r);
4209 }
4210 remove_duplicate_swizzle(store_expr);
4211
4212 if (!templated_store)
4213 {
4214 auto bitcast_op = bitcast_glsl_op(target_type, type);
4215 if (!bitcast_op.empty())
4216 store_expr = join(bitcast_op, "(", store_expr, ")");
4217 }
4218
4219 statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4220 chain.static_index + chain.matrix_stride * r, ", ", store_expr, ");");
4221 }
4222 }
4223 else if (!chain.row_major_matrix)
4224 {
4225 const char *store_op = nullptr;
4226 switch (type.vecsize)
4227 {
4228 case 1:
4229 store_op = "Store";
4230 break;
4231 case 2:
4232 store_op = "Store2";
4233 break;
4234 case 3:
4235 store_op = "Store3";
4236 break;
4237 case 4:
4238 store_op = "Store4";
4239 break;
4240 default:
4241 SPIRV_CROSS_THROW("Unknown vector size.");
4242 }
4243
4244 if (templated_store)
4245 {
4246 store_op = "Store";
4247 auto vector_type = type;
4248 vector_type.columns = 1;
4249 template_expr = join("<", type_to_glsl(vector_type), ">");
4250 }
4251
4252 for (uint32_t c = 0; c < type.columns; c++)
4253 {
4254 auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
4255
4256 if (!templated_store)
4257 {
4258 auto bitcast_op = bitcast_glsl_op(target_type, type);
4259 if (!bitcast_op.empty())
4260 store_expr = join(bitcast_op, "(", store_expr, ")");
4261 }
4262
4263 statement(base, ".", store_op, template_expr, "(", chain.dynamic_index,
4264 chain.static_index + c * chain.matrix_stride, ", ", store_expr, ");");
4265 }
4266 }
4267 else
4268 {
4269 if (templated_store)
4270 {
4271 auto scalar_type = type;
4272 scalar_type.vecsize = 1;
4273 scalar_type.columns = 1;
4274 template_expr = join("<", type_to_glsl(scalar_type), ">");
4275 }
4276
4277 for (uint32_t r = 0; r < type.vecsize; r++)
4278 {
4279 for (uint32_t c = 0; c < type.columns; c++)
4280 {
4281 auto store_expr =
4282 join(write_access_chain_value(value, composite_chain, true), "[", c, "].", index_to_swizzle(r));
4283 remove_duplicate_swizzle(store_expr);
4284 auto bitcast_op = bitcast_glsl_op(target_type, type);
4285 if (!bitcast_op.empty())
4286 store_expr = join(bitcast_op, "(", store_expr, ")");
4287 statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4288 chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
4289 }
4290 }
4291 }
4292
4293 register_write(chain.self);
4294 }
4295
emit_store(const Instruction & instruction)4296 void CompilerHLSL::emit_store(const Instruction &instruction)
4297 {
4298 auto ops = stream(instruction);
4299 auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4300 if (chain)
4301 write_access_chain(*chain, ops[1], {});
4302 else
4303 CompilerGLSL::emit_instruction(instruction);
4304 }
4305
emit_access_chain(const Instruction & instruction)4306 void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4307 {
4308 auto ops = stream(instruction);
4309 uint32_t length = instruction.length;
4310
4311 bool need_byte_access_chain = false;
4312 auto &type = expression_type(ops[2]);
4313 const auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4314
4315 if (chain)
4316 {
4317 // Keep tacking on an existing access chain.
4318 need_byte_access_chain = true;
4319 }
4320 else if (type.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock))
4321 {
4322 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4323 // to emit SPIRAccessChain rather than a plain SPIRExpression.
4324 uint32_t chain_arguments = length - 3;
4325 if (chain_arguments > type.array.size())
4326 need_byte_access_chain = true;
4327 }
4328
4329 if (need_byte_access_chain)
4330 {
4331 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4332 // and not array of SSBO.
4333 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4334
4335 auto *backing_variable = maybe_get_backing_variable(ops[2]);
4336
4337 string base;
4338 if (to_plain_buffer_length != 0)
4339 base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get<SPIRType>(ops[0]));
4340 else if (chain)
4341 base = chain->base;
4342 else
4343 base = to_expression(ops[2]);
4344
4345 // Start traversing type hierarchy at the proper non-pointer types.
4346 auto *basetype = &get_pointee_type(type);
4347
4348 // Traverse the type hierarchy down to the actual buffer types.
4349 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4350 {
4351 assert(basetype->parent_type);
4352 basetype = &get<SPIRType>(basetype->parent_type);
4353 }
4354
4355 uint32_t matrix_stride = 0;
4356 uint32_t array_stride = 0;
4357 bool row_major_matrix = false;
4358
4359 // Inherit matrix information.
4360 if (chain)
4361 {
4362 matrix_stride = chain->matrix_stride;
4363 row_major_matrix = chain->row_major_matrix;
4364 array_stride = chain->array_stride;
4365 }
4366
4367 auto offsets = flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
4368 length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix,
4369 &matrix_stride, &array_stride);
4370
4371 auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
4372 e.row_major_matrix = row_major_matrix;
4373 e.matrix_stride = matrix_stride;
4374 e.array_stride = array_stride;
4375 e.immutable = should_forward(ops[2]);
4376 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4377
4378 if (chain)
4379 {
4380 e.dynamic_index += chain->dynamic_index;
4381 e.static_index += chain->static_index;
4382 }
4383
4384 for (uint32_t i = 2; i < length; i++)
4385 {
4386 inherit_expression_dependencies(ops[1], ops[i]);
4387 add_implied_read_expression(e, ops[i]);
4388 }
4389 }
4390 else
4391 {
4392 CompilerGLSL::emit_instruction(instruction);
4393 }
4394 }
4395
emit_atomic(const uint32_t * ops,uint32_t length,spv::Op op)4396 void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4397 {
4398 const char *atomic_op = nullptr;
4399
4400 string value_expr;
4401 if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4402 value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
4403
4404 bool is_atomic_store = false;
4405
4406 switch (op)
4407 {
4408 case OpAtomicIIncrement:
4409 atomic_op = "InterlockedAdd";
4410 value_expr = "1";
4411 break;
4412
4413 case OpAtomicIDecrement:
4414 atomic_op = "InterlockedAdd";
4415 value_expr = "-1";
4416 break;
4417
4418 case OpAtomicLoad:
4419 atomic_op = "InterlockedAdd";
4420 value_expr = "0";
4421 break;
4422
4423 case OpAtomicISub:
4424 atomic_op = "InterlockedAdd";
4425 value_expr = join("-", enclose_expression(value_expr));
4426 break;
4427
4428 case OpAtomicSMin:
4429 case OpAtomicUMin:
4430 atomic_op = "InterlockedMin";
4431 break;
4432
4433 case OpAtomicSMax:
4434 case OpAtomicUMax:
4435 atomic_op = "InterlockedMax";
4436 break;
4437
4438 case OpAtomicAnd:
4439 atomic_op = "InterlockedAnd";
4440 break;
4441
4442 case OpAtomicOr:
4443 atomic_op = "InterlockedOr";
4444 break;
4445
4446 case OpAtomicXor:
4447 atomic_op = "InterlockedXor";
4448 break;
4449
4450 case OpAtomicIAdd:
4451 atomic_op = "InterlockedAdd";
4452 break;
4453
4454 case OpAtomicExchange:
4455 atomic_op = "InterlockedExchange";
4456 break;
4457
4458 case OpAtomicStore:
4459 atomic_op = "InterlockedExchange";
4460 is_atomic_store = true;
4461 break;
4462
4463 case OpAtomicCompareExchange:
4464 if (length < 8)
4465 SPIRV_CROSS_THROW("Not enough data for opcode.");
4466 atomic_op = "InterlockedCompareExchange";
4467 value_expr = join(to_expression(ops[7]), ", ", value_expr);
4468 break;
4469
4470 default:
4471 SPIRV_CROSS_THROW("Unknown atomic opcode.");
4472 }
4473
4474 if (is_atomic_store)
4475 {
4476 auto &data_type = expression_type(ops[0]);
4477 auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4478
4479 auto &tmp_id = extra_sub_expressions[ops[0]];
4480 if (!tmp_id)
4481 {
4482 tmp_id = ir.increase_bound_by(1);
4483 emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
4484 }
4485
4486 if (data_type.storage == StorageClassImage || !chain)
4487 {
4488 statement(atomic_op, "(", to_non_uniform_aware_expression(ops[0]), ", ",
4489 to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4490 }
4491 else
4492 {
4493 string base = chain->base;
4494 if (has_decoration(chain->self, DecorationNonUniform))
4495 convert_non_uniform_expression(base, chain->self);
4496 // RWByteAddress buffer is always uint in its underlying type.
4497 statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ",
4498 to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4499 }
4500 }
4501 else
4502 {
4503 uint32_t result_type = ops[0];
4504 uint32_t id = ops[1];
4505 forced_temporaries.insert(ops[1]);
4506
4507 auto &type = get<SPIRType>(result_type);
4508 statement(variable_decl(type, to_name(id)), ";");
4509
4510 auto &data_type = expression_type(ops[2]);
4511 auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4512 SPIRType::BaseType expr_type;
4513 if (data_type.storage == StorageClassImage || !chain)
4514 {
4515 statement(atomic_op, "(", to_non_uniform_aware_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
4516 expr_type = data_type.basetype;
4517 }
4518 else
4519 {
4520 // RWByteAddress buffer is always uint in its underlying type.
4521 string base = chain->base;
4522 if (has_decoration(chain->self, DecorationNonUniform))
4523 convert_non_uniform_expression(base, chain->self);
4524 expr_type = SPIRType::UInt;
4525 statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
4526 ", ", to_name(id), ");");
4527 }
4528
4529 auto expr = bitcast_expression(type, expr_type, to_name(id));
4530 set<SPIRExpression>(id, expr, result_type, true);
4531 }
4532 flush_all_atomic_capable_variables();
4533 }
4534
emit_subgroup_op(const Instruction & i)4535 void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4536 {
4537 if (hlsl_options.shader_model < 60)
4538 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4539
4540 const uint32_t *ops = stream(i);
4541 auto op = static_cast<Op>(i.op);
4542
4543 uint32_t result_type = ops[0];
4544 uint32_t id = ops[1];
4545
4546 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
4547 if (scope != ScopeSubgroup)
4548 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4549
4550 const auto make_inclusive_Sum = [&](const string &expr) -> string {
4551 return join(expr, " + ", to_expression(ops[4]));
4552 };
4553
4554 const auto make_inclusive_Product = [&](const string &expr) -> string {
4555 return join(expr, " * ", to_expression(ops[4]));
4556 };
4557
4558 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4559 uint32_t integer_width = get_integer_width_for_instruction(i);
4560 auto int_type = to_signed_basetype(integer_width);
4561 auto uint_type = to_unsigned_basetype(integer_width);
4562
4563 #define make_inclusive_BitAnd(expr) ""
4564 #define make_inclusive_BitOr(expr) ""
4565 #define make_inclusive_BitXor(expr) ""
4566 #define make_inclusive_Min(expr) ""
4567 #define make_inclusive_Max(expr) ""
4568
4569 switch (op)
4570 {
4571 case OpGroupNonUniformElect:
4572 emit_op(result_type, id, "WaveIsFirstLane()", true);
4573 break;
4574
4575 case OpGroupNonUniformBroadcast:
4576 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4577 break;
4578
4579 case OpGroupNonUniformBroadcastFirst:
4580 emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
4581 break;
4582
4583 case OpGroupNonUniformBallot:
4584 emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
4585 break;
4586
4587 case OpGroupNonUniformInverseBallot:
4588 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4589
4590 case OpGroupNonUniformBallotBitExtract:
4591 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4592
4593 case OpGroupNonUniformBallotFindLSB:
4594 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4595
4596 case OpGroupNonUniformBallotFindMSB:
4597 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4598
4599 case OpGroupNonUniformBallotBitCount:
4600 {
4601 auto operation = static_cast<GroupOperation>(ops[3]);
4602 if (operation == GroupOperationReduce)
4603 {
4604 bool forward = should_forward(ops[4]);
4605 auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
4606 to_enclosed_expression(ops[4]), ".y)");
4607 auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
4608 to_enclosed_expression(ops[4]), ".w)");
4609 emit_op(result_type, id, join(left, " + ", right), forward);
4610 inherit_expression_dependencies(id, ops[4]);
4611 }
4612 else if (operation == GroupOperationInclusiveScan)
4613 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4614 else if (operation == GroupOperationExclusiveScan)
4615 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4616 else
4617 SPIRV_CROSS_THROW("Invalid BitCount operation.");
4618 break;
4619 }
4620
4621 case OpGroupNonUniformShuffle:
4622 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4623 break;
4624 case OpGroupNonUniformShuffleXor:
4625 {
4626 bool forward = should_forward(ops[3]);
4627 emit_op(ops[0], ops[1],
4628 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4629 "WaveGetLaneIndex() ^ ", to_enclosed_expression(ops[4]), ")"), forward);
4630 inherit_expression_dependencies(ops[1], ops[3]);
4631 break;
4632 }
4633 case OpGroupNonUniformShuffleUp:
4634 {
4635 bool forward = should_forward(ops[3]);
4636 emit_op(ops[0], ops[1],
4637 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4638 "WaveGetLaneIndex() - ", to_enclosed_expression(ops[4]), ")"), forward);
4639 inherit_expression_dependencies(ops[1], ops[3]);
4640 break;
4641 }
4642 case OpGroupNonUniformShuffleDown:
4643 {
4644 bool forward = should_forward(ops[3]);
4645 emit_op(ops[0], ops[1],
4646 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4647 "WaveGetLaneIndex() + ", to_enclosed_expression(ops[4]), ")"), forward);
4648 inherit_expression_dependencies(ops[1], ops[3]);
4649 break;
4650 }
4651
4652 case OpGroupNonUniformAll:
4653 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
4654 break;
4655
4656 case OpGroupNonUniformAny:
4657 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
4658 break;
4659
4660 case OpGroupNonUniformAllEqual:
4661 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
4662 break;
4663
4664 // clang-format off
4665 #define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4666 case OpGroupNonUniform##op: \
4667 { \
4668 auto operation = static_cast<GroupOperation>(ops[3]); \
4669 if (operation == GroupOperationReduce) \
4670 emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4671 else if (operation == GroupOperationInclusiveScan && supports_scan) \
4672 { \
4673 bool forward = should_forward(ops[4]); \
4674 emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4675 inherit_expression_dependencies(id, ops[4]); \
4676 } \
4677 else if (operation == GroupOperationExclusiveScan && supports_scan) \
4678 emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4679 else if (operation == GroupOperationClusteredReduce) \
4680 SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4681 else \
4682 SPIRV_CROSS_THROW("Invalid group operation."); \
4683 break; \
4684 }
4685
4686 #define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4687 case OpGroupNonUniform##op: \
4688 { \
4689 auto operation = static_cast<GroupOperation>(ops[3]); \
4690 if (operation == GroupOperationReduce) \
4691 emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4692 else \
4693 SPIRV_CROSS_THROW("Invalid group operation."); \
4694 break; \
4695 }
4696
4697 HLSL_GROUP_OP(FAdd, Sum, true)
4698 HLSL_GROUP_OP(FMul, Product, true)
4699 HLSL_GROUP_OP(FMin, Min, false)
4700 HLSL_GROUP_OP(FMax, Max, false)
4701 HLSL_GROUP_OP(IAdd, Sum, true)
4702 HLSL_GROUP_OP(IMul, Product, true)
4703 HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4704 HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4705 HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4706 HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4707 HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4708 HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4709 HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4710 HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
4711 HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
4712 HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
4713
4714 #undef HLSL_GROUP_OP
4715 #undef HLSL_GROUP_OP_CAST
4716 // clang-format on
4717
4718 case OpGroupNonUniformQuadSwap:
4719 {
4720 uint32_t direction = evaluate_constant_u32(ops[4]);
4721 if (direction == 0)
4722 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
4723 else if (direction == 1)
4724 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
4725 else if (direction == 2)
4726 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
4727 else
4728 SPIRV_CROSS_THROW("Invalid quad swap direction.");
4729 break;
4730 }
4731
4732 case OpGroupNonUniformQuadBroadcast:
4733 {
4734 emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
4735 break;
4736 }
4737
4738 default:
4739 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4740 }
4741
4742 register_control_dependent_expression(id);
4743 }
4744
emit_instruction(const Instruction & instruction)4745 void CompilerHLSL::emit_instruction(const Instruction &instruction)
4746 {
4747 auto ops = stream(instruction);
4748 auto opcode = static_cast<Op>(instruction.op);
4749
4750 #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4751 #define HLSL_BOP_CAST(op, type) \
4752 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4753 #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4754 #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4755 #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4756 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4757 #define HLSL_BFOP_CAST(op, type) \
4758 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4759 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4760 #define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4761
4762 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4763 uint32_t integer_width = get_integer_width_for_instruction(instruction);
4764 auto int_type = to_signed_basetype(integer_width);
4765 auto uint_type = to_unsigned_basetype(integer_width);
4766
4767 switch (opcode)
4768 {
4769 case OpAccessChain:
4770 case OpInBoundsAccessChain:
4771 {
4772 emit_access_chain(instruction);
4773 break;
4774 }
4775 case OpBitcast:
4776 {
4777 auto bitcast_type = get_bitcast_type(ops[0], ops[2]);
4778 if (bitcast_type == CompilerHLSL::TypeNormal)
4779 CompilerGLSL::emit_instruction(instruction);
4780 else
4781 {
4782 if (!requires_uint2_packing)
4783 {
4784 requires_uint2_packing = true;
4785 force_recompile();
4786 }
4787
4788 if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4789 emit_unary_func_op(ops[0], ops[1], ops[2], "spvPackUint2x32");
4790 else
4791 emit_unary_func_op(ops[0], ops[1], ops[2], "spvUnpackUint2x32");
4792 }
4793
4794 break;
4795 }
4796
4797 case OpSelect:
4798 {
4799 auto &value_type = expression_type(ops[3]);
4800 if (value_type.basetype == SPIRType::Struct || is_array(value_type))
4801 {
4802 // HLSL does not support ternary expressions on composites.
4803 // Cannot use branches, since we might be in a continue block
4804 // where explicit control flow is prohibited.
4805 // Emit a helper function where we can use control flow.
4806 TypeID value_type_id = expression_type_id(ops[3]);
4807 auto itr = std::find(composite_selection_workaround_types.begin(),
4808 composite_selection_workaround_types.end(),
4809 value_type_id);
4810 if (itr == composite_selection_workaround_types.end())
4811 {
4812 composite_selection_workaround_types.push_back(value_type_id);
4813 force_recompile();
4814 }
4815 emit_uninitialized_temporary_expression(ops[0], ops[1]);
4816 statement("spvSelectComposite(",
4817 to_expression(ops[1]), ", ", to_expression(ops[2]), ", ",
4818 to_expression(ops[3]), ", ", to_expression(ops[4]), ");");
4819 }
4820 else
4821 CompilerGLSL::emit_instruction(instruction);
4822 break;
4823 }
4824
4825 case OpStore:
4826 {
4827 emit_store(instruction);
4828 break;
4829 }
4830
4831 case OpLoad:
4832 {
4833 emit_load(instruction);
4834 break;
4835 }
4836
4837 case OpMatrixTimesVector:
4838 {
4839 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4840 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4841 break;
4842 }
4843
4844 case OpVectorTimesMatrix:
4845 {
4846 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4847 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4848 break;
4849 }
4850
4851 case OpMatrixTimesMatrix:
4852 {
4853 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4854 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4855 break;
4856 }
4857
4858 case OpOuterProduct:
4859 {
4860 uint32_t result_type = ops[0];
4861 uint32_t id = ops[1];
4862 uint32_t a = ops[2];
4863 uint32_t b = ops[3];
4864
4865 auto &type = get<SPIRType>(result_type);
4866 string expr = type_to_glsl_constructor(type);
4867 expr += "(";
4868 for (uint32_t col = 0; col < type.columns; col++)
4869 {
4870 expr += to_enclosed_expression(a);
4871 expr += " * ";
4872 expr += to_extract_component_expression(b, col);
4873 if (col + 1 < type.columns)
4874 expr += ", ";
4875 }
4876 expr += ")";
4877 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4878 inherit_expression_dependencies(id, a);
4879 inherit_expression_dependencies(id, b);
4880 break;
4881 }
4882
4883 case OpFMod:
4884 {
4885 if (!requires_op_fmod)
4886 {
4887 requires_op_fmod = true;
4888 force_recompile();
4889 }
4890 CompilerGLSL::emit_instruction(instruction);
4891 break;
4892 }
4893
4894 case OpFRem:
4895 emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], "fmod");
4896 break;
4897
4898 case OpImage:
4899 {
4900 uint32_t result_type = ops[0];
4901 uint32_t id = ops[1];
4902 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4903
4904 if (combined)
4905 {
4906 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4907 auto *var = maybe_get_backing_variable(combined->image);
4908 if (var)
4909 e.loaded_from = var->self;
4910 }
4911 else
4912 {
4913 auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4914 auto *var = maybe_get_backing_variable(ops[2]);
4915 if (var)
4916 e.loaded_from = var->self;
4917 }
4918 break;
4919 }
4920
4921 case OpDPdx:
4922 HLSL_UFOP(ddx);
4923 register_control_dependent_expression(ops[1]);
4924 break;
4925
4926 case OpDPdy:
4927 HLSL_UFOP(ddy);
4928 register_control_dependent_expression(ops[1]);
4929 break;
4930
4931 case OpDPdxFine:
4932 HLSL_UFOP(ddx_fine);
4933 register_control_dependent_expression(ops[1]);
4934 break;
4935
4936 case OpDPdyFine:
4937 HLSL_UFOP(ddy_fine);
4938 register_control_dependent_expression(ops[1]);
4939 break;
4940
4941 case OpDPdxCoarse:
4942 HLSL_UFOP(ddx_coarse);
4943 register_control_dependent_expression(ops[1]);
4944 break;
4945
4946 case OpDPdyCoarse:
4947 HLSL_UFOP(ddy_coarse);
4948 register_control_dependent_expression(ops[1]);
4949 break;
4950
4951 case OpFwidth:
4952 case OpFwidthCoarse:
4953 case OpFwidthFine:
4954 HLSL_UFOP(fwidth);
4955 register_control_dependent_expression(ops[1]);
4956 break;
4957
4958 case OpLogicalNot:
4959 {
4960 auto result_type = ops[0];
4961 auto id = ops[1];
4962 auto &type = get<SPIRType>(result_type);
4963
4964 if (type.vecsize > 1)
4965 emit_unrolled_unary_op(result_type, id, ops[2], "!");
4966 else
4967 HLSL_UOP(!);
4968 break;
4969 }
4970
4971 case OpIEqual:
4972 {
4973 auto result_type = ops[0];
4974 auto id = ops[1];
4975
4976 if (expression_type(ops[2]).vecsize > 1)
4977 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4978 else
4979 HLSL_BOP_CAST(==, int_type);
4980 break;
4981 }
4982
4983 case OpLogicalEqual:
4984 case OpFOrdEqual:
4985 case OpFUnordEqual:
4986 {
4987 // HLSL != operator is unordered.
4988 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4989 // isnan() is apparently implemented as x != x as well.
4990 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
4991 // HACK: FUnordEqual will be implemented as FOrdEqual.
4992
4993 auto result_type = ops[0];
4994 auto id = ops[1];
4995
4996 if (expression_type(ops[2]).vecsize > 1)
4997 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4998 else
4999 HLSL_BOP(==);
5000 break;
5001 }
5002
5003 case OpINotEqual:
5004 {
5005 auto result_type = ops[0];
5006 auto id = ops[1];
5007
5008 if (expression_type(ops[2]).vecsize > 1)
5009 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5010 else
5011 HLSL_BOP_CAST(!=, int_type);
5012 break;
5013 }
5014
5015 case OpLogicalNotEqual:
5016 case OpFOrdNotEqual:
5017 case OpFUnordNotEqual:
5018 {
5019 // HLSL != operator is unordered.
5020 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5021 // isnan() is apparently implemented as x != x as well.
5022
5023 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5024 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
5025 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
5026 // SPIR-V will generally use isnan() checks when this even matters.
5027 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
5028
5029 auto result_type = ops[0];
5030 auto id = ops[1];
5031
5032 if (expression_type(ops[2]).vecsize > 1)
5033 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5034 else
5035 HLSL_BOP(!=);
5036 break;
5037 }
5038
5039 case OpUGreaterThan:
5040 case OpSGreaterThan:
5041 {
5042 auto result_type = ops[0];
5043 auto id = ops[1];
5044 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5045
5046 if (expression_type(ops[2]).vecsize > 1)
5047 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, type);
5048 else
5049 HLSL_BOP_CAST(>, type);
5050 break;
5051 }
5052
5053 case OpFOrdGreaterThan:
5054 {
5055 auto result_type = ops[0];
5056 auto id = ops[1];
5057
5058 if (expression_type(ops[2]).vecsize > 1)
5059 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, SPIRType::Unknown);
5060 else
5061 HLSL_BOP(>);
5062 break;
5063 }
5064
5065 case OpFUnordGreaterThan:
5066 {
5067 auto result_type = ops[0];
5068 auto id = ops[1];
5069
5070 if (expression_type(ops[2]).vecsize > 1)
5071 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", true, SPIRType::Unknown);
5072 else
5073 CompilerGLSL::emit_instruction(instruction);
5074 break;
5075 }
5076
5077 case OpUGreaterThanEqual:
5078 case OpSGreaterThanEqual:
5079 {
5080 auto result_type = ops[0];
5081 auto id = ops[1];
5082
5083 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5084 if (expression_type(ops[2]).vecsize > 1)
5085 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, type);
5086 else
5087 HLSL_BOP_CAST(>=, type);
5088 break;
5089 }
5090
5091 case OpFOrdGreaterThanEqual:
5092 {
5093 auto result_type = ops[0];
5094 auto id = ops[1];
5095
5096 if (expression_type(ops[2]).vecsize > 1)
5097 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, SPIRType::Unknown);
5098 else
5099 HLSL_BOP(>=);
5100 break;
5101 }
5102
5103 case OpFUnordGreaterThanEqual:
5104 {
5105 auto result_type = ops[0];
5106 auto id = ops[1];
5107
5108 if (expression_type(ops[2]).vecsize > 1)
5109 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", true, SPIRType::Unknown);
5110 else
5111 CompilerGLSL::emit_instruction(instruction);
5112 break;
5113 }
5114
5115 case OpULessThan:
5116 case OpSLessThan:
5117 {
5118 auto result_type = ops[0];
5119 auto id = ops[1];
5120
5121 auto type = opcode == OpULessThan ? uint_type : int_type;
5122 if (expression_type(ops[2]).vecsize > 1)
5123 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, type);
5124 else
5125 HLSL_BOP_CAST(<, type);
5126 break;
5127 }
5128
5129 case OpFOrdLessThan:
5130 {
5131 auto result_type = ops[0];
5132 auto id = ops[1];
5133
5134 if (expression_type(ops[2]).vecsize > 1)
5135 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, SPIRType::Unknown);
5136 else
5137 HLSL_BOP(<);
5138 break;
5139 }
5140
5141 case OpFUnordLessThan:
5142 {
5143 auto result_type = ops[0];
5144 auto id = ops[1];
5145
5146 if (expression_type(ops[2]).vecsize > 1)
5147 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", true, SPIRType::Unknown);
5148 else
5149 CompilerGLSL::emit_instruction(instruction);
5150 break;
5151 }
5152
5153 case OpULessThanEqual:
5154 case OpSLessThanEqual:
5155 {
5156 auto result_type = ops[0];
5157 auto id = ops[1];
5158
5159 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5160 if (expression_type(ops[2]).vecsize > 1)
5161 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, type);
5162 else
5163 HLSL_BOP_CAST(<=, type);
5164 break;
5165 }
5166
5167 case OpFOrdLessThanEqual:
5168 {
5169 auto result_type = ops[0];
5170 auto id = ops[1];
5171
5172 if (expression_type(ops[2]).vecsize > 1)
5173 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, SPIRType::Unknown);
5174 else
5175 HLSL_BOP(<=);
5176 break;
5177 }
5178
5179 case OpFUnordLessThanEqual:
5180 {
5181 auto result_type = ops[0];
5182 auto id = ops[1];
5183
5184 if (expression_type(ops[2]).vecsize > 1)
5185 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", true, SPIRType::Unknown);
5186 else
5187 CompilerGLSL::emit_instruction(instruction);
5188 break;
5189 }
5190
5191 case OpImageQueryLod:
5192 emit_texture_op(instruction, false);
5193 break;
5194
5195 case OpImageQuerySizeLod:
5196 {
5197 auto result_type = ops[0];
5198 auto id = ops[1];
5199
5200 require_texture_query_variant(ops[2]);
5201 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5202 statement("uint ", dummy_samples_levels, ";");
5203
5204 auto expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", ",
5205 bitcast_expression(SPIRType::UInt, ops[3]), ", ", dummy_samples_levels, ")");
5206
5207 auto &restype = get<SPIRType>(ops[0]);
5208 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5209 emit_op(result_type, id, expr, true);
5210 break;
5211 }
5212
5213 case OpImageQuerySize:
5214 {
5215 auto result_type = ops[0];
5216 auto id = ops[1];
5217
5218 require_texture_query_variant(ops[2]);
5219 bool uav = expression_type(ops[2]).image.sampled == 2;
5220
5221 if (const auto *var = maybe_get_backing_variable(ops[2]))
5222 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5223 uav = false;
5224
5225 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5226 statement("uint ", dummy_samples_levels, ";");
5227
5228 string expr;
5229 if (uav)
5230 expr = join("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", dummy_samples_levels, ")");
5231 else
5232 expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
5233
5234 auto &restype = get<SPIRType>(ops[0]);
5235 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5236 emit_op(result_type, id, expr, true);
5237 break;
5238 }
5239
5240 case OpImageQuerySamples:
5241 case OpImageQueryLevels:
5242 {
5243 auto result_type = ops[0];
5244 auto id = ops[1];
5245
5246 require_texture_query_variant(ops[2]);
5247 bool uav = expression_type(ops[2]).image.sampled == 2;
5248 if (opcode == OpImageQueryLevels && uav)
5249 SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5250
5251 if (const auto *var = maybe_get_backing_variable(ops[2]))
5252 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5253 uav = false;
5254
5255 // Keep it simple and do not emit special variants to make this look nicer ...
5256 // This stuff is barely, if ever, used.
5257 forced_temporaries.insert(id);
5258 auto &type = get<SPIRType>(result_type);
5259 statement(variable_decl(type, to_name(id)), ";");
5260
5261 if (uav)
5262 statement("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", to_name(id), ");");
5263 else
5264 statement("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", to_name(id), ");");
5265
5266 auto &restype = get<SPIRType>(ops[0]);
5267 auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
5268 set<SPIRExpression>(id, expr, result_type, true);
5269 break;
5270 }
5271
5272 case OpImageRead:
5273 {
5274 uint32_t result_type = ops[0];
5275 uint32_t id = ops[1];
5276 auto *var = maybe_get_backing_variable(ops[2]);
5277 auto &type = expression_type(ops[2]);
5278 bool subpass_data = type.image.dim == DimSubpassData;
5279 bool pure = false;
5280
5281 string imgexpr;
5282
5283 if (subpass_data)
5284 {
5285 if (hlsl_options.shader_model < 40)
5286 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5287
5288 // Similar to GLSL, implement subpass loads using texelFetch.
5289 if (type.image.ms)
5290 {
5291 uint32_t operands = ops[4];
5292 if (operands != ImageOperandsSampleMask || instruction.length != 6)
5293 SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5294 uint32_t sample = ops[5];
5295 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int2(gl_FragCoord.xy), ", to_expression(sample), ")");
5296 }
5297 else
5298 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int3(int2(gl_FragCoord.xy), 0))");
5299
5300 pure = true;
5301 }
5302 else
5303 {
5304 imgexpr = join(to_non_uniform_aware_expression(ops[2]), "[", to_expression(ops[3]), "]");
5305 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5306 // except that the underlying type changes how the data is interpreted.
5307
5308 bool force_srv =
5309 hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(var->self, DecorationNonWritable);
5310 pure = force_srv;
5311
5312 if (var && !subpass_data && !force_srv)
5313 imgexpr = remap_swizzle(get<SPIRType>(result_type),
5314 image_format_to_components(get<SPIRType>(var->basetype).image.format), imgexpr);
5315 }
5316
5317 if (var && var->forwardable)
5318 {
5319 bool forward = forced_temporaries.find(id) == end(forced_temporaries);
5320 auto &e = emit_op(result_type, id, imgexpr, forward);
5321
5322 if (!pure)
5323 {
5324 e.loaded_from = var->self;
5325 if (forward)
5326 var->dependees.push_back(id);
5327 }
5328 }
5329 else
5330 emit_op(result_type, id, imgexpr, false);
5331
5332 inherit_expression_dependencies(id, ops[2]);
5333 if (type.image.ms)
5334 inherit_expression_dependencies(id, ops[5]);
5335 break;
5336 }
5337
5338 case OpImageWrite:
5339 {
5340 auto *var = maybe_get_backing_variable(ops[0]);
5341
5342 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5343 // except that the underlying type changes how the data is interpreted.
5344 auto value_expr = to_expression(ops[2]);
5345 if (var)
5346 {
5347 auto &type = get<SPIRType>(var->basetype);
5348 auto narrowed_type = get<SPIRType>(type.image.type);
5349 narrowed_type.vecsize = image_format_to_components(type.image.format);
5350 value_expr = remap_swizzle(narrowed_type, expression_type(ops[2]).vecsize, value_expr);
5351 }
5352
5353 statement(to_non_uniform_aware_expression(ops[0]), "[", to_expression(ops[1]), "] = ", value_expr, ";");
5354 if (var && variable_storage_is_aliased(*var))
5355 flush_all_aliased_variables();
5356 break;
5357 }
5358
5359 case OpImageTexelPointer:
5360 {
5361 uint32_t result_type = ops[0];
5362 uint32_t id = ops[1];
5363
5364 auto expr = to_expression(ops[2]);
5365 expr += join("[", to_expression(ops[3]), "]");
5366 auto &e = set<SPIRExpression>(id, expr, result_type, true);
5367
5368 // When using the pointer, we need to know which variable it is actually loaded from.
5369 auto *var = maybe_get_backing_variable(ops[2]);
5370 e.loaded_from = var ? var->self : ID(0);
5371 inherit_expression_dependencies(id, ops[3]);
5372 break;
5373 }
5374
5375 case OpAtomicCompareExchange:
5376 case OpAtomicExchange:
5377 case OpAtomicISub:
5378 case OpAtomicSMin:
5379 case OpAtomicUMin:
5380 case OpAtomicSMax:
5381 case OpAtomicUMax:
5382 case OpAtomicAnd:
5383 case OpAtomicOr:
5384 case OpAtomicXor:
5385 case OpAtomicIAdd:
5386 case OpAtomicIIncrement:
5387 case OpAtomicIDecrement:
5388 case OpAtomicLoad:
5389 case OpAtomicStore:
5390 {
5391 emit_atomic(ops, instruction.length, opcode);
5392 break;
5393 }
5394
5395 case OpControlBarrier:
5396 case OpMemoryBarrier:
5397 {
5398 uint32_t memory;
5399 uint32_t semantics;
5400
5401 if (opcode == OpMemoryBarrier)
5402 {
5403 memory = evaluate_constant_u32(ops[0]);
5404 semantics = evaluate_constant_u32(ops[1]);
5405 }
5406 else
5407 {
5408 memory = evaluate_constant_u32(ops[1]);
5409 semantics = evaluate_constant_u32(ops[2]);
5410 }
5411
5412 if (memory == ScopeSubgroup)
5413 {
5414 // No Wave-barriers in HLSL.
5415 break;
5416 }
5417
5418 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
5419 semantics = mask_relevant_memory_semantics(semantics);
5420
5421 if (opcode == OpMemoryBarrier)
5422 {
5423 // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5424 // does what we need, so we avoid redundant barriers.
5425 const Instruction *next = get_next_instruction_in_block(instruction);
5426 if (next && next->op == OpControlBarrier)
5427 {
5428 auto *next_ops = stream(*next);
5429 uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
5430 uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
5431 next_semantics = mask_relevant_memory_semantics(next_semantics);
5432
5433 // There is no "just execution barrier" in HLSL.
5434 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5435 if (next_semantics == 0)
5436 next_semantics = MemorySemanticsWorkgroupMemoryMask;
5437
5438 bool memory_scope_covered = false;
5439 if (next_memory == memory)
5440 memory_scope_covered = true;
5441 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5442 {
5443 // If we only care about workgroup memory, either Device or Workgroup scope is fine,
5444 // scope does not have to match.
5445 if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5446 (memory == ScopeDevice || memory == ScopeWorkgroup))
5447 {
5448 memory_scope_covered = true;
5449 }
5450 }
5451 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5452 {
5453 // The control barrier has device scope, but the memory barrier just has workgroup scope.
5454 memory_scope_covered = true;
5455 }
5456
5457 // If we have the same memory scope, and all memory types are covered, we're good.
5458 if (memory_scope_covered && (semantics & next_semantics) == semantics)
5459 break;
5460 }
5461 }
5462
5463 // We are synchronizing some memory or syncing execution,
5464 // so we cannot forward any loads beyond the memory barrier.
5465 if (semantics || opcode == OpControlBarrier)
5466 {
5467 assert(current_emitting_block);
5468 flush_control_dependent_expressions(current_emitting_block->self);
5469 flush_all_active_variables();
5470 }
5471
5472 if (opcode == OpControlBarrier)
5473 {
5474 // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5475 if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5476 statement("GroupMemoryBarrierWithGroupSync();");
5477 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5478 statement("DeviceMemoryBarrierWithGroupSync();");
5479 else
5480 statement("AllMemoryBarrierWithGroupSync();");
5481 }
5482 else
5483 {
5484 if (semantics == MemorySemanticsWorkgroupMemoryMask)
5485 statement("GroupMemoryBarrier();");
5486 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5487 statement("DeviceMemoryBarrier();");
5488 else
5489 statement("AllMemoryBarrier();");
5490 }
5491 break;
5492 }
5493
5494 case OpBitFieldInsert:
5495 {
5496 if (!requires_bitfield_insert)
5497 {
5498 requires_bitfield_insert = true;
5499 force_recompile();
5500 }
5501
5502 auto expr = join("spvBitfieldInsert(", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ",
5503 to_expression(ops[4]), ", ", to_expression(ops[5]), ")");
5504
5505 bool forward =
5506 should_forward(ops[2]) && should_forward(ops[3]) && should_forward(ops[4]) && should_forward(ops[5]);
5507
5508 auto &restype = get<SPIRType>(ops[0]);
5509 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5510 emit_op(ops[0], ops[1], expr, forward);
5511 break;
5512 }
5513
5514 case OpBitFieldSExtract:
5515 case OpBitFieldUExtract:
5516 {
5517 if (!requires_bitfield_extract)
5518 {
5519 requires_bitfield_extract = true;
5520 force_recompile();
5521 }
5522
5523 if (opcode == OpBitFieldSExtract)
5524 HLSL_TFOP(spvBitfieldSExtract);
5525 else
5526 HLSL_TFOP(spvBitfieldUExtract);
5527 break;
5528 }
5529
5530 case OpBitCount:
5531 {
5532 auto basetype = expression_type(ops[2]).basetype;
5533 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "countbits", basetype, basetype);
5534 break;
5535 }
5536
5537 case OpBitReverse:
5538 HLSL_UFOP(reversebits);
5539 break;
5540
5541 case OpArrayLength:
5542 {
5543 auto *var = maybe_get_backing_variable(ops[2]);
5544 if (!var)
5545 SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5546
5547 auto &type = get<SPIRType>(var->basetype);
5548 if (!has_decoration(type.self, DecorationBlock) && !has_decoration(type.self, DecorationBufferBlock))
5549 SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5550
5551 // This must be 32-bit uint, so we're good to go.
5552 emit_uninitialized_temporary_expression(ops[0], ops[1]);
5553 statement(to_non_uniform_aware_expression(ops[2]), ".GetDimensions(", to_expression(ops[1]), ");");
5554 uint32_t offset = type_struct_member_offset(type, ops[3]);
5555 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
5556 statement(to_expression(ops[1]), " = (", to_expression(ops[1]), " - ", offset, ") / ", stride, ";");
5557 break;
5558 }
5559
5560 case OpIsHelperInvocationEXT:
5561 SPIRV_CROSS_THROW("helperInvocationEXT() is not supported in HLSL.");
5562
5563 case OpBeginInvocationInterlockEXT:
5564 case OpEndInvocationInterlockEXT:
5565 if (hlsl_options.shader_model < 51)
5566 SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5567 break; // Nothing to do in the body
5568
5569 default:
5570 CompilerGLSL::emit_instruction(instruction);
5571 break;
5572 }
5573 }
5574
require_texture_query_variant(uint32_t var_id)5575 void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5576 {
5577 if (const auto *var = maybe_get_backing_variable(var_id))
5578 var_id = var->self;
5579
5580 auto &type = expression_type(var_id);
5581 bool uav = type.image.sampled == 2;
5582 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
5583 uav = false;
5584
5585 uint32_t bit = 0;
5586 switch (type.image.dim)
5587 {
5588 case Dim1D:
5589 bit = type.image.arrayed ? Query1DArray : Query1D;
5590 break;
5591
5592 case Dim2D:
5593 if (type.image.ms)
5594 bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5595 else
5596 bit = type.image.arrayed ? Query2DArray : Query2D;
5597 break;
5598
5599 case Dim3D:
5600 bit = Query3D;
5601 break;
5602
5603 case DimCube:
5604 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5605 break;
5606
5607 case DimBuffer:
5608 bit = QueryBuffer;
5609 break;
5610
5611 default:
5612 SPIRV_CROSS_THROW("Unsupported query type.");
5613 }
5614
5615 switch (get<SPIRType>(type.image.type).basetype)
5616 {
5617 case SPIRType::Float:
5618 bit += QueryTypeFloat;
5619 break;
5620
5621 case SPIRType::Int:
5622 bit += QueryTypeInt;
5623 break;
5624
5625 case SPIRType::UInt:
5626 bit += QueryTypeUInt;
5627 break;
5628
5629 default:
5630 SPIRV_CROSS_THROW("Unsupported query type.");
5631 }
5632
5633 auto norm_state = image_format_to_normalized_state(type.image.format);
5634 auto &variant = uav ? required_texture_size_variants
5635 .uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
5636 required_texture_size_variants.srv;
5637
5638 uint64_t mask = 1ull << bit;
5639 if ((variant & mask) == 0)
5640 {
5641 force_recompile();
5642 variant |= mask;
5643 }
5644 }
5645
set_root_constant_layouts(std::vector<RootConstants> layout)5646 void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5647 {
5648 root_constants_layout = move(layout);
5649 }
5650
add_vertex_attribute_remap(const HLSLVertexAttributeRemap & vertex_attributes)5651 void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5652 {
5653 remap_vertex_attributes.push_back(vertex_attributes);
5654 }
5655
remap_num_workgroups_builtin()5656 VariableID CompilerHLSL::remap_num_workgroups_builtin()
5657 {
5658 update_active_builtins();
5659
5660 if (!active_input_builtins.get(BuiltInNumWorkgroups))
5661 return 0;
5662
5663 // Create a new, fake UBO.
5664 uint32_t offset = ir.increase_bound_by(4);
5665
5666 uint32_t uint_type_id = offset;
5667 uint32_t block_type_id = offset + 1;
5668 uint32_t block_pointer_type_id = offset + 2;
5669 uint32_t variable_id = offset + 3;
5670
5671 SPIRType uint_type;
5672 uint_type.basetype = SPIRType::UInt;
5673 uint_type.width = 32;
5674 uint_type.vecsize = 3;
5675 uint_type.columns = 1;
5676 set<SPIRType>(uint_type_id, uint_type);
5677
5678 SPIRType block_type;
5679 block_type.basetype = SPIRType::Struct;
5680 block_type.member_types.push_back(uint_type_id);
5681 set<SPIRType>(block_type_id, block_type);
5682 set_decoration(block_type_id, DecorationBlock);
5683 set_member_name(block_type_id, 0, "count");
5684 set_member_decoration(block_type_id, 0, DecorationOffset, 0);
5685
5686 SPIRType block_pointer_type = block_type;
5687 block_pointer_type.pointer = true;
5688 block_pointer_type.storage = StorageClassUniform;
5689 block_pointer_type.parent_type = block_type_id;
5690 auto &ptr_type = set<SPIRType>(block_pointer_type_id, block_pointer_type);
5691
5692 // Preserve self.
5693 ptr_type.self = block_type_id;
5694
5695 set<SPIRVariable>(variable_id, block_pointer_type_id, StorageClassUniform);
5696 ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5697
5698 num_workgroups_builtin = variable_id;
5699 return variable_id;
5700 }
5701
set_resource_binding_flags(HLSLBindingFlags flags)5702 void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5703 {
5704 resource_binding_flags = flags;
5705 }
5706
validate_shader_model()5707 void CompilerHLSL::validate_shader_model()
5708 {
5709 // Check for nonuniform qualifier.
5710 // Instead of looping over all decorations to find this, just look at capabilities.
5711 for (auto &cap : ir.declared_capabilities)
5712 {
5713 switch (cap)
5714 {
5715 case CapabilityShaderNonUniformEXT:
5716 case CapabilityRuntimeDescriptorArrayEXT:
5717 if (hlsl_options.shader_model < 51)
5718 SPIRV_CROSS_THROW(
5719 "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5720 break;
5721
5722 case CapabilityVariablePointers:
5723 case CapabilityVariablePointersStorageBuffer:
5724 SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5725
5726 default:
5727 break;
5728 }
5729 }
5730
5731 if (ir.addressing_model != AddressingModelLogical)
5732 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5733
5734 if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5735 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5736 }
5737
compile()5738 string CompilerHLSL::compile()
5739 {
5740 ir.fixup_reserved_names();
5741
5742 // Do not deal with ES-isms like precision, older extensions and such.
5743 options.es = false;
5744 options.version = 450;
5745 options.vulkan_semantics = true;
5746 backend.float_literal_suffix = true;
5747 backend.double_literal_suffix = false;
5748 backend.long_long_literal_suffix = true;
5749 backend.uint32_t_literal_suffix = true;
5750 backend.int16_t_literal_suffix = "";
5751 backend.uint16_t_literal_suffix = "u";
5752 backend.basic_int_type = "int";
5753 backend.basic_uint_type = "uint";
5754 backend.demote_literal = "discard";
5755 backend.boolean_mix_function = "";
5756 backend.swizzle_is_function = false;
5757 backend.shared_is_implied = true;
5758 backend.unsized_array_supported = true;
5759 backend.explicit_struct_type = false;
5760 backend.use_initializer_list = true;
5761 backend.use_constructor_splatting = false;
5762 backend.can_swizzle_scalar = true;
5763 backend.can_declare_struct_inline = false;
5764 backend.can_declare_arrays_inline = false;
5765 backend.can_return_array = false;
5766 backend.nonuniform_qualifier = "NonUniformResourceIndex";
5767 backend.support_case_fallthrough = false;
5768
5769 // SM 4.1 does not support precise for some reason.
5770 backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
5771
5772 fixup_type_alias();
5773 reorder_type_alias();
5774 build_function_control_flow_graphs_and_analyze();
5775 validate_shader_model();
5776 update_active_builtins();
5777 analyze_image_and_sampler_usage();
5778 analyze_interlocked_resource_usage();
5779
5780 // Subpass input needs SV_Position.
5781 if (need_subpass_input)
5782 active_input_builtins.set(BuiltInFragCoord);
5783
5784 uint32_t pass_count = 0;
5785 do
5786 {
5787 if (pass_count >= 3)
5788 SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
5789
5790 reset();
5791
5792 // Move constructor for this type is broken on GCC 4.9 ...
5793 buffer.reset();
5794
5795 emit_header();
5796 emit_resources();
5797
5798 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
5799 emit_hlsl_entry_point();
5800
5801 pass_count++;
5802 } while (is_forcing_recompilation());
5803
5804 // Entry point in HLSL is always main() for the time being.
5805 get_entry_point().name = "main";
5806
5807 return buffer.str();
5808 }
5809
emit_block_hints(const SPIRBlock & block)5810 void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
5811 {
5812 switch (block.hint)
5813 {
5814 case SPIRBlock::HintFlatten:
5815 statement("[flatten]");
5816 break;
5817 case SPIRBlock::HintDontFlatten:
5818 statement("[branch]");
5819 break;
5820 case SPIRBlock::HintUnroll:
5821 statement("[unroll]");
5822 break;
5823 case SPIRBlock::HintDontUnroll:
5824 statement("[loop]");
5825 break;
5826 default:
5827 break;
5828 }
5829 }
5830
get_unique_identifier()5831 string CompilerHLSL::get_unique_identifier()
5832 {
5833 return join("_", unique_identifier_count++, "ident");
5834 }
5835
add_hlsl_resource_binding(const HLSLResourceBinding & binding)5836 void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
5837 {
5838 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
5839 resource_bindings[tuple] = { binding, false };
5840 }
5841
is_hlsl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const5842 bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
5843 {
5844 StageSetBinding tuple = { model, desc_set, binding };
5845 auto itr = resource_bindings.find(tuple);
5846 return itr != end(resource_bindings) && itr->second.second;
5847 }
5848
get_bitcast_type(uint32_t result_type,uint32_t op0)5849 CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
5850 {
5851 auto &rslt_type = get<SPIRType>(result_type);
5852 auto &expr_type = expression_type(op0);
5853
5854 if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
5855 expr_type.vecsize == 2)
5856 return BitcastType::TypePackUint2x32;
5857 else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
5858 expr_type.basetype == SPIRType::BaseType::UInt64)
5859 return BitcastType::TypeUnpackUint64;
5860
5861 return BitcastType::TypeNormal;
5862 }
5863
is_hlsl_force_storage_buffer_as_uav(ID id) const5864 bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
5865 {
5866 if (hlsl_options.force_storage_buffer_as_uav)
5867 {
5868 return true;
5869 }
5870
5871 const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
5872 const uint32_t binding = get_decoration(id, spv::DecorationBinding);
5873
5874 return (force_uav_buffer_bindings.find({ desc_set, binding }) != force_uav_buffer_bindings.end());
5875 }
5876
set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set,uint32_t binding)5877 void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
5878 {
5879 SetBindingPair pair = { desc_set, binding };
5880 force_uav_buffer_bindings.insert(pair);
5881 }
5882
builtin_translates_to_nonarray(spv::BuiltIn builtin) const5883 bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
5884 {
5885 return (builtin == BuiltInSampleMask);
5886 }
5887