• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 Robert Konrad
3  * SPDX-License-Identifier: Apache-2.0 OR MIT
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 /*
20  * At your option, you may choose to accept this material under either:
21  *  1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22  *  2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23  */
24 
25 #include "spirv_hlsl.hpp"
26 #include "GLSL.std.450.h"
27 #include <algorithm>
28 #include <assert.h>
29 
30 using namespace spv;
31 using namespace SPIRV_CROSS_NAMESPACE;
32 using namespace std;
33 
34 enum class ImageFormatNormalizedState
35 {
36 	None = 0,
37 	Unorm = 1,
38 	Snorm = 2
39 };
40 
image_format_to_normalized_state(ImageFormat fmt)41 static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42 {
43 	switch (fmt)
44 	{
45 	case ImageFormatR8:
46 	case ImageFormatR16:
47 	case ImageFormatRg8:
48 	case ImageFormatRg16:
49 	case ImageFormatRgba8:
50 	case ImageFormatRgba16:
51 	case ImageFormatRgb10A2:
52 		return ImageFormatNormalizedState::Unorm;
53 
54 	case ImageFormatR8Snorm:
55 	case ImageFormatR16Snorm:
56 	case ImageFormatRg8Snorm:
57 	case ImageFormatRg16Snorm:
58 	case ImageFormatRgba8Snorm:
59 	case ImageFormatRgba16Snorm:
60 		return ImageFormatNormalizedState::Snorm;
61 
62 	default:
63 		break;
64 	}
65 
66 	return ImageFormatNormalizedState::None;
67 }
68 
image_format_to_components(ImageFormat fmt)69 static unsigned image_format_to_components(ImageFormat fmt)
70 {
71 	switch (fmt)
72 	{
73 	case ImageFormatR8:
74 	case ImageFormatR16:
75 	case ImageFormatR8Snorm:
76 	case ImageFormatR16Snorm:
77 	case ImageFormatR16f:
78 	case ImageFormatR32f:
79 	case ImageFormatR8i:
80 	case ImageFormatR16i:
81 	case ImageFormatR32i:
82 	case ImageFormatR8ui:
83 	case ImageFormatR16ui:
84 	case ImageFormatR32ui:
85 		return 1;
86 
87 	case ImageFormatRg8:
88 	case ImageFormatRg16:
89 	case ImageFormatRg8Snorm:
90 	case ImageFormatRg16Snorm:
91 	case ImageFormatRg16f:
92 	case ImageFormatRg32f:
93 	case ImageFormatRg8i:
94 	case ImageFormatRg16i:
95 	case ImageFormatRg32i:
96 	case ImageFormatRg8ui:
97 	case ImageFormatRg16ui:
98 	case ImageFormatRg32ui:
99 		return 2;
100 
101 	case ImageFormatR11fG11fB10f:
102 		return 3;
103 
104 	case ImageFormatRgba8:
105 	case ImageFormatRgba16:
106 	case ImageFormatRgb10A2:
107 	case ImageFormatRgba8Snorm:
108 	case ImageFormatRgba16Snorm:
109 	case ImageFormatRgba16f:
110 	case ImageFormatRgba32f:
111 	case ImageFormatRgba8i:
112 	case ImageFormatRgba16i:
113 	case ImageFormatRgba32i:
114 	case ImageFormatRgba8ui:
115 	case ImageFormatRgba16ui:
116 	case ImageFormatRgba32ui:
117 	case ImageFormatRgb10a2ui:
118 		return 4;
119 
120 	case ImageFormatUnknown:
121 		return 4; // Assume 4.
122 
123 	default:
124 		SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 	}
126 }
127 
image_format_to_type(ImageFormat fmt,SPIRType::BaseType basetype)128 static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129 {
130 	switch (fmt)
131 	{
132 	case ImageFormatR8:
133 	case ImageFormatR16:
134 		if (basetype != SPIRType::Float)
135 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 		return "unorm float";
137 	case ImageFormatRg8:
138 	case ImageFormatRg16:
139 		if (basetype != SPIRType::Float)
140 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 		return "unorm float2";
142 	case ImageFormatRgba8:
143 	case ImageFormatRgba16:
144 		if (basetype != SPIRType::Float)
145 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 		return "unorm float4";
147 	case ImageFormatRgb10A2:
148 		if (basetype != SPIRType::Float)
149 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 		return "unorm float4";
151 
152 	case ImageFormatR8Snorm:
153 	case ImageFormatR16Snorm:
154 		if (basetype != SPIRType::Float)
155 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 		return "snorm float";
157 	case ImageFormatRg8Snorm:
158 	case ImageFormatRg16Snorm:
159 		if (basetype != SPIRType::Float)
160 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 		return "snorm float2";
162 	case ImageFormatRgba8Snorm:
163 	case ImageFormatRgba16Snorm:
164 		if (basetype != SPIRType::Float)
165 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 		return "snorm float4";
167 
168 	case ImageFormatR16f:
169 	case ImageFormatR32f:
170 		if (basetype != SPIRType::Float)
171 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 		return "float";
173 	case ImageFormatRg16f:
174 	case ImageFormatRg32f:
175 		if (basetype != SPIRType::Float)
176 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 		return "float2";
178 	case ImageFormatRgba16f:
179 	case ImageFormatRgba32f:
180 		if (basetype != SPIRType::Float)
181 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 		return "float4";
183 
184 	case ImageFormatR11fG11fB10f:
185 		if (basetype != SPIRType::Float)
186 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 		return "float3";
188 
189 	case ImageFormatR8i:
190 	case ImageFormatR16i:
191 	case ImageFormatR32i:
192 		if (basetype != SPIRType::Int)
193 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 		return "int";
195 	case ImageFormatRg8i:
196 	case ImageFormatRg16i:
197 	case ImageFormatRg32i:
198 		if (basetype != SPIRType::Int)
199 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 		return "int2";
201 	case ImageFormatRgba8i:
202 	case ImageFormatRgba16i:
203 	case ImageFormatRgba32i:
204 		if (basetype != SPIRType::Int)
205 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 		return "int4";
207 
208 	case ImageFormatR8ui:
209 	case ImageFormatR16ui:
210 	case ImageFormatR32ui:
211 		if (basetype != SPIRType::UInt)
212 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 		return "uint";
214 	case ImageFormatRg8ui:
215 	case ImageFormatRg16ui:
216 	case ImageFormatRg32ui:
217 		if (basetype != SPIRType::UInt)
218 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 		return "uint2";
220 	case ImageFormatRgba8ui:
221 	case ImageFormatRgba16ui:
222 	case ImageFormatRgba32ui:
223 		if (basetype != SPIRType::UInt)
224 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 		return "uint4";
226 	case ImageFormatRgb10a2ui:
227 		if (basetype != SPIRType::UInt)
228 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 		return "uint4";
230 
231 	case ImageFormatUnknown:
232 		switch (basetype)
233 		{
234 		case SPIRType::Float:
235 			return "float4";
236 		case SPIRType::Int:
237 			return "int4";
238 		case SPIRType::UInt:
239 			return "uint4";
240 		default:
241 			SPIRV_CROSS_THROW("Unsupported base type for image.");
242 		}
243 
244 	default:
245 		SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 	}
247 }
248 
image_type_hlsl_modern(const SPIRType & type,uint32_t id)249 string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250 {
251 	auto &imagetype = get<SPIRType>(type.image.type);
252 	const char *dim = nullptr;
253 	bool typed_load = false;
254 	uint32_t components = 4;
255 
256 	bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, DecorationNonWritable);
257 
258 	switch (type.image.dim)
259 	{
260 	case Dim1D:
261 		typed_load = type.image.sampled == 2;
262 		dim = "1D";
263 		break;
264 	case Dim2D:
265 		typed_load = type.image.sampled == 2;
266 		dim = "2D";
267 		break;
268 	case Dim3D:
269 		typed_load = type.image.sampled == 2;
270 		dim = "3D";
271 		break;
272 	case DimCube:
273 		if (type.image.sampled == 2)
274 			SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 		dim = "Cube";
276 		break;
277 	case DimRect:
278 		SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 	case DimBuffer:
280 		if (type.image.sampled == 1)
281 			return join("Buffer<", type_to_glsl(imagetype), components, ">");
282 		else if (type.image.sampled == 2)
283 		{
284 			if (interlocked_resources.count(id))
285 				return join("RasterizerOrderedBuffer<", image_format_to_type(type.image.format, imagetype.basetype),
286 				            ">");
287 
288 			typed_load = !force_image_srv && type.image.sampled == 2;
289 
290 			const char *rw = force_image_srv ? "" : "RW";
291 			return join(rw, "Buffer<",
292 			            typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
293 			                         join(type_to_glsl(imagetype), components),
294 			            ">");
295 		}
296 		else
297 			SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 	case DimSubpassData:
299 		dim = "2D";
300 		typed_load = false;
301 		break;
302 	default:
303 		SPIRV_CROSS_THROW("Invalid dimension.");
304 	}
305 	const char *arrayed = type.image.arrayed ? "Array" : "";
306 	const char *ms = type.image.ms ? "MS" : "";
307 	const char *rw = typed_load && !force_image_srv ? "RW" : "";
308 
309 	if (force_image_srv)
310 		typed_load = false;
311 
312 	if (typed_load && interlocked_resources.count(id))
313 		rw = "RasterizerOrdered";
314 
315 	return join(rw, "Texture", dim, ms, arrayed, "<",
316 	            typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
317 	                         join(type_to_glsl(imagetype), components),
318 	            ">");
319 }
320 
image_type_hlsl_legacy(const SPIRType & type,uint32_t)321 string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322 {
323 	auto &imagetype = get<SPIRType>(type.image.type);
324 	string res;
325 
326 	switch (imagetype.basetype)
327 	{
328 	case SPIRType::Int:
329 		res = "i";
330 		break;
331 	case SPIRType::UInt:
332 		res = "u";
333 		break;
334 	default:
335 		break;
336 	}
337 
338 	if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 		return res + "subpassInput" + (type.image.ms ? "MS" : "");
340 
341 	// If we're emulating subpassInput with samplers, force sampler2D
342 	// so we don't have to specify format.
343 	if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 	{
345 		// Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 		if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 			res += "sampler";
348 		else
349 			res += type.image.sampled == 2 ? "image" : "texture";
350 	}
351 	else
352 		res += "sampler";
353 
354 	switch (type.image.dim)
355 	{
356 	case Dim1D:
357 		res += "1D";
358 		break;
359 	case Dim2D:
360 		res += "2D";
361 		break;
362 	case Dim3D:
363 		res += "3D";
364 		break;
365 	case DimCube:
366 		res += "CUBE";
367 		break;
368 
369 	case DimBuffer:
370 		res += "Buffer";
371 		break;
372 
373 	case DimSubpassData:
374 		res += "2D";
375 		break;
376 	default:
377 		SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 	}
379 
380 	if (type.image.ms)
381 		res += "MS";
382 	if (type.image.arrayed)
383 		res += "Array";
384 
385 	return res;
386 }
387 
image_type_hlsl(const SPIRType & type,uint32_t id)388 string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389 {
390 	if (hlsl_options.shader_model <= 30)
391 		return image_type_hlsl_legacy(type, id);
392 	else
393 		return image_type_hlsl_modern(type, id);
394 }
395 
396 // The optional id parameter indicates the object whose type we are trying
397 // to find the description for. It is optional. Most type descriptions do not
398 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)399 string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400 {
401 	// Ignore the pointer type since GLSL doesn't have pointers.
402 
403 	switch (type.basetype)
404 	{
405 	case SPIRType::Struct:
406 		// Need OpName lookup here to get a "sensible" name for a struct.
407 		if (backend.explicit_struct_type)
408 			return join("struct ", to_name(type.self));
409 		else
410 			return to_name(type.self);
411 
412 	case SPIRType::Image:
413 	case SPIRType::SampledImage:
414 		return image_type_hlsl(type, id);
415 
416 	case SPIRType::Sampler:
417 		return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState";
418 
419 	case SPIRType::Void:
420 		return "void";
421 
422 	default:
423 		break;
424 	}
425 
426 	if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 	{
428 		switch (type.basetype)
429 		{
430 		case SPIRType::Boolean:
431 			return "bool";
432 		case SPIRType::Int:
433 			return backend.basic_int_type;
434 		case SPIRType::UInt:
435 			return backend.basic_uint_type;
436 		case SPIRType::AtomicCounter:
437 			return "atomic_uint";
438 		case SPIRType::Half:
439 			if (hlsl_options.enable_16bit_types)
440 				return "half";
441 			else
442 				return "min16float";
443 		case SPIRType::Short:
444 			if (hlsl_options.enable_16bit_types)
445 				return "int16_t";
446 			else
447 				return "min16int";
448 		case SPIRType::UShort:
449 			if (hlsl_options.enable_16bit_types)
450 				return "uint16_t";
451 			else
452 				return "min16uint";
453 		case SPIRType::Float:
454 			return "float";
455 		case SPIRType::Double:
456 			return "double";
457 		case SPIRType::Int64:
458 			if (hlsl_options.shader_model < 60)
459 				SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 			return "int64_t";
461 		case SPIRType::UInt64:
462 			if (hlsl_options.shader_model < 60)
463 				SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 			return "uint64_t";
465 		default:
466 			return "???";
467 		}
468 	}
469 	else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
470 	{
471 		switch (type.basetype)
472 		{
473 		case SPIRType::Boolean:
474 			return join("bool", type.vecsize);
475 		case SPIRType::Int:
476 			return join("int", type.vecsize);
477 		case SPIRType::UInt:
478 			return join("uint", type.vecsize);
479 		case SPIRType::Half:
480 			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
481 		case SPIRType::Short:
482 			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
483 		case SPIRType::UShort:
484 			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
485 		case SPIRType::Float:
486 			return join("float", type.vecsize);
487 		case SPIRType::Double:
488 			return join("double", type.vecsize);
489 		case SPIRType::Int64:
490 			return join("i64vec", type.vecsize);
491 		case SPIRType::UInt64:
492 			return join("u64vec", type.vecsize);
493 		default:
494 			return "???";
495 		}
496 	}
497 	else
498 	{
499 		switch (type.basetype)
500 		{
501 		case SPIRType::Boolean:
502 			return join("bool", type.columns, "x", type.vecsize);
503 		case SPIRType::Int:
504 			return join("int", type.columns, "x", type.vecsize);
505 		case SPIRType::UInt:
506 			return join("uint", type.columns, "x", type.vecsize);
507 		case SPIRType::Half:
508 			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
509 		case SPIRType::Short:
510 			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
511 		case SPIRType::UShort:
512 			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
513 		case SPIRType::Float:
514 			return join("float", type.columns, "x", type.vecsize);
515 		case SPIRType::Double:
516 			return join("double", type.columns, "x", type.vecsize);
517 		// Matrix types not supported for int64/uint64.
518 		default:
519 			return "???";
520 		}
521 	}
522 }
523 
emit_header()524 void CompilerHLSL::emit_header()
525 {
526 	for (auto &header : header_lines)
527 		statement(header);
528 
529 	if (header_lines.size() > 0)
530 	{
531 		statement("");
532 	}
533 }
534 
emit_interface_block_globally(const SPIRVariable & var)535 void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
536 {
537 	add_resource_name(var.self);
538 
539 	// The global copies of I/O variables should not contain interpolation qualifiers.
540 	// These are emitted inside the interface structs.
541 	auto &flags = ir.meta[var.self].decoration.decoration_flags;
542 	auto old_flags = flags;
543 	flags.reset();
544 	statement("static ", variable_decl(var), ";");
545 	flags = old_flags;
546 }
547 
to_storage_qualifiers_glsl(const SPIRVariable & var)548 const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
549 {
550 	// Input and output variables are handled specially in HLSL backend.
551 	// The variables are declared as global, private variables, and do not need any qualifiers.
552 	if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
553 	    var.storage == StorageClassPushConstant)
554 	{
555 		return "uniform ";
556 	}
557 
558 	return "";
559 }
560 
emit_builtin_outputs_in_struct()561 void CompilerHLSL::emit_builtin_outputs_in_struct()
562 {
563 	auto &execution = get_entry_point();
564 
565 	bool legacy = hlsl_options.shader_model <= 30;
566 	active_output_builtins.for_each_bit([&](uint32_t i) {
567 		const char *type = nullptr;
568 		const char *semantic = nullptr;
569 		auto builtin = static_cast<BuiltIn>(i);
570 		switch (builtin)
571 		{
572 		case BuiltInPosition:
573 			type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
574 			semantic = legacy ? "POSITION" : "SV_Position";
575 			break;
576 
577 		case BuiltInSampleMask:
578 			if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
579 				SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
580 			type = "uint";
581 			semantic = "SV_Coverage";
582 			break;
583 
584 		case BuiltInFragDepth:
585 			type = "float";
586 			if (legacy)
587 			{
588 				semantic = "DEPTH";
589 			}
590 			else
591 			{
592 				if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthGreater))
593 					semantic = "SV_DepthGreaterEqual";
594 				else if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthLess))
595 					semantic = "SV_DepthLessEqual";
596 				else
597 					semantic = "SV_Depth";
598 			}
599 			break;
600 
601 		case BuiltInClipDistance:
602 			// HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
603 			for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
604 			{
605 				uint32_t to_declare = clip_distance_count - clip;
606 				if (to_declare > 4)
607 					to_declare = 4;
608 
609 				uint32_t semantic_index = clip / 4;
610 
611 				static const char *types[] = { "float", "float2", "float3", "float4" };
612 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
613 				          " : SV_ClipDistance", semantic_index, ";");
614 			}
615 			break;
616 
617 		case BuiltInCullDistance:
618 			// HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
619 			for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
620 			{
621 				uint32_t to_declare = cull_distance_count - cull;
622 				if (to_declare > 4)
623 					to_declare = 4;
624 
625 				uint32_t semantic_index = cull / 4;
626 
627 				static const char *types[] = { "float", "float2", "float3", "float4" };
628 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
629 				          " : SV_CullDistance", semantic_index, ";");
630 			}
631 			break;
632 
633 		case BuiltInPointSize:
634 			// If point_size_compat is enabled, just ignore PointSize.
635 			// PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
636 			// even if it means working around the missing feature.
637 			if (hlsl_options.point_size_compat)
638 				break;
639 			else
640 				SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
641 
642 		default:
643 			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
644 		}
645 
646 		if (type && semantic)
647 			statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
648 	});
649 }
650 
emit_builtin_inputs_in_struct()651 void CompilerHLSL::emit_builtin_inputs_in_struct()
652 {
653 	bool legacy = hlsl_options.shader_model <= 30;
654 	active_input_builtins.for_each_bit([&](uint32_t i) {
655 		const char *type = nullptr;
656 		const char *semantic = nullptr;
657 		auto builtin = static_cast<BuiltIn>(i);
658 		switch (builtin)
659 		{
660 		case BuiltInFragCoord:
661 			type = "float4";
662 			semantic = legacy ? "VPOS" : "SV_Position";
663 			break;
664 
665 		case BuiltInVertexId:
666 		case BuiltInVertexIndex:
667 			if (legacy)
668 				SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
669 			type = "uint";
670 			semantic = "SV_VertexID";
671 			break;
672 
673 		case BuiltInInstanceId:
674 		case BuiltInInstanceIndex:
675 			if (legacy)
676 				SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
677 			type = "uint";
678 			semantic = "SV_InstanceID";
679 			break;
680 
681 		case BuiltInSampleId:
682 			if (legacy)
683 				SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
684 			type = "uint";
685 			semantic = "SV_SampleIndex";
686 			break;
687 
688 		case BuiltInSampleMask:
689 			if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
690 				SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
691 			type = "uint";
692 			semantic = "SV_Coverage";
693 			break;
694 
695 		case BuiltInGlobalInvocationId:
696 			type = "uint3";
697 			semantic = "SV_DispatchThreadID";
698 			break;
699 
700 		case BuiltInLocalInvocationId:
701 			type = "uint3";
702 			semantic = "SV_GroupThreadID";
703 			break;
704 
705 		case BuiltInLocalInvocationIndex:
706 			type = "uint";
707 			semantic = "SV_GroupIndex";
708 			break;
709 
710 		case BuiltInWorkgroupId:
711 			type = "uint3";
712 			semantic = "SV_GroupID";
713 			break;
714 
715 		case BuiltInFrontFacing:
716 			type = "bool";
717 			semantic = "SV_IsFrontFace";
718 			break;
719 
720 		case BuiltInNumWorkgroups:
721 		case BuiltInSubgroupSize:
722 		case BuiltInSubgroupLocalInvocationId:
723 		case BuiltInSubgroupEqMask:
724 		case BuiltInSubgroupLtMask:
725 		case BuiltInSubgroupLeMask:
726 		case BuiltInSubgroupGtMask:
727 		case BuiltInSubgroupGeMask:
728 			// Handled specially.
729 			break;
730 
731 		case BuiltInClipDistance:
732 			// HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
733 			for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
734 			{
735 				uint32_t to_declare = clip_distance_count - clip;
736 				if (to_declare > 4)
737 					to_declare = 4;
738 
739 				uint32_t semantic_index = clip / 4;
740 
741 				static const char *types[] = { "float", "float2", "float3", "float4" };
742 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
743 				          " : SV_ClipDistance", semantic_index, ";");
744 			}
745 			break;
746 
747 		case BuiltInCullDistance:
748 			// HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
749 			for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
750 			{
751 				uint32_t to_declare = cull_distance_count - cull;
752 				if (to_declare > 4)
753 					to_declare = 4;
754 
755 				uint32_t semantic_index = cull / 4;
756 
757 				static const char *types[] = { "float", "float2", "float3", "float4" };
758 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
759 				          " : SV_CullDistance", semantic_index, ";");
760 			}
761 			break;
762 
763 		case BuiltInPointCoord:
764 			// PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
765 			if (hlsl_options.point_coord_compat)
766 				break;
767 			else
768 				SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
769 
770 		default:
771 			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
772 		}
773 
774 		if (type && semantic)
775 			statement(type, " ", builtin_to_glsl(builtin, StorageClassInput), " : ", semantic, ";");
776 	});
777 }
778 
type_to_consumed_locations(const SPIRType & type) const779 uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
780 {
781 	// TODO: Need to verify correctness.
782 	uint32_t elements = 0;
783 
784 	if (type.basetype == SPIRType::Struct)
785 	{
786 		for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
787 			elements += type_to_consumed_locations(get<SPIRType>(type.member_types[i]));
788 	}
789 	else
790 	{
791 		uint32_t array_multiplier = 1;
792 		for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
793 		{
794 			if (type.array_size_literal[i])
795 				array_multiplier *= type.array[i];
796 			else
797 				array_multiplier *= evaluate_constant_u32(type.array[i]);
798 		}
799 		elements += array_multiplier * type.columns;
800 	}
801 	return elements;
802 }
803 
to_interpolation_qualifiers(const Bitset & flags)804 string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
805 {
806 	string res;
807 	//if (flags & (1ull << DecorationSmooth))
808 	//    res += "linear ";
809 	if (flags.get(DecorationFlat))
810 		res += "nointerpolation ";
811 	if (flags.get(DecorationNoPerspective))
812 		res += "noperspective ";
813 	if (flags.get(DecorationCentroid))
814 		res += "centroid ";
815 	if (flags.get(DecorationPatch))
816 		res += "patch "; // Seems to be different in actual HLSL.
817 	if (flags.get(DecorationSample))
818 		res += "sample ";
819 	if (flags.get(DecorationInvariant) && backend.support_precise_qualifier)
820 		res += "precise "; // Not supported?
821 
822 	return res;
823 }
824 
to_semantic(uint32_t location,ExecutionModel em,StorageClass sc)825 std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
826 {
827 	if (em == ExecutionModelVertex && sc == StorageClassInput)
828 	{
829 		// We have a vertex attribute - we should look at remapping it if the user provided
830 		// vertex attribute hints.
831 		for (auto &attribute : remap_vertex_attributes)
832 			if (attribute.location == location)
833 				return attribute.semantic;
834 	}
835 
836 	// Not a vertex attribute, or no remap_vertex_attributes entry.
837 	return join("TEXCOORD", location);
838 }
839 
to_initializer_expression(const SPIRVariable & var)840 std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
841 {
842 	// We cannot emit static const initializer for block constants for practical reasons,
843 	// so just inline the initializer.
844 	// FIXME: There is a theoretical problem here if someone tries to composite extract
845 	// into this initializer since we don't declare it properly, but that is somewhat non-sensical.
846 	auto &type = get<SPIRType>(var.basetype);
847 	bool is_block = has_decoration(type.self, DecorationBlock);
848 	auto *c = maybe_get<SPIRConstant>(var.initializer);
849 	if (is_block && c)
850 		return constant_expression(*c);
851 	else
852 		return CompilerGLSL::to_initializer_expression(var);
853 }
854 
emit_interface_block_member_in_struct(const SPIRVariable & var,uint32_t member_index,uint32_t location,std::unordered_set<uint32_t> & active_locations)855 void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
856                                                          uint32_t location,
857                                                          std::unordered_set<uint32_t> &active_locations)
858 {
859 	auto &execution = get_entry_point();
860 	auto type = get<SPIRType>(var.basetype);
861 	auto semantic = to_semantic(location, execution.model, var.storage);
862 	auto mbr_name = join(to_name(type.self), "_", to_member_name(type, member_index));
863 	auto &mbr_type = get<SPIRType>(type.member_types[member_index]);
864 
865 	statement(to_interpolation_qualifiers(get_member_decoration_bitset(type.self, member_index)),
866 	          type_to_glsl(mbr_type),
867 	          " ", mbr_name, type_to_array_glsl(mbr_type),
868 	          " : ", semantic, ";");
869 
870 	// Structs and arrays should consume more locations.
871 	uint32_t consumed_locations = type_to_consumed_locations(mbr_type);
872 	for (uint32_t i = 0; i < consumed_locations; i++)
873 		active_locations.insert(location + i);
874 }
875 
emit_interface_block_in_struct(const SPIRVariable & var,unordered_set<uint32_t> & active_locations)876 void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
877 {
878 	auto &execution = get_entry_point();
879 	auto type = get<SPIRType>(var.basetype);
880 
881 	string binding;
882 	bool use_location_number = true;
883 	bool legacy = hlsl_options.shader_model <= 30;
884 	if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
885 	{
886 		// Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
887 		uint32_t index = get_decoration(var.self, DecorationIndex);
888 		uint32_t location = get_decoration(var.self, DecorationLocation);
889 
890 		if (index != 0 && location != 0)
891 			SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
892 
893 		binding = join(legacy ? "COLOR" : "SV_Target", location + index);
894 		use_location_number = false;
895 		if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
896 			type.vecsize = 4;
897 	}
898 
899 	const auto get_vacant_location = [&]() -> uint32_t {
900 		for (uint32_t i = 0; i < 64; i++)
901 			if (!active_locations.count(i))
902 				return i;
903 		SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
904 	};
905 
906 	bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
907 
908 	auto name = to_name(var.self);
909 	if (use_location_number)
910 	{
911 		uint32_t location_number;
912 
913 		// If an explicit location exists, use it with TEXCOORD[N] semantic.
914 		// Otherwise, pick a vacant location.
915 		if (has_decoration(var.self, DecorationLocation))
916 			location_number = get_decoration(var.self, DecorationLocation);
917 		else
918 			location_number = get_vacant_location();
919 
920 		// Allow semantic remap if specified.
921 		auto semantic = to_semantic(location_number, execution.model, var.storage);
922 
923 		if (need_matrix_unroll && type.columns > 1)
924 		{
925 			if (!type.array.empty())
926 				SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
927 
928 			// Unroll matrices.
929 			for (uint32_t i = 0; i < type.columns; i++)
930 			{
931 				SPIRType newtype = type;
932 				newtype.columns = 1;
933 
934 				string effective_semantic;
935 				if (hlsl_options.flatten_matrix_vertex_input_semantics)
936 					effective_semantic = to_semantic(location_number, execution.model, var.storage);
937 				else
938 					effective_semantic = join(semantic, "_", i);
939 
940 				statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)),
941 				          variable_decl(newtype, join(name, "_", i)), " : ", effective_semantic, ";");
942 				active_locations.insert(location_number++);
943 			}
944 		}
945 		else
946 		{
947 			statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
948 			          semantic, ";");
949 
950 			// Structs and arrays should consume more locations.
951 			uint32_t consumed_locations = type_to_consumed_locations(type);
952 			for (uint32_t i = 0; i < consumed_locations; i++)
953 				active_locations.insert(location_number + i);
954 		}
955 	}
956 	else
957 		statement(variable_decl(type, name), " : ", binding, ";");
958 }
959 
builtin_to_glsl(spv::BuiltIn builtin,spv::StorageClass storage)960 std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
961 {
962 	switch (builtin)
963 	{
964 	case BuiltInVertexId:
965 		return "gl_VertexID";
966 	case BuiltInInstanceId:
967 		return "gl_InstanceID";
968 	case BuiltInNumWorkgroups:
969 	{
970 		if (!num_workgroups_builtin)
971 			SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
972 			                  "Cannot emit code for this builtin.");
973 
974 		auto &var = get<SPIRVariable>(num_workgroups_builtin);
975 		auto &type = get<SPIRType>(var.basetype);
976 		auto ret = join(to_name(num_workgroups_builtin), "_", get_member_name(type.self, 0));
977 		ParsedIR::sanitize_underscores(ret);
978 		return ret;
979 	}
980 	case BuiltInPointCoord:
981 		// Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
982 		return "float2(0.5f, 0.5f)";
983 	case BuiltInSubgroupLocalInvocationId:
984 		return "WaveGetLaneIndex()";
985 	case BuiltInSubgroupSize:
986 		return "WaveGetLaneCount()";
987 
988 	default:
989 		return CompilerGLSL::builtin_to_glsl(builtin, storage);
990 	}
991 }
992 
emit_builtin_variables()993 void CompilerHLSL::emit_builtin_variables()
994 {
995 	Bitset builtins = active_input_builtins;
996 	builtins.merge_or(active_output_builtins);
997 
998 	bool need_base_vertex_info = false;
999 
1000 	std::unordered_map<uint32_t, ID> builtin_to_initializer;
1001 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1002 		if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1003 			return;
1004 
1005 		auto *c = this->maybe_get<SPIRConstant>(var.initializer);
1006 		if (!c)
1007 			return;
1008 
1009 		auto &type = this->get<SPIRType>(var.basetype);
1010 		if (type.basetype == SPIRType::Struct)
1011 		{
1012 			uint32_t member_count = uint32_t(type.member_types.size());
1013 			for (uint32_t i = 0; i < member_count; i++)
1014 			{
1015 				if (has_member_decoration(type.self, i, DecorationBuiltIn))
1016 				{
1017 					builtin_to_initializer[get_member_decoration(type.self, i, DecorationBuiltIn)] =
1018 						c->subconstants[i];
1019 				}
1020 			}
1021 		}
1022 		else if (has_decoration(var.self, DecorationBuiltIn))
1023 			builtin_to_initializer[get_decoration(var.self, DecorationBuiltIn)] = var.initializer;
1024 	});
1025 
1026 	// Emit global variables for the interface variables which are statically used by the shader.
1027 	builtins.for_each_bit([&](uint32_t i) {
1028 		const char *type = nullptr;
1029 		auto builtin = static_cast<BuiltIn>(i);
1030 		uint32_t array_size = 0;
1031 
1032 		string init_expr;
1033 		auto init_itr = builtin_to_initializer.find(builtin);
1034 		if (init_itr != builtin_to_initializer.end())
1035 			init_expr = join(" = ", to_expression(init_itr->second));
1036 
1037 		switch (builtin)
1038 		{
1039 		case BuiltInFragCoord:
1040 		case BuiltInPosition:
1041 			type = "float4";
1042 			break;
1043 
1044 		case BuiltInFragDepth:
1045 			type = "float";
1046 			break;
1047 
1048 		case BuiltInVertexId:
1049 		case BuiltInVertexIndex:
1050 		case BuiltInInstanceIndex:
1051 			type = "int";
1052 			if (hlsl_options.support_nonzero_base_vertex_base_instance)
1053 				need_base_vertex_info = true;
1054 			break;
1055 
1056 		case BuiltInInstanceId:
1057 		case BuiltInSampleId:
1058 			type = "int";
1059 			break;
1060 
1061 		case BuiltInPointSize:
1062 			if (hlsl_options.point_size_compat)
1063 			{
1064 				// Just emit the global variable, it will be ignored.
1065 				type = "float";
1066 				break;
1067 			}
1068 			else
1069 				SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1070 
1071 		case BuiltInGlobalInvocationId:
1072 		case BuiltInLocalInvocationId:
1073 		case BuiltInWorkgroupId:
1074 			type = "uint3";
1075 			break;
1076 
1077 		case BuiltInLocalInvocationIndex:
1078 			type = "uint";
1079 			break;
1080 
1081 		case BuiltInFrontFacing:
1082 			type = "bool";
1083 			break;
1084 
1085 		case BuiltInNumWorkgroups:
1086 		case BuiltInPointCoord:
1087 			// Handled specially.
1088 			break;
1089 
1090 		case BuiltInSubgroupLocalInvocationId:
1091 		case BuiltInSubgroupSize:
1092 			if (hlsl_options.shader_model < 60)
1093 				SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1094 			break;
1095 
1096 		case BuiltInSubgroupEqMask:
1097 		case BuiltInSubgroupLtMask:
1098 		case BuiltInSubgroupLeMask:
1099 		case BuiltInSubgroupGtMask:
1100 		case BuiltInSubgroupGeMask:
1101 			if (hlsl_options.shader_model < 60)
1102 				SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1103 			type = "uint4";
1104 			break;
1105 
1106 		case BuiltInClipDistance:
1107 			array_size = clip_distance_count;
1108 			type = "float";
1109 			break;
1110 
1111 		case BuiltInCullDistance:
1112 			array_size = cull_distance_count;
1113 			type = "float";
1114 			break;
1115 
1116 		case BuiltInSampleMask:
1117 			type = "int";
1118 			break;
1119 
1120 		default:
1121 			SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1122 		}
1123 
1124 		StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
1125 
1126 		if (type)
1127 		{
1128 			if (array_size)
1129 				statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";");
1130 			else
1131 				statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";");
1132 		}
1133 
1134 		// SampleMask can be both in and out with sample builtin, in this case we have already
1135 		// declared the input variable and we need to add the output one now.
1136 		if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i))
1137 		{
1138 			statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";");
1139 		}
1140 	});
1141 
1142 	if (need_base_vertex_info)
1143 	{
1144 		statement("cbuffer SPIRV_Cross_VertexInfo");
1145 		begin_scope();
1146 		statement("int SPIRV_Cross_BaseVertex;");
1147 		statement("int SPIRV_Cross_BaseInstance;");
1148 		end_scope_decl();
1149 		statement("");
1150 	}
1151 }
1152 
emit_composite_constants()1153 void CompilerHLSL::emit_composite_constants()
1154 {
1155 	// HLSL cannot declare structs or arrays inline, so we must move them out to
1156 	// global constants directly.
1157 	bool emitted = false;
1158 
1159 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
1160 		if (c.specialization)
1161 			return;
1162 
1163 		auto &type = this->get<SPIRType>(c.constant_type);
1164 
1165 		if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1166 			return;
1167 
1168 		if (type.basetype == SPIRType::Struct || !type.array.empty())
1169 		{
1170 			auto name = to_name(c.self);
1171 			statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1172 			emitted = true;
1173 		}
1174 	});
1175 
1176 	if (emitted)
1177 		statement("");
1178 }
1179 
emit_specialization_constants_and_structs()1180 void CompilerHLSL::emit_specialization_constants_and_structs()
1181 {
1182 	bool emitted = false;
1183 	SpecializationConstant wg_x, wg_y, wg_z;
1184 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
1185 
1186 	std::unordered_set<TypeID> io_block_types;
1187 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, const SPIRVariable &var) {
1188 		auto &type = this->get<SPIRType>(var.basetype);
1189 		if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1190 		    !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1191 		    interface_variable_exists_in_entry_point(var.self) &&
1192 		    has_decoration(type.self, DecorationBlock))
1193 		{
1194 			io_block_types.insert(type.self);
1195 		}
1196 	});
1197 
1198 	auto loop_lock = ir.create_loop_hard_lock();
1199 	for (auto &id_ : ir.ids_for_constant_or_type)
1200 	{
1201 		auto &id = ir.ids[id_];
1202 
1203 		if (id.get_type() == TypeConstant)
1204 		{
1205 			auto &c = id.get<SPIRConstant>();
1206 
1207 			if (c.self == workgroup_size_id)
1208 			{
1209 				statement("static const uint3 gl_WorkGroupSize = ",
1210 				          constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
1211 				emitted = true;
1212 			}
1213 			else if (c.specialization)
1214 			{
1215 				auto &type = get<SPIRType>(c.constant_type);
1216 				auto name = to_name(c.self);
1217 
1218 				// HLSL does not support specialization constants, so fallback to macros.
1219 				c.specialization_constant_macro_name =
1220 				    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
1221 
1222 				statement("#ifndef ", c.specialization_constant_macro_name);
1223 				statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
1224 				statement("#endif");
1225 				statement("static const ", variable_decl(type, name), " = ", c.specialization_constant_macro_name, ";");
1226 				emitted = true;
1227 			}
1228 		}
1229 		else if (id.get_type() == TypeConstantOp)
1230 		{
1231 			auto &c = id.get<SPIRConstantOp>();
1232 			auto &type = get<SPIRType>(c.basetype);
1233 			auto name = to_name(c.self);
1234 			statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
1235 			emitted = true;
1236 		}
1237 		else if (id.get_type() == TypeType)
1238 		{
1239 			auto &type = id.get<SPIRType>();
1240 			bool is_non_io_block = has_decoration(type.self, DecorationBlock) &&
1241 			                       io_block_types.count(type.self) == 0;
1242 			bool is_buffer_block = has_decoration(type.self, DecorationBufferBlock);
1243 			if (type.basetype == SPIRType::Struct && type.array.empty() &&
1244 			    !type.pointer && !is_non_io_block && !is_buffer_block)
1245 			{
1246 				if (emitted)
1247 					statement("");
1248 				emitted = false;
1249 
1250 				emit_struct(type);
1251 			}
1252 		}
1253 	}
1254 
1255 	if (emitted)
1256 		statement("");
1257 }
1258 
replace_illegal_names()1259 void CompilerHLSL::replace_illegal_names()
1260 {
1261 	static const unordered_set<string> keywords = {
1262 		// Additional HLSL specific keywords.
1263 		"line", "linear", "matrix", "point", "row_major", "sampler", "vector"
1264 	};
1265 
1266 	CompilerGLSL::replace_illegal_names(keywords);
1267 	CompilerGLSL::replace_illegal_names();
1268 }
1269 
declare_undefined_values()1270 void CompilerHLSL::declare_undefined_values()
1271 {
1272 	bool emitted = false;
1273 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
1274 		auto &type = this->get<SPIRType>(undef.basetype);
1275 		// OpUndef can be void for some reason ...
1276 		if (type.basetype == SPIRType::Void)
1277 			return;
1278 
1279 		string initializer;
1280 		if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1281 			initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
1282 
1283 		statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
1284 		emitted = true;
1285 	});
1286 
1287 	if (emitted)
1288 		statement("");
1289 }
1290 
emit_resources()1291 void CompilerHLSL::emit_resources()
1292 {
1293 	auto &execution = get_entry_point();
1294 
1295 	replace_illegal_names();
1296 
1297 	emit_specialization_constants_and_structs();
1298 	emit_composite_constants();
1299 
1300 	bool emitted = false;
1301 
1302 	// Output UBOs and SSBOs
1303 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1304 		auto &type = this->get<SPIRType>(var.basetype);
1305 
1306 		bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1307 		bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
1308 		                       ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
1309 
1310 		if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1311 		    has_block_flags)
1312 		{
1313 			emit_buffer_block(var);
1314 			emitted = true;
1315 		}
1316 	});
1317 
1318 	// Output push constant blocks
1319 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1320 		auto &type = this->get<SPIRType>(var.basetype);
1321 		if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1322 		    !is_hidden_variable(var))
1323 		{
1324 			emit_push_constant_block(var);
1325 			emitted = true;
1326 		}
1327 	});
1328 
1329 	if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30)
1330 	{
1331 		statement("uniform float4 gl_HalfPixel;");
1332 		emitted = true;
1333 	}
1334 
1335 	bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1336 
1337 	// Output Uniform Constants (values, samplers, images, etc).
1338 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1339 		auto &type = this->get<SPIRType>(var.basetype);
1340 
1341 		// If we're remapping separate samplers and images, only emit the combined samplers.
1342 		if (skip_separate_image_sampler)
1343 		{
1344 			// Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1345 			bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1346 			bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1347 			bool separate_sampler = type.basetype == SPIRType::Sampler;
1348 			if (!sampler_buffer && (separate_image || separate_sampler))
1349 				return;
1350 		}
1351 
1352 		if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1353 		    type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1354 		    !is_hidden_variable(var))
1355 		{
1356 			emit_uniform(var);
1357 			emitted = true;
1358 		}
1359 	});
1360 
1361 	if (emitted)
1362 		statement("");
1363 	emitted = false;
1364 
1365 	// Emit builtin input and output variables here.
1366 	emit_builtin_variables();
1367 
1368 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1369 		auto &type = this->get<SPIRType>(var.basetype);
1370 
1371 		if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1372 		    (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1373 		    interface_variable_exists_in_entry_point(var.self))
1374 		{
1375 			// Builtin variables are handled separately.
1376 			emit_interface_block_globally(var);
1377 			emitted = true;
1378 		}
1379 	});
1380 
1381 	if (emitted)
1382 		statement("");
1383 	emitted = false;
1384 
1385 	require_input = false;
1386 	require_output = false;
1387 	unordered_set<uint32_t> active_inputs;
1388 	unordered_set<uint32_t> active_outputs;
1389 
1390 	struct IOVariable
1391 	{
1392 		const SPIRVariable *var;
1393 		uint32_t location;
1394 		uint32_t block_member_index;
1395 		bool block;
1396 	};
1397 
1398 	SmallVector<IOVariable> input_variables;
1399 	SmallVector<IOVariable> output_variables;
1400 
1401 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1402 		auto &type = this->get<SPIRType>(var.basetype);
1403 		bool block = has_decoration(type.self, DecorationBlock);
1404 
1405 		if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1406 			return;
1407 
1408 		if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1409 		    interface_variable_exists_in_entry_point(var.self))
1410 		{
1411 			if (block)
1412 			{
1413 				for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1414 				{
1415 					uint32_t location = get_declared_member_location(var, i, false);
1416 					if (var.storage == StorageClassInput)
1417 						input_variables.push_back({ &var, location, i, true });
1418 					else
1419 						output_variables.push_back({ &var, location, i, true });
1420 				}
1421 			}
1422 			else
1423 			{
1424 				uint32_t location = get_decoration(var.self, DecorationLocation);
1425 				if (var.storage == StorageClassInput)
1426 					input_variables.push_back({ &var, location, 0, false });
1427 				else
1428 					output_variables.push_back({ &var, location, 0, false });
1429 			}
1430 		}
1431 	});
1432 
1433 	const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1434 		// Sort input and output variables based on, from more robust to less robust:
1435 		// - Location
1436 		// - Variable has a location
1437 		// - Name comparison
1438 		// - Variable has a name
1439 		// - Fallback: ID
1440 		bool has_location_a = a.block || has_decoration(a.var->self, DecorationLocation);
1441 		bool has_location_b = b.block || has_decoration(b.var->self, DecorationLocation);
1442 
1443 		if (has_location_a && has_location_b)
1444 			return a.location < b.location;
1445 		else if (has_location_a && !has_location_b)
1446 			return true;
1447 		else if (!has_location_a && has_location_b)
1448 			return false;
1449 
1450 		const auto &name1 = to_name(a.var->self);
1451 		const auto &name2 = to_name(b.var->self);
1452 
1453 		if (name1.empty() && name2.empty())
1454 			return a.var->self < b.var->self;
1455 		else if (name1.empty())
1456 			return true;
1457 		else if (name2.empty())
1458 			return false;
1459 
1460 		return name1.compare(name2) < 0;
1461 	};
1462 
1463 	auto input_builtins = active_input_builtins;
1464 	input_builtins.clear(BuiltInNumWorkgroups);
1465 	input_builtins.clear(BuiltInPointCoord);
1466 	input_builtins.clear(BuiltInSubgroupSize);
1467 	input_builtins.clear(BuiltInSubgroupLocalInvocationId);
1468 	input_builtins.clear(BuiltInSubgroupEqMask);
1469 	input_builtins.clear(BuiltInSubgroupLtMask);
1470 	input_builtins.clear(BuiltInSubgroupLeMask);
1471 	input_builtins.clear(BuiltInSubgroupGtMask);
1472 	input_builtins.clear(BuiltInSubgroupGeMask);
1473 
1474 	if (!input_variables.empty() || !input_builtins.empty())
1475 	{
1476 		require_input = true;
1477 		statement("struct SPIRV_Cross_Input");
1478 
1479 		begin_scope();
1480 		sort(input_variables.begin(), input_variables.end(), variable_compare);
1481 		for (auto &var : input_variables)
1482 		{
1483 			if (var.block)
1484 				emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_inputs);
1485 			else
1486 				emit_interface_block_in_struct(*var.var, active_inputs);
1487 		}
1488 		emit_builtin_inputs_in_struct();
1489 		end_scope_decl();
1490 		statement("");
1491 	}
1492 
1493 	if (!output_variables.empty() || !active_output_builtins.empty())
1494 	{
1495 		require_output = true;
1496 		statement("struct SPIRV_Cross_Output");
1497 
1498 		begin_scope();
1499 		sort(output_variables.begin(), output_variables.end(), variable_compare);
1500 		for (auto &var : output_variables)
1501 		{
1502 			if (var.block)
1503 				emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_outputs);
1504 			else
1505 				emit_interface_block_in_struct(*var.var, active_outputs);
1506 		}
1507 		emit_builtin_outputs_in_struct();
1508 		end_scope_decl();
1509 		statement("");
1510 	}
1511 
1512 	// Global variables.
1513 	for (auto global : global_variables)
1514 	{
1515 		auto &var = get<SPIRVariable>(global);
1516 		if (is_hidden_variable(var, true))
1517 			continue;
1518 
1519 		if (var.storage != StorageClassOutput)
1520 		{
1521 			if (!variable_is_lut(var))
1522 			{
1523 				add_resource_name(var.self);
1524 
1525 				const char *storage = nullptr;
1526 				switch (var.storage)
1527 				{
1528 				case StorageClassWorkgroup:
1529 					storage = "groupshared";
1530 					break;
1531 
1532 				default:
1533 					storage = "static";
1534 					break;
1535 				}
1536 
1537 				string initializer;
1538 				if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1539 				    !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var)))
1540 				{
1541 					initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var)));
1542 				}
1543 				statement(storage, " ", variable_decl(var), initializer, ";");
1544 
1545 				emitted = true;
1546 			}
1547 		}
1548 	}
1549 
1550 	if (emitted)
1551 		statement("");
1552 
1553 	declare_undefined_values();
1554 
1555 	if (requires_op_fmod)
1556 	{
1557 		static const char *types[] = {
1558 			"float",
1559 			"float2",
1560 			"float3",
1561 			"float4",
1562 		};
1563 
1564 		for (auto &type : types)
1565 		{
1566 			statement(type, " mod(", type, " x, ", type, " y)");
1567 			begin_scope();
1568 			statement("return x - y * floor(x / y);");
1569 			end_scope();
1570 			statement("");
1571 		}
1572 	}
1573 
1574 	emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
1575 	for (uint32_t norm = 0; norm < 3; norm++)
1576 	{
1577 		for (uint32_t comp = 0; comp < 4; comp++)
1578 		{
1579 			static const char *qualifiers[] = { "", "unorm ", "snorm " };
1580 			static const char *vecsizes[] = { "", "2", "3", "4" };
1581 			emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true,
1582 			                           qualifiers[norm]);
1583 		}
1584 	}
1585 
1586 	if (requires_fp16_packing)
1587 	{
1588 		// HLSL does not pack into a single word sadly :(
1589 		statement("uint spvPackHalf2x16(float2 value)");
1590 		begin_scope();
1591 		statement("uint2 Packed = f32tof16(value);");
1592 		statement("return Packed.x | (Packed.y << 16);");
1593 		end_scope();
1594 		statement("");
1595 
1596 		statement("float2 spvUnpackHalf2x16(uint value)");
1597 		begin_scope();
1598 		statement("return f16tof32(uint2(value & 0xffff, value >> 16));");
1599 		end_scope();
1600 		statement("");
1601 	}
1602 
1603 	if (requires_uint2_packing)
1604 	{
1605 		statement("uint64_t spvPackUint2x32(uint2 value)");
1606 		begin_scope();
1607 		statement("return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1608 		end_scope();
1609 		statement("");
1610 
1611 		statement("uint2 spvUnpackUint2x32(uint64_t value)");
1612 		begin_scope();
1613 		statement("uint2 Unpacked;");
1614 		statement("Unpacked.x = uint(value & 0xffffffff);");
1615 		statement("Unpacked.y = uint(value >> 32);");
1616 		statement("return Unpacked;");
1617 		end_scope();
1618 		statement("");
1619 	}
1620 
1621 	if (requires_explicit_fp16_packing)
1622 	{
1623 		// HLSL does not pack into a single word sadly :(
1624 		statement("uint spvPackFloat2x16(min16float2 value)");
1625 		begin_scope();
1626 		statement("uint2 Packed = f32tof16(value);");
1627 		statement("return Packed.x | (Packed.y << 16);");
1628 		end_scope();
1629 		statement("");
1630 
1631 		statement("min16float2 spvUnpackFloat2x16(uint value)");
1632 		begin_scope();
1633 		statement("return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1634 		end_scope();
1635 		statement("");
1636 	}
1637 
1638 	// HLSL does not seem to have builtins for these operation, so roll them by hand ...
1639 	if (requires_unorm8_packing)
1640 	{
1641 		statement("uint spvPackUnorm4x8(float4 value)");
1642 		begin_scope();
1643 		statement("uint4 Packed = uint4(round(saturate(value) * 255.0));");
1644 		statement("return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1645 		end_scope();
1646 		statement("");
1647 
1648 		statement("float4 spvUnpackUnorm4x8(uint value)");
1649 		begin_scope();
1650 		statement("uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1651 		statement("return float4(Packed) / 255.0;");
1652 		end_scope();
1653 		statement("");
1654 	}
1655 
1656 	if (requires_snorm8_packing)
1657 	{
1658 		statement("uint spvPackSnorm4x8(float4 value)");
1659 		begin_scope();
1660 		statement("int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1661 		statement("return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1662 		end_scope();
1663 		statement("");
1664 
1665 		statement("float4 spvUnpackSnorm4x8(uint value)");
1666 		begin_scope();
1667 		statement("int SignedValue = int(value);");
1668 		statement("int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1669 		statement("return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1670 		end_scope();
1671 		statement("");
1672 	}
1673 
1674 	if (requires_unorm16_packing)
1675 	{
1676 		statement("uint spvPackUnorm2x16(float2 value)");
1677 		begin_scope();
1678 		statement("uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1679 		statement("return Packed.x | (Packed.y << 16);");
1680 		end_scope();
1681 		statement("");
1682 
1683 		statement("float2 spvUnpackUnorm2x16(uint value)");
1684 		begin_scope();
1685 		statement("uint2 Packed = uint2(value & 0xffff, value >> 16);");
1686 		statement("return float2(Packed) / 65535.0;");
1687 		end_scope();
1688 		statement("");
1689 	}
1690 
1691 	if (requires_snorm16_packing)
1692 	{
1693 		statement("uint spvPackSnorm2x16(float2 value)");
1694 		begin_scope();
1695 		statement("int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1696 		statement("return uint(Packed.x | (Packed.y << 16));");
1697 		end_scope();
1698 		statement("");
1699 
1700 		statement("float2 spvUnpackSnorm2x16(uint value)");
1701 		begin_scope();
1702 		statement("int SignedValue = int(value);");
1703 		statement("int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1704 		statement("return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1705 		end_scope();
1706 		statement("");
1707 	}
1708 
1709 	if (requires_bitfield_insert)
1710 	{
1711 		static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1712 		for (auto &type : types)
1713 		{
1714 			statement(type, " spvBitfieldInsert(", type, " Base, ", type, " Insert, uint Offset, uint Count)");
1715 			begin_scope();
1716 			statement("uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1717 			statement("return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1718 			end_scope();
1719 			statement("");
1720 		}
1721 	}
1722 
1723 	if (requires_bitfield_extract)
1724 	{
1725 		static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1726 		for (auto &type : unsigned_types)
1727 		{
1728 			statement(type, " spvBitfieldUExtract(", type, " Base, uint Offset, uint Count)");
1729 			begin_scope();
1730 			statement("uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1731 			statement("return (Base >> Offset) & Mask;");
1732 			end_scope();
1733 			statement("");
1734 		}
1735 
1736 		// In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1737 		static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1738 		for (auto &type : signed_types)
1739 		{
1740 			statement(type, " spvBitfieldSExtract(", type, " Base, int Offset, int Count)");
1741 			begin_scope();
1742 			statement("int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1743 			statement(type, " Masked = (Base >> Offset) & Mask;");
1744 			statement("int ExtendShift = (32 - Count) & 31;");
1745 			statement("return (Masked << ExtendShift) >> ExtendShift;");
1746 			end_scope();
1747 			statement("");
1748 		}
1749 	}
1750 
1751 	if (requires_inverse_2x2)
1752 	{
1753 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1754 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1755 		statement("float2x2 spvInverse(float2x2 m)");
1756 		begin_scope();
1757 		statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1758 		statement_no_indent("");
1759 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1760 		statement("adj[0][0] =  m[1][1];");
1761 		statement("adj[0][1] = -m[0][1];");
1762 		statement_no_indent("");
1763 		statement("adj[1][0] = -m[1][0];");
1764 		statement("adj[1][1] =  m[0][0];");
1765 		statement_no_indent("");
1766 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1767 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1768 		statement_no_indent("");
1769 		statement("// Divide the classical adjoint matrix by the determinant.");
1770 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1771 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1772 		end_scope();
1773 		statement("");
1774 	}
1775 
1776 	if (requires_inverse_3x3)
1777 	{
1778 		statement("// Returns the determinant of a 2x2 matrix.");
1779 		statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1780 		begin_scope();
1781 		statement("return a1 * b2 - b1 * a2;");
1782 		end_scope();
1783 		statement_no_indent("");
1784 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1785 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1786 		statement("float3x3 spvInverse(float3x3 m)");
1787 		begin_scope();
1788 		statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1789 		statement_no_indent("");
1790 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1791 		statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1792 		statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1793 		statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1794 		statement_no_indent("");
1795 		statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1796 		statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1797 		statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1798 		statement_no_indent("");
1799 		statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1800 		statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1801 		statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1802 		statement_no_indent("");
1803 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1804 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1805 		statement_no_indent("");
1806 		statement("// Divide the classical adjoint matrix by the determinant.");
1807 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1808 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1809 		end_scope();
1810 		statement("");
1811 	}
1812 
1813 	if (requires_inverse_4x4)
1814 	{
1815 		if (!requires_inverse_3x3)
1816 		{
1817 			statement("// Returns the determinant of a 2x2 matrix.");
1818 			statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1819 			begin_scope();
1820 			statement("return a1 * b2 - b1 * a2;");
1821 			end_scope();
1822 			statement("");
1823 		}
1824 
1825 		statement("// Returns the determinant of a 3x3 matrix.");
1826 		statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1827 		          "float c2, float c3)");
1828 		begin_scope();
1829 		statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1830 		          "spvDet2x2(a2, a3, "
1831 		          "b2, b3);");
1832 		end_scope();
1833 		statement_no_indent("");
1834 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1835 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1836 		statement("float4x4 spvInverse(float4x4 m)");
1837 		begin_scope();
1838 		statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1839 		statement_no_indent("");
1840 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1841 		statement(
1842 		    "adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1843 		    "m[3][3]);");
1844 		statement(
1845 		    "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1846 		    "m[3][3]);");
1847 		statement(
1848 		    "adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1849 		    "m[3][3]);");
1850 		statement(
1851 		    "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1852 		    "m[2][3]);");
1853 		statement_no_indent("");
1854 		statement(
1855 		    "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1856 		    "m[3][3]);");
1857 		statement(
1858 		    "adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1859 		    "m[3][3]);");
1860 		statement(
1861 		    "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1862 		    "m[3][3]);");
1863 		statement(
1864 		    "adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1865 		    "m[2][3]);");
1866 		statement_no_indent("");
1867 		statement(
1868 		    "adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1869 		    "m[3][3]);");
1870 		statement(
1871 		    "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1872 		    "m[3][3]);");
1873 		statement(
1874 		    "adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1875 		    "m[3][3]);");
1876 		statement(
1877 		    "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1878 		    "m[2][3]);");
1879 		statement_no_indent("");
1880 		statement(
1881 		    "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1882 		    "m[3][2]);");
1883 		statement(
1884 		    "adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1885 		    "m[3][2]);");
1886 		statement(
1887 		    "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1888 		    "m[3][2]);");
1889 		statement(
1890 		    "adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1891 		    "m[2][2]);");
1892 		statement_no_indent("");
1893 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1894 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1895 		          "* m[3][0]);");
1896 		statement_no_indent("");
1897 		statement("// Divide the classical adjoint matrix by the determinant.");
1898 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1899 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1900 		end_scope();
1901 		statement("");
1902 	}
1903 
1904 	if (requires_scalar_reflect)
1905 	{
1906 		// FP16/FP64? No templates in HLSL.
1907 		statement("float spvReflect(float i, float n)");
1908 		begin_scope();
1909 		statement("return i - 2.0 * dot(n, i) * n;");
1910 		end_scope();
1911 		statement("");
1912 	}
1913 
1914 	if (requires_scalar_refract)
1915 	{
1916 		// FP16/FP64? No templates in HLSL.
1917 		statement("float spvRefract(float i, float n, float eta)");
1918 		begin_scope();
1919 		statement("float NoI = n * i;");
1920 		statement("float NoI2 = NoI * NoI;");
1921 		statement("float k = 1.0 - eta * eta * (1.0 - NoI2);");
1922 		statement("if (k < 0.0)");
1923 		begin_scope();
1924 		statement("return 0.0;");
1925 		end_scope();
1926 		statement("else");
1927 		begin_scope();
1928 		statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
1929 		end_scope();
1930 		end_scope();
1931 		statement("");
1932 	}
1933 
1934 	if (requires_scalar_faceforward)
1935 	{
1936 		// FP16/FP64? No templates in HLSL.
1937 		statement("float spvFaceForward(float n, float i, float nref)");
1938 		begin_scope();
1939 		statement("return i * nref < 0.0 ? n : -n;");
1940 		end_scope();
1941 		statement("");
1942 	}
1943 
1944 	for (TypeID type_id : composite_selection_workaround_types)
1945 	{
1946 		// Need out variable since HLSL does not support returning arrays.
1947 		auto &type = get<SPIRType>(type_id);
1948 		auto type_str = type_to_glsl(type);
1949 		auto type_arr_str = type_to_array_glsl(type);
1950 		statement("void spvSelectComposite(out ", type_str, " out_value", type_arr_str, ", bool cond, ",
1951 		          type_str, " true_val", type_arr_str, ", ",
1952 		          type_str, " false_val", type_arr_str, ")");
1953 		begin_scope();
1954 		statement("if (cond)");
1955 		begin_scope();
1956 		statement("out_value = true_val;");
1957 		end_scope();
1958 		statement("else");
1959 		begin_scope();
1960 		statement("out_value = false_val;");
1961 		end_scope();
1962 		end_scope();
1963 		statement("");
1964 	}
1965 }
1966 
emit_texture_size_variants(uint64_t variant_mask,const char * vecsize_qualifier,bool uav,const char * type_qualifier)1967 void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
1968                                               const char *type_qualifier)
1969 {
1970 	if (variant_mask == 0)
1971 		return;
1972 
1973 	static const char *types[QueryTypeCount] = { "float", "int", "uint" };
1974 	static const char *dims[QueryDimCount] = { "Texture1D",   "Texture1DArray",  "Texture2D",   "Texture2DArray",
1975 		                                       "Texture3D",   "Buffer",          "TextureCube", "TextureCubeArray",
1976 		                                       "Texture2DMS", "Texture2DMSArray" };
1977 
1978 	static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
1979 
1980 	static const char *ret_types[QueryDimCount] = {
1981 		"uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
1982 	};
1983 
1984 	static const uint32_t return_arguments[QueryDimCount] = {
1985 		1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
1986 	};
1987 
1988 	for (uint32_t index = 0; index < QueryDimCount; index++)
1989 	{
1990 		for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
1991 		{
1992 			uint32_t bit = 16 * type_index + index;
1993 			uint64_t mask = 1ull << bit;
1994 
1995 			if ((variant_mask & mask) == 0)
1996 				continue;
1997 
1998 			statement(ret_types[index], " spv", (uav ? "Image" : "Texture"), "Size(", (uav ? "RW" : ""),
1999 			          dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier, "> Tex, ",
2000 			          (uav ? "" : "uint Level, "), "out uint Param)");
2001 			begin_scope();
2002 			statement(ret_types[index], " ret;");
2003 			switch (return_arguments[index])
2004 			{
2005 			case 1:
2006 				if (has_lod[index] && !uav)
2007 					statement("Tex.GetDimensions(Level, ret.x, Param);");
2008 				else
2009 				{
2010 					statement("Tex.GetDimensions(ret.x);");
2011 					statement("Param = 0u;");
2012 				}
2013 				break;
2014 			case 2:
2015 				if (has_lod[index] && !uav)
2016 					statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2017 				else if (!uav)
2018 					statement("Tex.GetDimensions(ret.x, ret.y, Param);");
2019 				else
2020 				{
2021 					statement("Tex.GetDimensions(ret.x, ret.y);");
2022 					statement("Param = 0u;");
2023 				}
2024 				break;
2025 			case 3:
2026 				if (has_lod[index] && !uav)
2027 					statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2028 				else if (!uav)
2029 					statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2030 				else
2031 				{
2032 					statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
2033 					statement("Param = 0u;");
2034 				}
2035 				break;
2036 			}
2037 
2038 			statement("return ret;");
2039 			end_scope();
2040 			statement("");
2041 		}
2042 	}
2043 }
2044 
layout_for_member(const SPIRType & type,uint32_t index)2045 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2046 {
2047 	auto &flags = get_member_decoration_bitset(type.self, index);
2048 
2049 	// HLSL can emit row_major or column_major decoration in any struct.
2050 	// Do not try to merge combined decorations for children like in GLSL.
2051 
2052 	// Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2053 	// The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2054 	if (flags.get(DecorationColMajor))
2055 		return "row_major ";
2056 	else if (flags.get(DecorationRowMajor))
2057 		return "column_major ";
2058 
2059 	return "";
2060 }
2061 
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t base_offset)2062 void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2063                                       const string &qualifier, uint32_t base_offset)
2064 {
2065 	auto &membertype = get<SPIRType>(member_type_id);
2066 
2067 	Bitset memberflags;
2068 	auto &memb = ir.meta[type.self].members;
2069 	if (index < memb.size())
2070 		memberflags = memb[index].decoration_flags;
2071 
2072 	string packing_offset;
2073 	bool is_push_constant = type.storage == StorageClassPushConstant;
2074 
2075 	if ((has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2076 	    has_member_decoration(type.self, index, DecorationOffset))
2077 	{
2078 		uint32_t offset = memb[index].offset - base_offset;
2079 		if (offset & 3)
2080 			SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2081 
2082 		static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2083 		packing_offset = join(" : packoffset(c", offset / 16, packing_swizzle[(offset & 15) >> 2], ")");
2084 	}
2085 
2086 	statement(layout_for_member(type, index), qualifier,
2087 	          variable_decl(membertype, to_member_name(type, index)), packing_offset, ";");
2088 }
2089 
emit_buffer_block(const SPIRVariable & var)2090 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2091 {
2092 	auto &type = get<SPIRType>(var.basetype);
2093 
2094 	bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock);
2095 
2096 	if (is_uav)
2097 	{
2098 		Bitset flags = ir.get_buffer_block_flags(var);
2099 		bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
2100 		bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
2101 		bool is_interlocked = interlocked_resources.count(var.self) > 0;
2102 		const char *type_name = "ByteAddressBuffer ";
2103 		if (!is_readonly)
2104 			type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2105 		add_resource_name(var.self);
2106 		statement(is_coherent ? "globallycoherent " : "", type_name, to_name(var.self), type_to_array_glsl(type),
2107 		          to_resource_binding(var), ";");
2108 	}
2109 	else
2110 	{
2111 		if (type.array.empty())
2112 		{
2113 			// Flatten the top-level struct so we can use packoffset,
2114 			// this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2115 			flattened_structs[var.self] = false;
2116 
2117 			// Prefer the block name if possible.
2118 			auto buffer_name = to_name(type.self, false);
2119 			if (ir.meta[type.self].decoration.alias.empty() ||
2120 			    resource_names.find(buffer_name) != end(resource_names) ||
2121 			    block_names.find(buffer_name) != end(block_names))
2122 			{
2123 				buffer_name = get_block_fallback_name(var.self);
2124 			}
2125 
2126 			add_variable(block_names, resource_names, buffer_name);
2127 
2128 			// If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2129 			// This cannot conflict with anything else, so we're safe now.
2130 			if (buffer_name.empty())
2131 				buffer_name = join("_", get<SPIRType>(var.basetype).self, "_", var.self);
2132 
2133 			uint32_t failed_index = 0;
2134 			if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index))
2135 				set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2136 			else
2137 			{
2138 				SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2139 				                       failed_index, " (name: ", to_member_name(type, failed_index),
2140 				                       ") cannot be expressed with either HLSL packing layout or packoffset."));
2141 			}
2142 
2143 			block_names.insert(buffer_name);
2144 
2145 			// Save for post-reflection later.
2146 			declared_block_names[var.self] = buffer_name;
2147 
2148 			type.member_name_cache.clear();
2149 			// var.self can be used as a backup name for the block name,
2150 			// so we need to make sure we don't disturb the name here on a recompile.
2151 			// It will need to be reset if we have to recompile.
2152 			preserve_alias_on_reset(var.self);
2153 			add_resource_name(var.self);
2154 			statement("cbuffer ", buffer_name, to_resource_binding(var));
2155 			begin_scope();
2156 
2157 			uint32_t i = 0;
2158 			for (auto &member : type.member_types)
2159 			{
2160 				add_member_name(type, i);
2161 				auto backup_name = get_member_name(type.self, i);
2162 				auto member_name = to_member_name(type, i);
2163 				member_name = join(to_name(var.self), "_", member_name);
2164 				ParsedIR::sanitize_underscores(member_name);
2165 				set_member_name(type.self, i, member_name);
2166 				emit_struct_member(type, member, i, "");
2167 				set_member_name(type.self, i, backup_name);
2168 				i++;
2169 			}
2170 
2171 			end_scope_decl();
2172 			statement("");
2173 		}
2174 		else
2175 		{
2176 			if (hlsl_options.shader_model < 51)
2177 				SPIRV_CROSS_THROW(
2178 				    "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2179 
2180 			add_resource_name(type.self);
2181 			add_resource_name(var.self);
2182 
2183 			// ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2184 			uint32_t failed_index = 0;
2185 			if (!buffer_is_packing_standard(type, BufferPackingHLSLCbuffer, &failed_index))
2186 			{
2187 				SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2188 				                       "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2189 				                       ") cannot be expressed with normal HLSL packing rules."));
2190 			}
2191 
2192 			emit_struct(get<SPIRType>(type.self));
2193 			statement("ConstantBuffer<", to_name(type.self), "> ", to_name(var.self), type_to_array_glsl(type),
2194 			          to_resource_binding(var), ";");
2195 		}
2196 	}
2197 }
2198 
emit_push_constant_block(const SPIRVariable & var)2199 void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2200 {
2201 	if (root_constants_layout.empty())
2202 	{
2203 		emit_buffer_block(var);
2204 	}
2205 	else
2206 	{
2207 		for (const auto &layout : root_constants_layout)
2208 		{
2209 			auto &type = get<SPIRType>(var.basetype);
2210 
2211 			uint32_t failed_index = 0;
2212 			if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index, layout.start,
2213 			                               layout.end))
2214 				set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2215 			else
2216 			{
2217 				SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2218 				                       ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2219 				                       ") cannot be expressed with either HLSL packing layout or packoffset."));
2220 			}
2221 
2222 			flattened_structs[var.self] = false;
2223 			type.member_name_cache.clear();
2224 			add_resource_name(var.self);
2225 			auto &memb = ir.meta[type.self].members;
2226 
2227 			statement("cbuffer SPIRV_CROSS_RootConstant_", to_name(var.self),
2228 			          to_resource_register(HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, 'b', layout.binding, layout.space));
2229 			begin_scope();
2230 
2231 			// Index of the next field in the generated root constant constant buffer
2232 			auto constant_index = 0u;
2233 
2234 			// Iterate over all member of the push constant and check which of the fields
2235 			// fit into the given root constant layout.
2236 			for (auto i = 0u; i < memb.size(); i++)
2237 			{
2238 				const auto offset = memb[i].offset;
2239 				if (layout.start <= offset && offset < layout.end)
2240 				{
2241 					const auto &member = type.member_types[i];
2242 
2243 					add_member_name(type, constant_index);
2244 					auto backup_name = get_member_name(type.self, i);
2245 					auto member_name = to_member_name(type, i);
2246 					member_name = join(to_name(var.self), "_", member_name);
2247 					ParsedIR::sanitize_underscores(member_name);
2248 					set_member_name(type.self, constant_index, member_name);
2249 					emit_struct_member(type, member, i, "", layout.start);
2250 					set_member_name(type.self, constant_index, backup_name);
2251 
2252 					constant_index++;
2253 				}
2254 			}
2255 
2256 			end_scope_decl();
2257 		}
2258 	}
2259 }
2260 
to_sampler_expression(uint32_t id)2261 string CompilerHLSL::to_sampler_expression(uint32_t id)
2262 {
2263 	auto expr = join("_", to_non_uniform_aware_expression(id));
2264 	auto index = expr.find_first_of('[');
2265 	if (index == string::npos)
2266 	{
2267 		return expr + "_sampler";
2268 	}
2269 	else
2270 	{
2271 		// We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2272 		return expr.insert(index, "_sampler");
2273 	}
2274 }
2275 
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)2276 void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2277 {
2278 	if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2279 	{
2280 		set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
2281 	}
2282 	else
2283 	{
2284 		// Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2285 		emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true);
2286 	}
2287 }
2288 
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)2289 string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2290 {
2291 	string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2292 
2293 	if (hlsl_options.shader_model <= 30)
2294 		return arg_str;
2295 
2296 	// Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2297 	auto &type = expression_type(id);
2298 
2299 	// We don't have to consider combined image samplers here via OpSampledImage because
2300 	// those variables cannot be passed as arguments to functions.
2301 	// Only global SampledImage variables may be used as arguments.
2302 	if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2303 		arg_str += ", " + to_sampler_expression(id);
2304 
2305 	return arg_str;
2306 }
2307 
emit_function_prototype(SPIRFunction & func,const Bitset & return_flags)2308 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2309 {
2310 	if (func.self != ir.default_entry_point)
2311 		add_function_overload(func);
2312 
2313 	auto &execution = get_entry_point();
2314 	// Avoid shadow declarations.
2315 	local_variable_names = resource_names;
2316 
2317 	string decl;
2318 
2319 	auto &type = get<SPIRType>(func.return_type);
2320 	if (type.array.empty())
2321 	{
2322 		decl += flags_to_qualifiers_glsl(type, return_flags);
2323 		decl += type_to_glsl(type);
2324 		decl += " ";
2325 	}
2326 	else
2327 	{
2328 		// We cannot return arrays in HLSL, so "return" through an out variable.
2329 		decl = "void ";
2330 	}
2331 
2332 	if (func.self == ir.default_entry_point)
2333 	{
2334 		if (execution.model == ExecutionModelVertex)
2335 			decl += "vert_main";
2336 		else if (execution.model == ExecutionModelFragment)
2337 			decl += "frag_main";
2338 		else if (execution.model == ExecutionModelGLCompute)
2339 			decl += "comp_main";
2340 		else
2341 			SPIRV_CROSS_THROW("Unsupported execution model.");
2342 		processing_entry_point = true;
2343 	}
2344 	else
2345 		decl += to_name(func.self);
2346 
2347 	decl += "(";
2348 	SmallVector<string> arglist;
2349 
2350 	if (!type.array.empty())
2351 	{
2352 		// Fake array returns by writing to an out array instead.
2353 		string out_argument;
2354 		out_argument += "out ";
2355 		out_argument += type_to_glsl(type);
2356 		out_argument += " ";
2357 		out_argument += "spvReturnValue";
2358 		out_argument += type_to_array_glsl(type);
2359 		arglist.push_back(move(out_argument));
2360 	}
2361 
2362 	for (auto &arg : func.arguments)
2363 	{
2364 		// Do not pass in separate images or samplers if we're remapping
2365 		// to combined image samplers.
2366 		if (skip_argument(arg.id))
2367 			continue;
2368 
2369 		// Might change the variable name if it already exists in this function.
2370 		// SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2371 		// to use same name for variables.
2372 		// Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2373 		add_local_variable_name(arg.id);
2374 
2375 		arglist.push_back(argument_decl(arg));
2376 
2377 		// Flatten a combined sampler to two separate arguments in modern HLSL.
2378 		auto &arg_type = get<SPIRType>(arg.type);
2379 		if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2380 		    arg_type.image.dim != DimBuffer)
2381 		{
2382 			// Manufacture automatic sampler arg for SampledImage texture
2383 			arglist.push_back(join(image_is_comparison(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
2384 			                       to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
2385 		}
2386 
2387 		// Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2388 		auto *var = maybe_get<SPIRVariable>(arg.id);
2389 		if (var)
2390 			var->parameter = &arg;
2391 	}
2392 
2393 	for (auto &arg : func.shadow_arguments)
2394 	{
2395 		// Might change the variable name if it already exists in this function.
2396 		// SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2397 		// to use same name for variables.
2398 		// Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2399 		add_local_variable_name(arg.id);
2400 
2401 		arglist.push_back(argument_decl(arg));
2402 
2403 		// Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2404 		auto *var = maybe_get<SPIRVariable>(arg.id);
2405 		if (var)
2406 			var->parameter = &arg;
2407 	}
2408 
2409 	decl += merge(arglist);
2410 	decl += ")";
2411 	statement(decl);
2412 }
2413 
emit_hlsl_entry_point()2414 void CompilerHLSL::emit_hlsl_entry_point()
2415 {
2416 	SmallVector<string> arguments;
2417 
2418 	if (require_input)
2419 		arguments.push_back("SPIRV_Cross_Input stage_input");
2420 
2421 	auto &execution = get_entry_point();
2422 
2423 	switch (execution.model)
2424 	{
2425 	case ExecutionModelGLCompute:
2426 	{
2427 		SpecializationConstant wg_x, wg_y, wg_z;
2428 		get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
2429 
2430 		uint32_t x = execution.workgroup_size.x;
2431 		uint32_t y = execution.workgroup_size.y;
2432 		uint32_t z = execution.workgroup_size.z;
2433 
2434 		auto x_expr = wg_x.id ? get<SPIRConstant>(wg_x.id).specialization_constant_macro_name : to_string(x);
2435 		auto y_expr = wg_y.id ? get<SPIRConstant>(wg_y.id).specialization_constant_macro_name : to_string(y);
2436 		auto z_expr = wg_z.id ? get<SPIRConstant>(wg_z.id).specialization_constant_macro_name : to_string(z);
2437 
2438 		statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]");
2439 		break;
2440 	}
2441 	case ExecutionModelFragment:
2442 		if (execution.flags.get(ExecutionModeEarlyFragmentTests))
2443 			statement("[earlydepthstencil]");
2444 		break;
2445 	default:
2446 		break;
2447 	}
2448 
2449 	statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")");
2450 	begin_scope();
2451 	bool legacy = hlsl_options.shader_model <= 30;
2452 
2453 	// Copy builtins from entry point arguments to globals.
2454 	active_input_builtins.for_each_bit([&](uint32_t i) {
2455 		auto builtin = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassInput);
2456 		switch (static_cast<BuiltIn>(i))
2457 		{
2458 		case BuiltInFragCoord:
2459 			// VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2460 			// TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2461 			// on a D3D10+ system with a different rasterization config?
2462 			if (legacy)
2463 				statement(builtin, " = stage_input.", builtin, " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2464 			else
2465 			{
2466 				statement(builtin, " = stage_input.", builtin, ";");
2467 				// ZW are undefined in D3D9, only do this fixup here.
2468 				statement(builtin, ".w = 1.0 / ", builtin, ".w;");
2469 			}
2470 			break;
2471 
2472 		case BuiltInVertexId:
2473 		case BuiltInVertexIndex:
2474 		case BuiltInInstanceIndex:
2475 			// D3D semantics are uint, but shader wants int.
2476 			if (hlsl_options.support_nonzero_base_vertex_base_instance)
2477 			{
2478 				if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2479 					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
2480 				else
2481 					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
2482 			}
2483 			else
2484 				statement(builtin, " = int(stage_input.", builtin, ");");
2485 			break;
2486 
2487 		case BuiltInInstanceId:
2488 			// D3D semantics are uint, but shader wants int.
2489 			statement(builtin, " = int(stage_input.", builtin, ");");
2490 			break;
2491 
2492 		case BuiltInNumWorkgroups:
2493 		case BuiltInPointCoord:
2494 		case BuiltInSubgroupSize:
2495 		case BuiltInSubgroupLocalInvocationId:
2496 			break;
2497 
2498 		case BuiltInSubgroupEqMask:
2499 			// Emulate these ...
2500 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2501 			statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2502 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2503 			statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2504 			statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2505 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2506 			break;
2507 
2508 		case BuiltInSubgroupGeMask:
2509 			// Emulate these ...
2510 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2511 			statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2512 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2513 			statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2514 			statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2515 			statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2516 			statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2517 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2518 			break;
2519 
2520 		case BuiltInSubgroupGtMask:
2521 			// Emulate these ...
2522 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2523 			statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
2524 			statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2525 			statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2526 			statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2527 			statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2528 			statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2529 			statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2530 			statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2531 			statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2532 			break;
2533 
2534 		case BuiltInSubgroupLeMask:
2535 			// Emulate these ...
2536 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2537 			statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
2538 			statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2539 			statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2540 			statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2541 			statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2542 			statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2543 			statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2544 			statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2545 			statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2546 			break;
2547 
2548 		case BuiltInSubgroupLtMask:
2549 			// Emulate these ...
2550 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2551 			statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2552 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2553 			statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2554 			statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2555 			statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2556 			statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2557 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2558 			break;
2559 
2560 		case BuiltInClipDistance:
2561 			for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2562 				statement("gl_ClipDistance[", clip, "] = stage_input.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3],
2563 				          ";");
2564 			break;
2565 
2566 		case BuiltInCullDistance:
2567 			for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2568 				statement("gl_CullDistance[", cull, "] = stage_input.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3],
2569 				          ";");
2570 			break;
2571 
2572 		default:
2573 			statement(builtin, " = stage_input.", builtin, ";");
2574 			break;
2575 		}
2576 	});
2577 
2578 	// Copy from stage input struct to globals.
2579 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2580 		auto &type = this->get<SPIRType>(var.basetype);
2581 		bool block = has_decoration(type.self, DecorationBlock);
2582 
2583 		if (var.storage != StorageClassInput)
2584 			return;
2585 
2586 		bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2587 
2588 		if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2589 		    interface_variable_exists_in_entry_point(var.self))
2590 		{
2591 			if (block)
2592 			{
2593 				auto type_name = to_name(type.self);
2594 				auto var_name = to_name(var.self);
2595 				for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2596 				{
2597 					auto mbr_name = to_member_name(type, mbr_idx);
2598 					auto flat_name = join(type_name, "_", mbr_name);
2599 					statement(var_name, ".", mbr_name, " = stage_input.", flat_name, ";");
2600 				}
2601 			}
2602 			else
2603 			{
2604 				auto name = to_name(var.self);
2605 				auto &mtype = this->get<SPIRType>(var.basetype);
2606 				if (need_matrix_unroll && mtype.columns > 1)
2607 				{
2608 					// Unroll matrices.
2609 					for (uint32_t col = 0; col < mtype.columns; col++)
2610 						statement(name, "[", col, "] = stage_input.", name, "_", col, ";");
2611 				}
2612 				else
2613 				{
2614 					statement(name, " = stage_input.", name, ";");
2615 				}
2616 			}
2617 		}
2618 	});
2619 
2620 	// Run the shader.
2621 	if (execution.model == ExecutionModelVertex)
2622 		statement("vert_main();");
2623 	else if (execution.model == ExecutionModelFragment)
2624 		statement("frag_main();");
2625 	else if (execution.model == ExecutionModelGLCompute)
2626 		statement("comp_main();");
2627 	else
2628 		SPIRV_CROSS_THROW("Unsupported shader stage.");
2629 
2630 	// Copy stage outputs.
2631 	if (require_output)
2632 	{
2633 		statement("SPIRV_Cross_Output stage_output;");
2634 
2635 		// Copy builtins from globals to return struct.
2636 		active_output_builtins.for_each_bit([&](uint32_t i) {
2637 			// PointSize doesn't exist in HLSL.
2638 			if (i == BuiltInPointSize)
2639 				return;
2640 
2641 			switch (static_cast<BuiltIn>(i))
2642 			{
2643 			case BuiltInClipDistance:
2644 				for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2645 					statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
2646 					          clip, "];");
2647 				break;
2648 
2649 			case BuiltInCullDistance:
2650 				for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2651 					statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
2652 					          cull, "];");
2653 				break;
2654 
2655 			default:
2656 			{
2657 				auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
2658 				statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
2659 				break;
2660 			}
2661 			}
2662 		});
2663 
2664 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2665 			auto &type = this->get<SPIRType>(var.basetype);
2666 			bool block = has_decoration(type.self, DecorationBlock);
2667 
2668 			if (var.storage != StorageClassOutput)
2669 				return;
2670 
2671 			if (!var.remapped_variable && type.pointer &&
2672 			    !is_builtin_variable(var) &&
2673 			    interface_variable_exists_in_entry_point(var.self))
2674 			{
2675 				if (block)
2676 				{
2677 					// I/O blocks need to flatten output.
2678 					auto type_name = to_name(type.self);
2679 					auto var_name = to_name(var.self);
2680 					for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2681 					{
2682 						auto mbr_name = to_member_name(type, mbr_idx);
2683 						auto flat_name = join(type_name, "_", mbr_name);
2684 						statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";");
2685 					}
2686 				}
2687 				else
2688 				{
2689 					auto name = to_name(var.self);
2690 
2691 					if (legacy && execution.model == ExecutionModelFragment)
2692 					{
2693 						string output_filler;
2694 						for (uint32_t size = type.vecsize; size < 4; ++size)
2695 							output_filler += ", 0.0";
2696 
2697 						statement("stage_output.", name, " = float4(", name, output_filler, ");");
2698 					}
2699 					else
2700 					{
2701 						statement("stage_output.", name, " = ", name, ";");
2702 					}
2703 				}
2704 			}
2705 		});
2706 
2707 		statement("return stage_output;");
2708 	}
2709 
2710 	end_scope();
2711 }
2712 
emit_fixup()2713 void CompilerHLSL::emit_fixup()
2714 {
2715 	if (is_vertex_like_shader())
2716 	{
2717 		// Do various mangling on the gl_Position.
2718 		if (hlsl_options.shader_model <= 30)
2719 		{
2720 			statement("gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2721 			          "gl_Position.w;");
2722 			statement("gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2723 			          "gl_Position.w;");
2724 		}
2725 
2726 		if (options.vertex.flip_vert_y)
2727 			statement("gl_Position.y = -gl_Position.y;");
2728 		if (options.vertex.fixup_clipspace)
2729 			statement("gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2730 	}
2731 }
2732 
emit_texture_op(const Instruction & i,bool sparse)2733 void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2734 {
2735 	if (sparse)
2736 		SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2737 
2738 	auto *ops = stream(i);
2739 	auto op = static_cast<Op>(i.op);
2740 	uint32_t length = i.length;
2741 
2742 	SmallVector<uint32_t> inherited_expressions;
2743 
2744 	uint32_t result_type = ops[0];
2745 	uint32_t id = ops[1];
2746 	VariableID img = ops[2];
2747 	uint32_t coord = ops[3];
2748 	uint32_t dref = 0;
2749 	uint32_t comp = 0;
2750 	bool gather = false;
2751 	bool proj = false;
2752 	const uint32_t *opt = nullptr;
2753 	auto *combined_image = maybe_get<SPIRCombinedImageSampler>(img);
2754 
2755 	if (combined_image && has_decoration(img, DecorationNonUniform))
2756 	{
2757 		set_decoration(combined_image->image, DecorationNonUniform);
2758 		set_decoration(combined_image->sampler, DecorationNonUniform);
2759 	}
2760 
2761 	auto img_expr = to_non_uniform_aware_expression(combined_image ? combined_image->image : img);
2762 
2763 	inherited_expressions.push_back(coord);
2764 
2765 	switch (op)
2766 	{
2767 	case OpImageSampleDrefImplicitLod:
2768 	case OpImageSampleDrefExplicitLod:
2769 		dref = ops[4];
2770 		opt = &ops[5];
2771 		length -= 5;
2772 		break;
2773 
2774 	case OpImageSampleProjDrefImplicitLod:
2775 	case OpImageSampleProjDrefExplicitLod:
2776 		dref = ops[4];
2777 		proj = true;
2778 		opt = &ops[5];
2779 		length -= 5;
2780 		break;
2781 
2782 	case OpImageDrefGather:
2783 		dref = ops[4];
2784 		opt = &ops[5];
2785 		gather = true;
2786 		length -= 5;
2787 		break;
2788 
2789 	case OpImageGather:
2790 		comp = ops[4];
2791 		opt = &ops[5];
2792 		gather = true;
2793 		length -= 5;
2794 		break;
2795 
2796 	case OpImageSampleProjImplicitLod:
2797 	case OpImageSampleProjExplicitLod:
2798 		opt = &ops[4];
2799 		length -= 4;
2800 		proj = true;
2801 		break;
2802 
2803 	case OpImageQueryLod:
2804 		opt = &ops[4];
2805 		length -= 4;
2806 		break;
2807 
2808 	default:
2809 		opt = &ops[4];
2810 		length -= 4;
2811 		break;
2812 	}
2813 
2814 	auto &imgtype = expression_type(img);
2815 	uint32_t coord_components = 0;
2816 	switch (imgtype.image.dim)
2817 	{
2818 	case spv::Dim1D:
2819 		coord_components = 1;
2820 		break;
2821 	case spv::Dim2D:
2822 		coord_components = 2;
2823 		break;
2824 	case spv::Dim3D:
2825 		coord_components = 3;
2826 		break;
2827 	case spv::DimCube:
2828 		coord_components = 3;
2829 		break;
2830 	case spv::DimBuffer:
2831 		coord_components = 1;
2832 		break;
2833 	default:
2834 		coord_components = 2;
2835 		break;
2836 	}
2837 
2838 	if (dref)
2839 		inherited_expressions.push_back(dref);
2840 
2841 	if (imgtype.image.arrayed)
2842 		coord_components++;
2843 
2844 	uint32_t bias = 0;
2845 	uint32_t lod = 0;
2846 	uint32_t grad_x = 0;
2847 	uint32_t grad_y = 0;
2848 	uint32_t coffset = 0;
2849 	uint32_t offset = 0;
2850 	uint32_t coffsets = 0;
2851 	uint32_t sample = 0;
2852 	uint32_t minlod = 0;
2853 	uint32_t flags = 0;
2854 
2855 	if (length)
2856 	{
2857 		flags = opt[0];
2858 		opt++;
2859 		length--;
2860 	}
2861 
2862 	auto test = [&](uint32_t &v, uint32_t flag) {
2863 		if (length && (flags & flag))
2864 		{
2865 			v = *opt++;
2866 			inherited_expressions.push_back(v);
2867 			length--;
2868 		}
2869 	};
2870 
2871 	test(bias, ImageOperandsBiasMask);
2872 	test(lod, ImageOperandsLodMask);
2873 	test(grad_x, ImageOperandsGradMask);
2874 	test(grad_y, ImageOperandsGradMask);
2875 	test(coffset, ImageOperandsConstOffsetMask);
2876 	test(offset, ImageOperandsOffsetMask);
2877 	test(coffsets, ImageOperandsConstOffsetsMask);
2878 	test(sample, ImageOperandsSampleMask);
2879 	test(minlod, ImageOperandsMinLodMask);
2880 
2881 	string expr;
2882 	string texop;
2883 
2884 	if (minlod != 0)
2885 		SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2886 
2887 	if (op == OpImageFetch)
2888 	{
2889 		if (hlsl_options.shader_model < 40)
2890 		{
2891 			SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
2892 		}
2893 		texop += img_expr;
2894 		texop += ".Load";
2895 	}
2896 	else if (op == OpImageQueryLod)
2897 	{
2898 		texop += img_expr;
2899 		texop += ".CalculateLevelOfDetail";
2900 	}
2901 	else
2902 	{
2903 		auto &imgformat = get<SPIRType>(imgtype.image.type);
2904 		if (imgformat.basetype != SPIRType::Float)
2905 		{
2906 			SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
2907 		}
2908 
2909 		if (hlsl_options.shader_model >= 40)
2910 		{
2911 			texop += img_expr;
2912 
2913 			if (image_is_comparison(imgtype, img))
2914 			{
2915 				if (gather)
2916 				{
2917 					SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
2918 				}
2919 				else if (lod || grad_x || grad_y)
2920 				{
2921 					// Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
2922 					texop += ".SampleCmpLevelZero";
2923 				}
2924 				else
2925 					texop += ".SampleCmp";
2926 			}
2927 			else if (gather)
2928 			{
2929 				uint32_t comp_num = evaluate_constant_u32(comp);
2930 				if (hlsl_options.shader_model >= 50)
2931 				{
2932 					switch (comp_num)
2933 					{
2934 					case 0:
2935 						texop += ".GatherRed";
2936 						break;
2937 					case 1:
2938 						texop += ".GatherGreen";
2939 						break;
2940 					case 2:
2941 						texop += ".GatherBlue";
2942 						break;
2943 					case 3:
2944 						texop += ".GatherAlpha";
2945 						break;
2946 					default:
2947 						SPIRV_CROSS_THROW("Invalid component.");
2948 					}
2949 				}
2950 				else
2951 				{
2952 					if (comp_num == 0)
2953 						texop += ".Gather";
2954 					else
2955 						SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
2956 				}
2957 			}
2958 			else if (bias)
2959 				texop += ".SampleBias";
2960 			else if (grad_x || grad_y)
2961 				texop += ".SampleGrad";
2962 			else if (lod)
2963 				texop += ".SampleLevel";
2964 			else
2965 				texop += ".Sample";
2966 		}
2967 		else
2968 		{
2969 			switch (imgtype.image.dim)
2970 			{
2971 			case Dim1D:
2972 				texop += "tex1D";
2973 				break;
2974 			case Dim2D:
2975 				texop += "tex2D";
2976 				break;
2977 			case Dim3D:
2978 				texop += "tex3D";
2979 				break;
2980 			case DimCube:
2981 				texop += "texCUBE";
2982 				break;
2983 			case DimRect:
2984 			case DimBuffer:
2985 			case DimSubpassData:
2986 				SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
2987 			default:
2988 				SPIRV_CROSS_THROW("Invalid dimension.");
2989 			}
2990 
2991 			if (gather)
2992 				SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
2993 			if (offset || coffset)
2994 				SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
2995 
2996 			if (grad_x || grad_y)
2997 				texop += "grad";
2998 			else if (lod)
2999 				texop += "lod";
3000 			else if (bias)
3001 				texop += "bias";
3002 			else if (proj || dref)
3003 				texop += "proj";
3004 		}
3005 	}
3006 
3007 	expr += texop;
3008 	expr += "(";
3009 	if (hlsl_options.shader_model < 40)
3010 	{
3011 		if (combined_image)
3012 			SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3013 		expr += to_expression(img);
3014 	}
3015 	else if (op != OpImageFetch)
3016 	{
3017 		string sampler_expr;
3018 		if (combined_image)
3019 			sampler_expr = to_non_uniform_aware_expression(combined_image->sampler);
3020 		else
3021 			sampler_expr = to_sampler_expression(img);
3022 		expr += sampler_expr;
3023 	}
3024 
3025 	auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3026 		if (comps == in_comps)
3027 			return "";
3028 
3029 		switch (comps)
3030 		{
3031 		case 1:
3032 			return ".x";
3033 		case 2:
3034 			return ".xy";
3035 		case 3:
3036 			return ".xyz";
3037 		default:
3038 			return "";
3039 		}
3040 	};
3041 
3042 	bool forward = should_forward(coord);
3043 
3044 	// The IR can give us more components than we need, so chop them off as needed.
3045 	string coord_expr;
3046 	auto &coord_type = expression_type(coord);
3047 	if (coord_components != coord_type.vecsize)
3048 		coord_expr = to_enclosed_expression(coord) + swizzle(coord_components, expression_type(coord).vecsize);
3049 	else
3050 		coord_expr = to_expression(coord);
3051 
3052 	if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3053 		coord_expr = coord_expr + " / " + to_extract_component_expression(coord, coord_components);
3054 
3055 	if (hlsl_options.shader_model < 40)
3056 	{
3057 		if (dref)
3058 		{
3059 			if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3060 			{
3061 				SPIRV_CROSS_THROW(
3062 				    "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3063 			}
3064 
3065 			if (grad_x || grad_y)
3066 				SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3067 
3068 			for (uint32_t size = coord_components; size < 2; ++size)
3069 				coord_expr += ", 0.0";
3070 
3071 			forward = forward && should_forward(dref);
3072 			coord_expr += ", " + to_expression(dref);
3073 		}
3074 		else if (lod || bias || proj)
3075 		{
3076 			for (uint32_t size = coord_components; size < 3; ++size)
3077 				coord_expr += ", 0.0";
3078 		}
3079 
3080 		if (lod)
3081 		{
3082 			coord_expr = "float4(" + coord_expr + ", " + to_expression(lod) + ")";
3083 		}
3084 		else if (bias)
3085 		{
3086 			coord_expr = "float4(" + coord_expr + ", " + to_expression(bias) + ")";
3087 		}
3088 		else if (proj)
3089 		{
3090 			coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(coord, coord_components) + ")";
3091 		}
3092 		else if (dref)
3093 		{
3094 			// A "normal" sample gets fed into tex2Dproj as well, because the
3095 			// regular tex2D accepts only two coordinates.
3096 			coord_expr = "float4(" + coord_expr + ", 1.0)";
3097 		}
3098 
3099 		if (!!lod + !!bias + !!proj > 1)
3100 			SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3101 	}
3102 
3103 	if (op == OpImageFetch)
3104 	{
3105 		if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3106 			coord_expr =
3107 			    join("int", coord_components + 1, "(", coord_expr, ", ", lod ? to_expression(lod) : string("0"), ")");
3108 	}
3109 	else
3110 		expr += ", ";
3111 	expr += coord_expr;
3112 
3113 	if (dref && hlsl_options.shader_model >= 40)
3114 	{
3115 		forward = forward && should_forward(dref);
3116 		expr += ", ";
3117 
3118 		if (proj)
3119 			expr += to_enclosed_expression(dref) + " / " + to_extract_component_expression(coord, coord_components);
3120 		else
3121 			expr += to_expression(dref);
3122 	}
3123 
3124 	if (!dref && (grad_x || grad_y))
3125 	{
3126 		forward = forward && should_forward(grad_x);
3127 		forward = forward && should_forward(grad_y);
3128 		expr += ", ";
3129 		expr += to_expression(grad_x);
3130 		expr += ", ";
3131 		expr += to_expression(grad_y);
3132 	}
3133 
3134 	if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3135 	{
3136 		forward = forward && should_forward(lod);
3137 		expr += ", ";
3138 		expr += to_expression(lod);
3139 	}
3140 
3141 	if (!dref && bias && hlsl_options.shader_model >= 40)
3142 	{
3143 		forward = forward && should_forward(bias);
3144 		expr += ", ";
3145 		expr += to_expression(bias);
3146 	}
3147 
3148 	if (coffset)
3149 	{
3150 		forward = forward && should_forward(coffset);
3151 		expr += ", ";
3152 		expr += to_expression(coffset);
3153 	}
3154 	else if (offset)
3155 	{
3156 		forward = forward && should_forward(offset);
3157 		expr += ", ";
3158 		expr += to_expression(offset);
3159 	}
3160 
3161 	if (sample)
3162 	{
3163 		expr += ", ";
3164 		expr += to_expression(sample);
3165 	}
3166 
3167 	expr += ")";
3168 
3169 	if (dref && hlsl_options.shader_model < 40)
3170 		expr += ".x";
3171 
3172 	if (op == OpImageQueryLod)
3173 	{
3174 		// This is rather awkward.
3175 		// textureQueryLod returns two values, the "accessed level",
3176 		// as well as the actual LOD lambda.
3177 		// As far as I can tell, there is no way to get the .x component
3178 		// according to GLSL spec, and it depends on the sampler itself.
3179 		// Just assume X == Y, so we will need to splat the result to a float2.
3180 		statement("float _", id, "_tmp = ", expr, ";");
3181 		statement("float2 _", id, " = _", id, "_tmp.xx;");
3182 		set<SPIRExpression>(id, join("_", id), result_type, true);
3183 	}
3184 	else
3185 	{
3186 		emit_op(result_type, id, expr, forward, false);
3187 	}
3188 
3189 	for (auto &inherit : inherited_expressions)
3190 		inherit_expression_dependencies(id, inherit);
3191 
3192 	switch (op)
3193 	{
3194 	case OpImageSampleDrefImplicitLod:
3195 	case OpImageSampleImplicitLod:
3196 	case OpImageSampleProjImplicitLod:
3197 	case OpImageSampleProjDrefImplicitLod:
3198 		register_control_dependent_expression(id);
3199 		break;
3200 
3201 	default:
3202 		break;
3203 	}
3204 }
3205 
to_resource_binding(const SPIRVariable & var)3206 string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3207 {
3208 	const auto &type = get<SPIRType>(var.basetype);
3209 
3210 	// We can remap push constant blocks, even if they don't have any binding decoration.
3211 	if (type.storage != StorageClassPushConstant && !has_decoration(var.self, DecorationBinding))
3212 		return "";
3213 
3214 	char space = '\0';
3215 
3216 	HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3217 
3218 	switch (type.basetype)
3219 	{
3220 	case SPIRType::SampledImage:
3221 		space = 't'; // SRV
3222 		resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3223 		break;
3224 
3225 	case SPIRType::Image:
3226 		if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3227 		{
3228 			if (has_decoration(var.self, DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3229 			{
3230 				space = 't'; // SRV
3231 				resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3232 			}
3233 			else
3234 			{
3235 				space = 'u'; // UAV
3236 				resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3237 			}
3238 		}
3239 		else
3240 		{
3241 			space = 't'; // SRV
3242 			resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3243 		}
3244 		break;
3245 
3246 	case SPIRType::Sampler:
3247 		space = 's';
3248 		resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3249 		break;
3250 
3251 	case SPIRType::Struct:
3252 	{
3253 		auto storage = type.storage;
3254 		if (storage == StorageClassUniform)
3255 		{
3256 			if (has_decoration(type.self, DecorationBufferBlock))
3257 			{
3258 				Bitset flags = ir.get_buffer_block_flags(var);
3259 				bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3260 				space = is_readonly ? 't' : 'u'; // UAV
3261 				resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3262 			}
3263 			else if (has_decoration(type.self, DecorationBlock))
3264 			{
3265 				space = 'b'; // Constant buffers
3266 				resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3267 			}
3268 		}
3269 		else if (storage == StorageClassPushConstant)
3270 		{
3271 			space = 'b'; // Constant buffers
3272 			resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3273 		}
3274 		else if (storage == StorageClassStorageBuffer)
3275 		{
3276 			// UAV or SRV depending on readonly flag.
3277 			Bitset flags = ir.get_buffer_block_flags(var);
3278 			bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3279 			space = is_readonly ? 't' : 'u';
3280 			resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3281 		}
3282 
3283 		break;
3284 	}
3285 	default:
3286 		break;
3287 	}
3288 
3289 	if (!space)
3290 		return "";
3291 
3292 	uint32_t desc_set =
3293 	    resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3294 	uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3295 
3296 	if (has_decoration(var.self, DecorationBinding))
3297 		binding = get_decoration(var.self, DecorationBinding);
3298 	if (has_decoration(var.self, DecorationDescriptorSet))
3299 		desc_set = get_decoration(var.self, DecorationDescriptorSet);
3300 
3301 	return to_resource_register(resource_flags, space, binding, desc_set);
3302 }
3303 
to_resource_binding_sampler(const SPIRVariable & var)3304 string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3305 {
3306 	// For combined image samplers.
3307 	if (!has_decoration(var.self, DecorationBinding))
3308 		return "";
3309 
3310 	return to_resource_register(HLSL_BINDING_AUTO_SAMPLER_BIT, 's', get_decoration(var.self, DecorationBinding),
3311 	                            get_decoration(var.self, DecorationDescriptorSet));
3312 }
3313 
remap_hlsl_resource_binding(HLSLBindingFlagBits type,uint32_t & desc_set,uint32_t & binding)3314 void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3315 {
3316 	auto itr = resource_bindings.find({ get_execution_model(), desc_set, binding });
3317 	if (itr != end(resource_bindings))
3318 	{
3319 		auto &remap = itr->second;
3320 		remap.second = true;
3321 
3322 		switch (type)
3323 		{
3324 		case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3325 		case HLSL_BINDING_AUTO_CBV_BIT:
3326 			desc_set = remap.first.cbv.register_space;
3327 			binding = remap.first.cbv.register_binding;
3328 			break;
3329 
3330 		case HLSL_BINDING_AUTO_SRV_BIT:
3331 			desc_set = remap.first.srv.register_space;
3332 			binding = remap.first.srv.register_binding;
3333 			break;
3334 
3335 		case HLSL_BINDING_AUTO_SAMPLER_BIT:
3336 			desc_set = remap.first.sampler.register_space;
3337 			binding = remap.first.sampler.register_binding;
3338 			break;
3339 
3340 		case HLSL_BINDING_AUTO_UAV_BIT:
3341 			desc_set = remap.first.uav.register_space;
3342 			binding = remap.first.uav.register_binding;
3343 			break;
3344 
3345 		default:
3346 			break;
3347 		}
3348 	}
3349 }
3350 
to_resource_register(HLSLBindingFlagBits flag,char space,uint32_t binding,uint32_t space_set)3351 string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3352 {
3353 	if ((flag & resource_binding_flags) == 0)
3354 	{
3355 		remap_hlsl_resource_binding(flag, space_set, binding);
3356 
3357 		// The push constant block did not have a binding, and there were no remap for it,
3358 		// so, declare without register binding.
3359 		if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3360 			return "";
3361 
3362 		if (hlsl_options.shader_model >= 51)
3363 			return join(" : register(", space, binding, ", space", space_set, ")");
3364 		else
3365 			return join(" : register(", space, binding, ")");
3366 	}
3367 	else
3368 		return "";
3369 }
3370 
emit_modern_uniform(const SPIRVariable & var)3371 void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3372 {
3373 	auto &type = get<SPIRType>(var.basetype);
3374 	switch (type.basetype)
3375 	{
3376 	case SPIRType::SampledImage:
3377 	case SPIRType::Image:
3378 	{
3379 		bool is_coherent = false;
3380 		if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3381 			is_coherent = has_decoration(var.self, DecorationCoherent);
3382 
3383 		statement(is_coherent ? "globallycoherent " : "", image_type_hlsl_modern(type, var.self), " ",
3384 		          to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3385 
3386 		if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3387 		{
3388 			// For combined image samplers, also emit a combined image sampler.
3389 			if (image_is_comparison(type, var.self))
3390 				statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3391 				          to_resource_binding_sampler(var), ";");
3392 			else
3393 				statement("SamplerState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3394 				          to_resource_binding_sampler(var), ";");
3395 		}
3396 		break;
3397 	}
3398 
3399 	case SPIRType::Sampler:
3400 		if (comparison_ids.count(var.self))
3401 			statement("SamplerComparisonState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var),
3402 			          ";");
3403 		else
3404 			statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3405 		break;
3406 
3407 	default:
3408 		statement(variable_decl(var), to_resource_binding(var), ";");
3409 		break;
3410 	}
3411 }
3412 
emit_legacy_uniform(const SPIRVariable & var)3413 void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3414 {
3415 	auto &type = get<SPIRType>(var.basetype);
3416 	switch (type.basetype)
3417 	{
3418 	case SPIRType::Sampler:
3419 	case SPIRType::Image:
3420 		SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3421 
3422 	default:
3423 		statement(variable_decl(var), ";");
3424 		break;
3425 	}
3426 }
3427 
emit_uniform(const SPIRVariable & var)3428 void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3429 {
3430 	add_resource_name(var.self);
3431 	if (hlsl_options.shader_model >= 40)
3432 		emit_modern_uniform(var);
3433 	else
3434 		emit_legacy_uniform(var);
3435 }
3436 
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)3437 bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3438 {
3439 	return false;
3440 }
3441 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)3442 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3443 {
3444 	if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3445 		return type_to_glsl(out_type);
3446 	else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3447 		return type_to_glsl(out_type);
3448 	else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3449 		return "asuint";
3450 	else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3451 		return type_to_glsl(out_type);
3452 	else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3453 		return type_to_glsl(out_type);
3454 	else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3455 		return "asint";
3456 	else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3457 		return "asfloat";
3458 	else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3459 		return "asfloat";
3460 	else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3461 		SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3462 	else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3463 		SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3464 	else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3465 		return "asdouble";
3466 	else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3467 		return "asdouble";
3468 	else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3469 	{
3470 		if (!requires_explicit_fp16_packing)
3471 		{
3472 			requires_explicit_fp16_packing = true;
3473 			force_recompile();
3474 		}
3475 		return "spvUnpackFloat2x16";
3476 	}
3477 	else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3478 	{
3479 		if (!requires_explicit_fp16_packing)
3480 		{
3481 			requires_explicit_fp16_packing = true;
3482 			force_recompile();
3483 		}
3484 		return "spvPackFloat2x16";
3485 	}
3486 	else
3487 		return "";
3488 }
3489 
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)3490 void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3491 {
3492 	auto op = static_cast<GLSLstd450>(eop);
3493 
3494 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
3495 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
3496 	auto int_type = to_signed_basetype(integer_width);
3497 	auto uint_type = to_unsigned_basetype(integer_width);
3498 
3499 	switch (op)
3500 	{
3501 	case GLSLstd450InverseSqrt:
3502 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
3503 		break;
3504 
3505 	case GLSLstd450Fract:
3506 		emit_unary_func_op(result_type, id, args[0], "frac");
3507 		break;
3508 
3509 	case GLSLstd450RoundEven:
3510 		if (hlsl_options.shader_model < 40)
3511 			SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3512 		emit_unary_func_op(result_type, id, args[0], "round");
3513 		break;
3514 
3515 	case GLSLstd450Acosh:
3516 	case GLSLstd450Asinh:
3517 	case GLSLstd450Atanh:
3518 		SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3519 
3520 	case GLSLstd450FMix:
3521 	case GLSLstd450IMix:
3522 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "lerp");
3523 		break;
3524 
3525 	case GLSLstd450Atan2:
3526 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
3527 		break;
3528 
3529 	case GLSLstd450Fma:
3530 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mad");
3531 		break;
3532 
3533 	case GLSLstd450InterpolateAtCentroid:
3534 		emit_unary_func_op(result_type, id, args[0], "EvaluateAttributeAtCentroid");
3535 		break;
3536 	case GLSLstd450InterpolateAtSample:
3537 		emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeAtSample");
3538 		break;
3539 	case GLSLstd450InterpolateAtOffset:
3540 		emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeSnapped");
3541 		break;
3542 
3543 	case GLSLstd450PackHalf2x16:
3544 		if (!requires_fp16_packing)
3545 		{
3546 			requires_fp16_packing = true;
3547 			force_recompile();
3548 		}
3549 		emit_unary_func_op(result_type, id, args[0], "spvPackHalf2x16");
3550 		break;
3551 
3552 	case GLSLstd450UnpackHalf2x16:
3553 		if (!requires_fp16_packing)
3554 		{
3555 			requires_fp16_packing = true;
3556 			force_recompile();
3557 		}
3558 		emit_unary_func_op(result_type, id, args[0], "spvUnpackHalf2x16");
3559 		break;
3560 
3561 	case GLSLstd450PackSnorm4x8:
3562 		if (!requires_snorm8_packing)
3563 		{
3564 			requires_snorm8_packing = true;
3565 			force_recompile();
3566 		}
3567 		emit_unary_func_op(result_type, id, args[0], "spvPackSnorm4x8");
3568 		break;
3569 
3570 	case GLSLstd450UnpackSnorm4x8:
3571 		if (!requires_snorm8_packing)
3572 		{
3573 			requires_snorm8_packing = true;
3574 			force_recompile();
3575 		}
3576 		emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm4x8");
3577 		break;
3578 
3579 	case GLSLstd450PackUnorm4x8:
3580 		if (!requires_unorm8_packing)
3581 		{
3582 			requires_unorm8_packing = true;
3583 			force_recompile();
3584 		}
3585 		emit_unary_func_op(result_type, id, args[0], "spvPackUnorm4x8");
3586 		break;
3587 
3588 	case GLSLstd450UnpackUnorm4x8:
3589 		if (!requires_unorm8_packing)
3590 		{
3591 			requires_unorm8_packing = true;
3592 			force_recompile();
3593 		}
3594 		emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm4x8");
3595 		break;
3596 
3597 	case GLSLstd450PackSnorm2x16:
3598 		if (!requires_snorm16_packing)
3599 		{
3600 			requires_snorm16_packing = true;
3601 			force_recompile();
3602 		}
3603 		emit_unary_func_op(result_type, id, args[0], "spvPackSnorm2x16");
3604 		break;
3605 
3606 	case GLSLstd450UnpackSnorm2x16:
3607 		if (!requires_snorm16_packing)
3608 		{
3609 			requires_snorm16_packing = true;
3610 			force_recompile();
3611 		}
3612 		emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm2x16");
3613 		break;
3614 
3615 	case GLSLstd450PackUnorm2x16:
3616 		if (!requires_unorm16_packing)
3617 		{
3618 			requires_unorm16_packing = true;
3619 			force_recompile();
3620 		}
3621 		emit_unary_func_op(result_type, id, args[0], "spvPackUnorm2x16");
3622 		break;
3623 
3624 	case GLSLstd450UnpackUnorm2x16:
3625 		if (!requires_unorm16_packing)
3626 		{
3627 			requires_unorm16_packing = true;
3628 			force_recompile();
3629 		}
3630 		emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm2x16");
3631 		break;
3632 
3633 	case GLSLstd450PackDouble2x32:
3634 	case GLSLstd450UnpackDouble2x32:
3635 		SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3636 
3637 	case GLSLstd450FindILsb:
3638 	{
3639 		auto basetype = expression_type(args[0]).basetype;
3640 		emit_unary_func_op_cast(result_type, id, args[0], "firstbitlow", basetype, basetype);
3641 		break;
3642 	}
3643 
3644 	case GLSLstd450FindSMsb:
3645 		emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", int_type, int_type);
3646 		break;
3647 
3648 	case GLSLstd450FindUMsb:
3649 		emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", uint_type, uint_type);
3650 		break;
3651 
3652 	case GLSLstd450MatrixInverse:
3653 	{
3654 		auto &type = get<SPIRType>(result_type);
3655 		if (type.vecsize == 2 && type.columns == 2)
3656 		{
3657 			if (!requires_inverse_2x2)
3658 			{
3659 				requires_inverse_2x2 = true;
3660 				force_recompile();
3661 			}
3662 		}
3663 		else if (type.vecsize == 3 && type.columns == 3)
3664 		{
3665 			if (!requires_inverse_3x3)
3666 			{
3667 				requires_inverse_3x3 = true;
3668 				force_recompile();
3669 			}
3670 		}
3671 		else if (type.vecsize == 4 && type.columns == 4)
3672 		{
3673 			if (!requires_inverse_4x4)
3674 			{
3675 				requires_inverse_4x4 = true;
3676 				force_recompile();
3677 			}
3678 		}
3679 		emit_unary_func_op(result_type, id, args[0], "spvInverse");
3680 		break;
3681 	}
3682 
3683 	case GLSLstd450Normalize:
3684 		// HLSL does not support scalar versions here.
3685 		if (expression_type(args[0]).vecsize == 1)
3686 		{
3687 			// Returns -1 or 1 for valid input, sign() does the job.
3688 			emit_unary_func_op(result_type, id, args[0], "sign");
3689 		}
3690 		else
3691 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3692 		break;
3693 
3694 	case GLSLstd450Reflect:
3695 		if (get<SPIRType>(result_type).vecsize == 1)
3696 		{
3697 			if (!requires_scalar_reflect)
3698 			{
3699 				requires_scalar_reflect = true;
3700 				force_recompile();
3701 			}
3702 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
3703 		}
3704 		else
3705 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3706 		break;
3707 
3708 	case GLSLstd450Refract:
3709 		if (get<SPIRType>(result_type).vecsize == 1)
3710 		{
3711 			if (!requires_scalar_refract)
3712 			{
3713 				requires_scalar_refract = true;
3714 				force_recompile();
3715 			}
3716 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
3717 		}
3718 		else
3719 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3720 		break;
3721 
3722 	case GLSLstd450FaceForward:
3723 		if (get<SPIRType>(result_type).vecsize == 1)
3724 		{
3725 			if (!requires_scalar_faceforward)
3726 			{
3727 				requires_scalar_faceforward = true;
3728 				force_recompile();
3729 			}
3730 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
3731 		}
3732 		else
3733 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3734 		break;
3735 
3736 	default:
3737 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3738 		break;
3739 	}
3740 }
3741 
read_access_chain_array(const string & lhs,const SPIRAccessChain & chain)3742 void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3743 {
3744 	auto &type = get<SPIRType>(chain.basetype);
3745 
3746 	// Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3747 	auto ident = get_unique_identifier();
3748 
3749 	statement("[unroll]");
3750 	statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
3751 	          ident, "++)");
3752 	begin_scope();
3753 	auto subchain = chain;
3754 	subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
3755 	subchain.basetype = type.parent_type;
3756 	if (!get<SPIRType>(subchain.basetype).array.empty())
3757 		subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
3758 	read_access_chain(nullptr, join(lhs, "[", ident, "]"), subchain);
3759 	end_scope();
3760 }
3761 
read_access_chain_struct(const string & lhs,const SPIRAccessChain & chain)3762 void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3763 {
3764 	auto &type = get<SPIRType>(chain.basetype);
3765 	auto subchain = chain;
3766 	uint32_t member_count = uint32_t(type.member_types.size());
3767 
3768 	for (uint32_t i = 0; i < member_count; i++)
3769 	{
3770 		uint32_t offset = type_struct_member_offset(type, i);
3771 		subchain.static_index = chain.static_index + offset;
3772 		subchain.basetype = type.member_types[i];
3773 
3774 		subchain.matrix_stride = 0;
3775 		subchain.array_stride = 0;
3776 		subchain.row_major_matrix = false;
3777 
3778 		auto &member_type = get<SPIRType>(subchain.basetype);
3779 		if (member_type.columns > 1)
3780 		{
3781 			subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
3782 			subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
3783 		}
3784 
3785 		if (!member_type.array.empty())
3786 			subchain.array_stride = type_struct_member_array_stride(type, i);
3787 
3788 		read_access_chain(nullptr, join(lhs, ".", to_member_name(type, i)), subchain);
3789 	}
3790 }
3791 
read_access_chain(string * expr,const string & lhs,const SPIRAccessChain & chain)3792 void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3793 {
3794 	auto &type = get<SPIRType>(chain.basetype);
3795 
3796 	SPIRType target_type;
3797 	target_type.basetype = SPIRType::UInt;
3798 	target_type.vecsize = type.vecsize;
3799 	target_type.columns = type.columns;
3800 
3801 	if (!type.array.empty())
3802 	{
3803 		read_access_chain_array(lhs, chain);
3804 		return;
3805 	}
3806 	else if (type.basetype == SPIRType::Struct)
3807 	{
3808 		read_access_chain_struct(lhs, chain);
3809 		return;
3810 	}
3811 	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3812 		SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3813 		                  "native 16-bit types are enabled.");
3814 
3815 	string base = chain.base;
3816 	if (has_decoration(chain.self, DecorationNonUniform))
3817 		convert_non_uniform_expression(base, chain.self);
3818 
3819 	bool templated_load = hlsl_options.shader_model >= 62;
3820 	string load_expr;
3821 
3822 	string template_expr;
3823 	if (templated_load)
3824 		template_expr = join("<", type_to_glsl(type), ">");
3825 
3826 	// Load a vector or scalar.
3827 	if (type.columns == 1 && !chain.row_major_matrix)
3828 	{
3829 		const char *load_op = nullptr;
3830 		switch (type.vecsize)
3831 		{
3832 		case 1:
3833 			load_op = "Load";
3834 			break;
3835 		case 2:
3836 			load_op = "Load2";
3837 			break;
3838 		case 3:
3839 			load_op = "Load3";
3840 			break;
3841 		case 4:
3842 			load_op = "Load4";
3843 			break;
3844 		default:
3845 			SPIRV_CROSS_THROW("Unknown vector size.");
3846 		}
3847 
3848 		if (templated_load)
3849 			load_op = "Load";
3850 
3851 		load_expr = join(base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
3852 	}
3853 	else if (type.columns == 1)
3854 	{
3855 		// Strided load since we are loading a column from a row-major matrix.
3856 		if (templated_load)
3857 		{
3858 			auto scalar_type = type;
3859 			scalar_type.vecsize = 1;
3860 			scalar_type.columns = 1;
3861 			template_expr = join("<", type_to_glsl(scalar_type), ">");
3862 			if (type.vecsize > 1)
3863 				load_expr += type_to_glsl(type) + "(";
3864 		}
3865 		else if (type.vecsize > 1)
3866 		{
3867 			load_expr = type_to_glsl(target_type);
3868 			load_expr += "(";
3869 		}
3870 
3871 		for (uint32_t r = 0; r < type.vecsize; r++)
3872 		{
3873 			load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3874 			                  chain.static_index + r * chain.matrix_stride, ")");
3875 			if (r + 1 < type.vecsize)
3876 				load_expr += ", ";
3877 		}
3878 
3879 		if (type.vecsize > 1)
3880 			load_expr += ")";
3881 	}
3882 	else if (!chain.row_major_matrix)
3883 	{
3884 		// Load a matrix, column-major, the easy case.
3885 		const char *load_op = nullptr;
3886 		switch (type.vecsize)
3887 		{
3888 		case 1:
3889 			load_op = "Load";
3890 			break;
3891 		case 2:
3892 			load_op = "Load2";
3893 			break;
3894 		case 3:
3895 			load_op = "Load3";
3896 			break;
3897 		case 4:
3898 			load_op = "Load4";
3899 			break;
3900 		default:
3901 			SPIRV_CROSS_THROW("Unknown vector size.");
3902 		}
3903 
3904 		if (templated_load)
3905 		{
3906 			auto vector_type = type;
3907 			vector_type.columns = 1;
3908 			template_expr = join("<", type_to_glsl(vector_type), ">");
3909 			load_expr = type_to_glsl(type);
3910 			load_op = "Load";
3911 		}
3912 		else
3913 		{
3914 			// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
3915 			// so row-major is technically column-major ...
3916 			load_expr = type_to_glsl(target_type);
3917 		}
3918 		load_expr += "(";
3919 
3920 		for (uint32_t c = 0; c < type.columns; c++)
3921 		{
3922 			load_expr += join(base, ".", load_op, template_expr, "(", chain.dynamic_index,
3923 			                  chain.static_index + c * chain.matrix_stride, ")");
3924 			if (c + 1 < type.columns)
3925 				load_expr += ", ";
3926 		}
3927 		load_expr += ")";
3928 	}
3929 	else
3930 	{
3931 		// Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
3932 		// considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
3933 
3934 		if (templated_load)
3935 		{
3936 			load_expr = type_to_glsl(type);
3937 			auto scalar_type = type;
3938 			scalar_type.vecsize = 1;
3939 			scalar_type.columns = 1;
3940 			template_expr = join("<", type_to_glsl(scalar_type), ">");
3941 		}
3942 		else
3943 			load_expr = type_to_glsl(target_type);
3944 
3945 		load_expr += "(";
3946 
3947 		for (uint32_t c = 0; c < type.columns; c++)
3948 		{
3949 			for (uint32_t r = 0; r < type.vecsize; r++)
3950 			{
3951 				load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3952 				                  chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
3953 
3954 				if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
3955 					load_expr += ", ";
3956 			}
3957 		}
3958 		load_expr += ")";
3959 	}
3960 
3961 	if (!templated_load)
3962 	{
3963 		auto bitcast_op = bitcast_glsl_op(type, target_type);
3964 		if (!bitcast_op.empty())
3965 			load_expr = join(bitcast_op, "(", load_expr, ")");
3966 	}
3967 
3968 	if (lhs.empty())
3969 	{
3970 		assert(expr);
3971 		*expr = move(load_expr);
3972 	}
3973 	else
3974 		statement(lhs, " = ", load_expr, ";");
3975 }
3976 
emit_load(const Instruction & instruction)3977 void CompilerHLSL::emit_load(const Instruction &instruction)
3978 {
3979 	auto ops = stream(instruction);
3980 
3981 	auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
3982 	if (chain)
3983 	{
3984 		uint32_t result_type = ops[0];
3985 		uint32_t id = ops[1];
3986 		uint32_t ptr = ops[2];
3987 
3988 		auto &type = get<SPIRType>(result_type);
3989 		bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
3990 
3991 		if (composite_load)
3992 		{
3993 			// We cannot make this work in one single expression as we might have nested structures and arrays,
3994 			// so unroll the load to an uninitialized temporary.
3995 			emit_uninitialized_temporary_expression(result_type, id);
3996 			read_access_chain(nullptr, to_expression(id), *chain);
3997 			track_expression_read(chain->self);
3998 		}
3999 		else
4000 		{
4001 			string load_expr;
4002 			read_access_chain(&load_expr, "", *chain);
4003 
4004 			bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
4005 
4006 			// If we are forwarding this load,
4007 			// don't register the read to access chain here, defer that to when we actually use the expression,
4008 			// using the add_implied_read_expression mechanism.
4009 			if (!forward)
4010 				track_expression_read(chain->self);
4011 
4012 			// Do not forward complex load sequences like matrices, structs and arrays.
4013 			if (type.columns > 1)
4014 				forward = false;
4015 
4016 			auto &e = emit_op(result_type, id, load_expr, forward, true);
4017 			e.need_transpose = false;
4018 			register_read(id, ptr, forward);
4019 			inherit_expression_dependencies(id, ptr);
4020 			if (forward)
4021 				add_implied_read_expression(e, chain->self);
4022 		}
4023 	}
4024 	else
4025 		CompilerGLSL::emit_instruction(instruction);
4026 }
4027 
write_access_chain_array(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4028 void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4029                                             const SmallVector<uint32_t> &composite_chain)
4030 {
4031 	auto &type = get<SPIRType>(chain.basetype);
4032 
4033 	// Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4034 	auto ident = get_unique_identifier();
4035 
4036 	uint32_t id = ir.increase_bound_by(2);
4037 	uint32_t int_type_id = id + 1;
4038 	SPIRType int_type;
4039 	int_type.basetype = SPIRType::Int;
4040 	int_type.width = 32;
4041 	set<SPIRType>(int_type_id, int_type);
4042 	set<SPIRExpression>(id, ident, int_type_id, true);
4043 	set_name(id, ident);
4044 	suppressed_usage_tracking.insert(id);
4045 
4046 	statement("[unroll]");
4047 	statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
4048 	          ident, "++)");
4049 	begin_scope();
4050 	auto subchain = chain;
4051 	subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
4052 	subchain.basetype = type.parent_type;
4053 
4054 	// Forcefully allow us to use an ID here by setting MSB.
4055 	auto subcomposite_chain = composite_chain;
4056 	subcomposite_chain.push_back(0x80000000u | id);
4057 
4058 	if (!get<SPIRType>(subchain.basetype).array.empty())
4059 		subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
4060 
4061 	write_access_chain(subchain, value, subcomposite_chain);
4062 	end_scope();
4063 }
4064 
write_access_chain_struct(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4065 void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4066                                              const SmallVector<uint32_t> &composite_chain)
4067 {
4068 	auto &type = get<SPIRType>(chain.basetype);
4069 	uint32_t member_count = uint32_t(type.member_types.size());
4070 	auto subchain = chain;
4071 
4072 	auto subcomposite_chain = composite_chain;
4073 	subcomposite_chain.push_back(0);
4074 
4075 	for (uint32_t i = 0; i < member_count; i++)
4076 	{
4077 		uint32_t offset = type_struct_member_offset(type, i);
4078 		subchain.static_index = chain.static_index + offset;
4079 		subchain.basetype = type.member_types[i];
4080 
4081 		subchain.matrix_stride = 0;
4082 		subchain.array_stride = 0;
4083 		subchain.row_major_matrix = false;
4084 
4085 		auto &member_type = get<SPIRType>(subchain.basetype);
4086 		if (member_type.columns > 1)
4087 		{
4088 			subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
4089 			subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
4090 		}
4091 
4092 		if (!member_type.array.empty())
4093 			subchain.array_stride = type_struct_member_array_stride(type, i);
4094 
4095 		subcomposite_chain.back() = i;
4096 		write_access_chain(subchain, value, subcomposite_chain);
4097 	}
4098 }
4099 
write_access_chain_value(uint32_t value,const SmallVector<uint32_t> & composite_chain,bool enclose)4100 string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4101                                               bool enclose)
4102 {
4103 	string ret;
4104 	if (composite_chain.empty())
4105 		ret = to_expression(value);
4106 	else
4107 	{
4108 		AccessChainMeta meta;
4109 		ret = access_chain_internal(value, composite_chain.data(), uint32_t(composite_chain.size()),
4110 		                            ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, &meta);
4111 	}
4112 
4113 	if (enclose)
4114 		ret = enclose_expression(ret);
4115 	return ret;
4116 }
4117 
write_access_chain(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4118 void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4119                                       const SmallVector<uint32_t> &composite_chain)
4120 {
4121 	auto &type = get<SPIRType>(chain.basetype);
4122 
4123 	// Make sure we trigger a read of the constituents in the access chain.
4124 	track_expression_read(chain.self);
4125 
4126 	SPIRType target_type;
4127 	target_type.basetype = SPIRType::UInt;
4128 	target_type.vecsize = type.vecsize;
4129 	target_type.columns = type.columns;
4130 
4131 	if (!type.array.empty())
4132 	{
4133 		write_access_chain_array(chain, value, composite_chain);
4134 		register_write(chain.self);
4135 		return;
4136 	}
4137 	else if (type.basetype == SPIRType::Struct)
4138 	{
4139 		write_access_chain_struct(chain, value, composite_chain);
4140 		register_write(chain.self);
4141 		return;
4142 	}
4143 	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4144 		SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4145 		                  "native 16-bit types are enabled.");
4146 
4147 	bool templated_store = hlsl_options.shader_model >= 62;
4148 
4149 	auto base = chain.base;
4150 	if (has_decoration(chain.self, DecorationNonUniform))
4151 		convert_non_uniform_expression(base, chain.self);
4152 
4153 	string template_expr;
4154 	if (templated_store)
4155 		template_expr = join("<", type_to_glsl(type), ">");
4156 
4157 	if (type.columns == 1 && !chain.row_major_matrix)
4158 	{
4159 		const char *store_op = nullptr;
4160 		switch (type.vecsize)
4161 		{
4162 		case 1:
4163 			store_op = "Store";
4164 			break;
4165 		case 2:
4166 			store_op = "Store2";
4167 			break;
4168 		case 3:
4169 			store_op = "Store3";
4170 			break;
4171 		case 4:
4172 			store_op = "Store4";
4173 			break;
4174 		default:
4175 			SPIRV_CROSS_THROW("Unknown vector size.");
4176 		}
4177 
4178 		auto store_expr = write_access_chain_value(value, composite_chain, false);
4179 
4180 		if (!templated_store)
4181 		{
4182 			auto bitcast_op = bitcast_glsl_op(target_type, type);
4183 			if (!bitcast_op.empty())
4184 				store_expr = join(bitcast_op, "(", store_expr, ")");
4185 		}
4186 		else
4187 			store_op = "Store";
4188 		statement(base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ",
4189 		          store_expr, ");");
4190 	}
4191 	else if (type.columns == 1)
4192 	{
4193 		if (templated_store)
4194 		{
4195 			auto scalar_type = type;
4196 			scalar_type.vecsize = 1;
4197 			scalar_type.columns = 1;
4198 			template_expr = join("<", type_to_glsl(scalar_type), ">");
4199 		}
4200 
4201 		// Strided store.
4202 		for (uint32_t r = 0; r < type.vecsize; r++)
4203 		{
4204 			auto store_expr = write_access_chain_value(value, composite_chain, true);
4205 			if (type.vecsize > 1)
4206 			{
4207 				store_expr += ".";
4208 				store_expr += index_to_swizzle(r);
4209 			}
4210 			remove_duplicate_swizzle(store_expr);
4211 
4212 			if (!templated_store)
4213 			{
4214 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4215 				if (!bitcast_op.empty())
4216 					store_expr = join(bitcast_op, "(", store_expr, ")");
4217 			}
4218 
4219 			statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4220 			          chain.static_index + chain.matrix_stride * r, ", ", store_expr, ");");
4221 		}
4222 	}
4223 	else if (!chain.row_major_matrix)
4224 	{
4225 		const char *store_op = nullptr;
4226 		switch (type.vecsize)
4227 		{
4228 		case 1:
4229 			store_op = "Store";
4230 			break;
4231 		case 2:
4232 			store_op = "Store2";
4233 			break;
4234 		case 3:
4235 			store_op = "Store3";
4236 			break;
4237 		case 4:
4238 			store_op = "Store4";
4239 			break;
4240 		default:
4241 			SPIRV_CROSS_THROW("Unknown vector size.");
4242 		}
4243 
4244 		if (templated_store)
4245 		{
4246 			store_op = "Store";
4247 			auto vector_type = type;
4248 			vector_type.columns = 1;
4249 			template_expr = join("<", type_to_glsl(vector_type), ">");
4250 		}
4251 
4252 		for (uint32_t c = 0; c < type.columns; c++)
4253 		{
4254 			auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
4255 
4256 			if (!templated_store)
4257 			{
4258 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4259 				if (!bitcast_op.empty())
4260 					store_expr = join(bitcast_op, "(", store_expr, ")");
4261 			}
4262 
4263 			statement(base, ".", store_op, template_expr, "(", chain.dynamic_index,
4264 			          chain.static_index + c * chain.matrix_stride, ", ", store_expr, ");");
4265 		}
4266 	}
4267 	else
4268 	{
4269 		if (templated_store)
4270 		{
4271 			auto scalar_type = type;
4272 			scalar_type.vecsize = 1;
4273 			scalar_type.columns = 1;
4274 			template_expr = join("<", type_to_glsl(scalar_type), ">");
4275 		}
4276 
4277 		for (uint32_t r = 0; r < type.vecsize; r++)
4278 		{
4279 			for (uint32_t c = 0; c < type.columns; c++)
4280 			{
4281 				auto store_expr =
4282 				    join(write_access_chain_value(value, composite_chain, true), "[", c, "].", index_to_swizzle(r));
4283 				remove_duplicate_swizzle(store_expr);
4284 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4285 				if (!bitcast_op.empty())
4286 					store_expr = join(bitcast_op, "(", store_expr, ")");
4287 				statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4288 				          chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
4289 			}
4290 		}
4291 	}
4292 
4293 	register_write(chain.self);
4294 }
4295 
emit_store(const Instruction & instruction)4296 void CompilerHLSL::emit_store(const Instruction &instruction)
4297 {
4298 	auto ops = stream(instruction);
4299 	auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4300 	if (chain)
4301 		write_access_chain(*chain, ops[1], {});
4302 	else
4303 		CompilerGLSL::emit_instruction(instruction);
4304 }
4305 
emit_access_chain(const Instruction & instruction)4306 void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4307 {
4308 	auto ops = stream(instruction);
4309 	uint32_t length = instruction.length;
4310 
4311 	bool need_byte_access_chain = false;
4312 	auto &type = expression_type(ops[2]);
4313 	const auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4314 
4315 	if (chain)
4316 	{
4317 		// Keep tacking on an existing access chain.
4318 		need_byte_access_chain = true;
4319 	}
4320 	else if (type.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock))
4321 	{
4322 		// If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4323 		// to emit SPIRAccessChain rather than a plain SPIRExpression.
4324 		uint32_t chain_arguments = length - 3;
4325 		if (chain_arguments > type.array.size())
4326 			need_byte_access_chain = true;
4327 	}
4328 
4329 	if (need_byte_access_chain)
4330 	{
4331 		// If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4332 		// and not array of SSBO.
4333 		uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4334 
4335 		auto *backing_variable = maybe_get_backing_variable(ops[2]);
4336 
4337 		string base;
4338 		if (to_plain_buffer_length != 0)
4339 			base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get<SPIRType>(ops[0]));
4340 		else if (chain)
4341 			base = chain->base;
4342 		else
4343 			base = to_expression(ops[2]);
4344 
4345 		// Start traversing type hierarchy at the proper non-pointer types.
4346 		auto *basetype = &get_pointee_type(type);
4347 
4348 		// Traverse the type hierarchy down to the actual buffer types.
4349 		for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4350 		{
4351 			assert(basetype->parent_type);
4352 			basetype = &get<SPIRType>(basetype->parent_type);
4353 		}
4354 
4355 		uint32_t matrix_stride = 0;
4356 		uint32_t array_stride = 0;
4357 		bool row_major_matrix = false;
4358 
4359 		// Inherit matrix information.
4360 		if (chain)
4361 		{
4362 			matrix_stride = chain->matrix_stride;
4363 			row_major_matrix = chain->row_major_matrix;
4364 			array_stride = chain->array_stride;
4365 		}
4366 
4367 		auto offsets = flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
4368 		                                             length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix,
4369 		                                             &matrix_stride, &array_stride);
4370 
4371 		auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
4372 		e.row_major_matrix = row_major_matrix;
4373 		e.matrix_stride = matrix_stride;
4374 		e.array_stride = array_stride;
4375 		e.immutable = should_forward(ops[2]);
4376 		e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4377 
4378 		if (chain)
4379 		{
4380 			e.dynamic_index += chain->dynamic_index;
4381 			e.static_index += chain->static_index;
4382 		}
4383 
4384 		for (uint32_t i = 2; i < length; i++)
4385 		{
4386 			inherit_expression_dependencies(ops[1], ops[i]);
4387 			add_implied_read_expression(e, ops[i]);
4388 		}
4389 	}
4390 	else
4391 	{
4392 		CompilerGLSL::emit_instruction(instruction);
4393 	}
4394 }
4395 
emit_atomic(const uint32_t * ops,uint32_t length,spv::Op op)4396 void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4397 {
4398 	const char *atomic_op = nullptr;
4399 
4400 	string value_expr;
4401 	if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4402 		value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
4403 
4404 	bool is_atomic_store = false;
4405 
4406 	switch (op)
4407 	{
4408 	case OpAtomicIIncrement:
4409 		atomic_op = "InterlockedAdd";
4410 		value_expr = "1";
4411 		break;
4412 
4413 	case OpAtomicIDecrement:
4414 		atomic_op = "InterlockedAdd";
4415 		value_expr = "-1";
4416 		break;
4417 
4418 	case OpAtomicLoad:
4419 		atomic_op = "InterlockedAdd";
4420 		value_expr = "0";
4421 		break;
4422 
4423 	case OpAtomicISub:
4424 		atomic_op = "InterlockedAdd";
4425 		value_expr = join("-", enclose_expression(value_expr));
4426 		break;
4427 
4428 	case OpAtomicSMin:
4429 	case OpAtomicUMin:
4430 		atomic_op = "InterlockedMin";
4431 		break;
4432 
4433 	case OpAtomicSMax:
4434 	case OpAtomicUMax:
4435 		atomic_op = "InterlockedMax";
4436 		break;
4437 
4438 	case OpAtomicAnd:
4439 		atomic_op = "InterlockedAnd";
4440 		break;
4441 
4442 	case OpAtomicOr:
4443 		atomic_op = "InterlockedOr";
4444 		break;
4445 
4446 	case OpAtomicXor:
4447 		atomic_op = "InterlockedXor";
4448 		break;
4449 
4450 	case OpAtomicIAdd:
4451 		atomic_op = "InterlockedAdd";
4452 		break;
4453 
4454 	case OpAtomicExchange:
4455 		atomic_op = "InterlockedExchange";
4456 		break;
4457 
4458 	case OpAtomicStore:
4459 		atomic_op = "InterlockedExchange";
4460 		is_atomic_store = true;
4461 		break;
4462 
4463 	case OpAtomicCompareExchange:
4464 		if (length < 8)
4465 			SPIRV_CROSS_THROW("Not enough data for opcode.");
4466 		atomic_op = "InterlockedCompareExchange";
4467 		value_expr = join(to_expression(ops[7]), ", ", value_expr);
4468 		break;
4469 
4470 	default:
4471 		SPIRV_CROSS_THROW("Unknown atomic opcode.");
4472 	}
4473 
4474 	if (is_atomic_store)
4475 	{
4476 		auto &data_type = expression_type(ops[0]);
4477 		auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4478 
4479 		auto &tmp_id = extra_sub_expressions[ops[0]];
4480 		if (!tmp_id)
4481 		{
4482 			tmp_id = ir.increase_bound_by(1);
4483 			emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
4484 		}
4485 
4486 		if (data_type.storage == StorageClassImage || !chain)
4487 		{
4488 			statement(atomic_op, "(", to_non_uniform_aware_expression(ops[0]), ", ",
4489 			          to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4490 		}
4491 		else
4492 		{
4493 			string base = chain->base;
4494 			if (has_decoration(chain->self, DecorationNonUniform))
4495 				convert_non_uniform_expression(base, chain->self);
4496 			// RWByteAddress buffer is always uint in its underlying type.
4497 			statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ",
4498 			          to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4499 		}
4500 	}
4501 	else
4502 	{
4503 		uint32_t result_type = ops[0];
4504 		uint32_t id = ops[1];
4505 		forced_temporaries.insert(ops[1]);
4506 
4507 		auto &type = get<SPIRType>(result_type);
4508 		statement(variable_decl(type, to_name(id)), ";");
4509 
4510 		auto &data_type = expression_type(ops[2]);
4511 		auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4512 		SPIRType::BaseType expr_type;
4513 		if (data_type.storage == StorageClassImage || !chain)
4514 		{
4515 			statement(atomic_op, "(", to_non_uniform_aware_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
4516 			expr_type = data_type.basetype;
4517 		}
4518 		else
4519 		{
4520 			// RWByteAddress buffer is always uint in its underlying type.
4521 			string base = chain->base;
4522 			if (has_decoration(chain->self, DecorationNonUniform))
4523 				convert_non_uniform_expression(base, chain->self);
4524 			expr_type = SPIRType::UInt;
4525 			statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
4526 			          ", ", to_name(id), ");");
4527 		}
4528 
4529 		auto expr = bitcast_expression(type, expr_type, to_name(id));
4530 		set<SPIRExpression>(id, expr, result_type, true);
4531 	}
4532 	flush_all_atomic_capable_variables();
4533 }
4534 
emit_subgroup_op(const Instruction & i)4535 void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4536 {
4537 	if (hlsl_options.shader_model < 60)
4538 		SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4539 
4540 	const uint32_t *ops = stream(i);
4541 	auto op = static_cast<Op>(i.op);
4542 
4543 	uint32_t result_type = ops[0];
4544 	uint32_t id = ops[1];
4545 
4546 	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
4547 	if (scope != ScopeSubgroup)
4548 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4549 
4550 	const auto make_inclusive_Sum = [&](const string &expr) -> string {
4551 		return join(expr, " + ", to_expression(ops[4]));
4552 	};
4553 
4554 	const auto make_inclusive_Product = [&](const string &expr) -> string {
4555 		return join(expr, " * ", to_expression(ops[4]));
4556 	};
4557 
4558 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
4559 	uint32_t integer_width = get_integer_width_for_instruction(i);
4560 	auto int_type = to_signed_basetype(integer_width);
4561 	auto uint_type = to_unsigned_basetype(integer_width);
4562 
4563 #define make_inclusive_BitAnd(expr) ""
4564 #define make_inclusive_BitOr(expr) ""
4565 #define make_inclusive_BitXor(expr) ""
4566 #define make_inclusive_Min(expr) ""
4567 #define make_inclusive_Max(expr) ""
4568 
4569 	switch (op)
4570 	{
4571 	case OpGroupNonUniformElect:
4572 		emit_op(result_type, id, "WaveIsFirstLane()", true);
4573 		break;
4574 
4575 	case OpGroupNonUniformBroadcast:
4576 		emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4577 		break;
4578 
4579 	case OpGroupNonUniformBroadcastFirst:
4580 		emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
4581 		break;
4582 
4583 	case OpGroupNonUniformBallot:
4584 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
4585 		break;
4586 
4587 	case OpGroupNonUniformInverseBallot:
4588 		SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4589 
4590 	case OpGroupNonUniformBallotBitExtract:
4591 		SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4592 
4593 	case OpGroupNonUniformBallotFindLSB:
4594 		SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4595 
4596 	case OpGroupNonUniformBallotFindMSB:
4597 		SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4598 
4599 	case OpGroupNonUniformBallotBitCount:
4600 	{
4601 		auto operation = static_cast<GroupOperation>(ops[3]);
4602 		if (operation == GroupOperationReduce)
4603 		{
4604 			bool forward = should_forward(ops[4]);
4605 			auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
4606 			                 to_enclosed_expression(ops[4]), ".y)");
4607 			auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
4608 			                  to_enclosed_expression(ops[4]), ".w)");
4609 			emit_op(result_type, id, join(left, " + ", right), forward);
4610 			inherit_expression_dependencies(id, ops[4]);
4611 		}
4612 		else if (operation == GroupOperationInclusiveScan)
4613 			SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4614 		else if (operation == GroupOperationExclusiveScan)
4615 			SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4616 		else
4617 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
4618 		break;
4619 	}
4620 
4621 	case OpGroupNonUniformShuffle:
4622 		emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4623 		break;
4624 	case OpGroupNonUniformShuffleXor:
4625 	{
4626 		bool forward = should_forward(ops[3]);
4627 		emit_op(ops[0], ops[1],
4628 		        join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4629 		             "WaveGetLaneIndex() ^ ", to_enclosed_expression(ops[4]), ")"), forward);
4630 		inherit_expression_dependencies(ops[1], ops[3]);
4631 		break;
4632 	}
4633 	case OpGroupNonUniformShuffleUp:
4634 	{
4635 		bool forward = should_forward(ops[3]);
4636 		emit_op(ops[0], ops[1],
4637 		        join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4638 		             "WaveGetLaneIndex() - ", to_enclosed_expression(ops[4]), ")"), forward);
4639 		inherit_expression_dependencies(ops[1], ops[3]);
4640 		break;
4641 	}
4642 	case OpGroupNonUniformShuffleDown:
4643 	{
4644 		bool forward = should_forward(ops[3]);
4645 		emit_op(ops[0], ops[1],
4646 		        join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4647 		             "WaveGetLaneIndex() + ", to_enclosed_expression(ops[4]), ")"), forward);
4648 		inherit_expression_dependencies(ops[1], ops[3]);
4649 		break;
4650 	}
4651 
4652 	case OpGroupNonUniformAll:
4653 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
4654 		break;
4655 
4656 	case OpGroupNonUniformAny:
4657 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
4658 		break;
4659 
4660 	case OpGroupNonUniformAllEqual:
4661 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
4662 		break;
4663 
4664 	// clang-format off
4665 #define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4666 case OpGroupNonUniform##op: \
4667 	{ \
4668 		auto operation = static_cast<GroupOperation>(ops[3]); \
4669 		if (operation == GroupOperationReduce) \
4670 			emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4671 		else if (operation == GroupOperationInclusiveScan && supports_scan) \
4672         { \
4673 			bool forward = should_forward(ops[4]); \
4674 			emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4675 			inherit_expression_dependencies(id, ops[4]); \
4676         } \
4677 		else if (operation == GroupOperationExclusiveScan && supports_scan) \
4678 			emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4679 		else if (operation == GroupOperationClusteredReduce) \
4680 			SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4681 		else \
4682 			SPIRV_CROSS_THROW("Invalid group operation."); \
4683 		break; \
4684 	}
4685 
4686 #define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4687 case OpGroupNonUniform##op: \
4688 	{ \
4689 		auto operation = static_cast<GroupOperation>(ops[3]); \
4690 		if (operation == GroupOperationReduce) \
4691 			emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4692 		else \
4693 			SPIRV_CROSS_THROW("Invalid group operation."); \
4694 		break; \
4695 	}
4696 
4697 	HLSL_GROUP_OP(FAdd, Sum, true)
4698 	HLSL_GROUP_OP(FMul, Product, true)
4699 	HLSL_GROUP_OP(FMin, Min, false)
4700 	HLSL_GROUP_OP(FMax, Max, false)
4701 	HLSL_GROUP_OP(IAdd, Sum, true)
4702 	HLSL_GROUP_OP(IMul, Product, true)
4703 	HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4704 	HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4705 	HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4706 	HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4707 	HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4708 	HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4709 	HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4710 	HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
4711 	HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
4712 	HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
4713 
4714 #undef HLSL_GROUP_OP
4715 #undef HLSL_GROUP_OP_CAST
4716 		// clang-format on
4717 
4718 	case OpGroupNonUniformQuadSwap:
4719 	{
4720 		uint32_t direction = evaluate_constant_u32(ops[4]);
4721 		if (direction == 0)
4722 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
4723 		else if (direction == 1)
4724 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
4725 		else if (direction == 2)
4726 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
4727 		else
4728 			SPIRV_CROSS_THROW("Invalid quad swap direction.");
4729 		break;
4730 	}
4731 
4732 	case OpGroupNonUniformQuadBroadcast:
4733 	{
4734 		emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
4735 		break;
4736 	}
4737 
4738 	default:
4739 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4740 	}
4741 
4742 	register_control_dependent_expression(id);
4743 }
4744 
emit_instruction(const Instruction & instruction)4745 void CompilerHLSL::emit_instruction(const Instruction &instruction)
4746 {
4747 	auto ops = stream(instruction);
4748 	auto opcode = static_cast<Op>(instruction.op);
4749 
4750 #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4751 #define HLSL_BOP_CAST(op, type) \
4752 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4753 #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4754 #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4755 #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4756 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4757 #define HLSL_BFOP_CAST(op, type) \
4758 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4759 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4760 #define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4761 
4762 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
4763 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
4764 	auto int_type = to_signed_basetype(integer_width);
4765 	auto uint_type = to_unsigned_basetype(integer_width);
4766 
4767 	switch (opcode)
4768 	{
4769 	case OpAccessChain:
4770 	case OpInBoundsAccessChain:
4771 	{
4772 		emit_access_chain(instruction);
4773 		break;
4774 	}
4775 	case OpBitcast:
4776 	{
4777 		auto bitcast_type = get_bitcast_type(ops[0], ops[2]);
4778 		if (bitcast_type == CompilerHLSL::TypeNormal)
4779 			CompilerGLSL::emit_instruction(instruction);
4780 		else
4781 		{
4782 			if (!requires_uint2_packing)
4783 			{
4784 				requires_uint2_packing = true;
4785 				force_recompile();
4786 			}
4787 
4788 			if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4789 				emit_unary_func_op(ops[0], ops[1], ops[2], "spvPackUint2x32");
4790 			else
4791 				emit_unary_func_op(ops[0], ops[1], ops[2], "spvUnpackUint2x32");
4792 		}
4793 
4794 		break;
4795 	}
4796 
4797 	case OpSelect:
4798 	{
4799 		auto &value_type = expression_type(ops[3]);
4800 		if (value_type.basetype == SPIRType::Struct || is_array(value_type))
4801 		{
4802 			// HLSL does not support ternary expressions on composites.
4803 			// Cannot use branches, since we might be in a continue block
4804 			// where explicit control flow is prohibited.
4805 			// Emit a helper function where we can use control flow.
4806 			TypeID value_type_id = expression_type_id(ops[3]);
4807 			auto itr = std::find(composite_selection_workaround_types.begin(),
4808 			                     composite_selection_workaround_types.end(),
4809 			                     value_type_id);
4810 			if (itr == composite_selection_workaround_types.end())
4811 			{
4812 				composite_selection_workaround_types.push_back(value_type_id);
4813 				force_recompile();
4814 			}
4815 			emit_uninitialized_temporary_expression(ops[0], ops[1]);
4816 			statement("spvSelectComposite(",
4817 					  to_expression(ops[1]), ", ", to_expression(ops[2]), ", ",
4818 					  to_expression(ops[3]), ", ", to_expression(ops[4]), ");");
4819 		}
4820 		else
4821 			CompilerGLSL::emit_instruction(instruction);
4822 		break;
4823 	}
4824 
4825 	case OpStore:
4826 	{
4827 		emit_store(instruction);
4828 		break;
4829 	}
4830 
4831 	case OpLoad:
4832 	{
4833 		emit_load(instruction);
4834 		break;
4835 	}
4836 
4837 	case OpMatrixTimesVector:
4838 	{
4839 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4840 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4841 		break;
4842 	}
4843 
4844 	case OpVectorTimesMatrix:
4845 	{
4846 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4847 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4848 		break;
4849 	}
4850 
4851 	case OpMatrixTimesMatrix:
4852 	{
4853 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4854 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4855 		break;
4856 	}
4857 
4858 	case OpOuterProduct:
4859 	{
4860 		uint32_t result_type = ops[0];
4861 		uint32_t id = ops[1];
4862 		uint32_t a = ops[2];
4863 		uint32_t b = ops[3];
4864 
4865 		auto &type = get<SPIRType>(result_type);
4866 		string expr = type_to_glsl_constructor(type);
4867 		expr += "(";
4868 		for (uint32_t col = 0; col < type.columns; col++)
4869 		{
4870 			expr += to_enclosed_expression(a);
4871 			expr += " * ";
4872 			expr += to_extract_component_expression(b, col);
4873 			if (col + 1 < type.columns)
4874 				expr += ", ";
4875 		}
4876 		expr += ")";
4877 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4878 		inherit_expression_dependencies(id, a);
4879 		inherit_expression_dependencies(id, b);
4880 		break;
4881 	}
4882 
4883 	case OpFMod:
4884 	{
4885 		if (!requires_op_fmod)
4886 		{
4887 			requires_op_fmod = true;
4888 			force_recompile();
4889 		}
4890 		CompilerGLSL::emit_instruction(instruction);
4891 		break;
4892 	}
4893 
4894 	case OpFRem:
4895 		emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], "fmod");
4896 		break;
4897 
4898 	case OpImage:
4899 	{
4900 		uint32_t result_type = ops[0];
4901 		uint32_t id = ops[1];
4902 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4903 
4904 		if (combined)
4905 		{
4906 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4907 			auto *var = maybe_get_backing_variable(combined->image);
4908 			if (var)
4909 				e.loaded_from = var->self;
4910 		}
4911 		else
4912 		{
4913 			auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4914 			auto *var = maybe_get_backing_variable(ops[2]);
4915 			if (var)
4916 				e.loaded_from = var->self;
4917 		}
4918 		break;
4919 	}
4920 
4921 	case OpDPdx:
4922 		HLSL_UFOP(ddx);
4923 		register_control_dependent_expression(ops[1]);
4924 		break;
4925 
4926 	case OpDPdy:
4927 		HLSL_UFOP(ddy);
4928 		register_control_dependent_expression(ops[1]);
4929 		break;
4930 
4931 	case OpDPdxFine:
4932 		HLSL_UFOP(ddx_fine);
4933 		register_control_dependent_expression(ops[1]);
4934 		break;
4935 
4936 	case OpDPdyFine:
4937 		HLSL_UFOP(ddy_fine);
4938 		register_control_dependent_expression(ops[1]);
4939 		break;
4940 
4941 	case OpDPdxCoarse:
4942 		HLSL_UFOP(ddx_coarse);
4943 		register_control_dependent_expression(ops[1]);
4944 		break;
4945 
4946 	case OpDPdyCoarse:
4947 		HLSL_UFOP(ddy_coarse);
4948 		register_control_dependent_expression(ops[1]);
4949 		break;
4950 
4951 	case OpFwidth:
4952 	case OpFwidthCoarse:
4953 	case OpFwidthFine:
4954 		HLSL_UFOP(fwidth);
4955 		register_control_dependent_expression(ops[1]);
4956 		break;
4957 
4958 	case OpLogicalNot:
4959 	{
4960 		auto result_type = ops[0];
4961 		auto id = ops[1];
4962 		auto &type = get<SPIRType>(result_type);
4963 
4964 		if (type.vecsize > 1)
4965 			emit_unrolled_unary_op(result_type, id, ops[2], "!");
4966 		else
4967 			HLSL_UOP(!);
4968 		break;
4969 	}
4970 
4971 	case OpIEqual:
4972 	{
4973 		auto result_type = ops[0];
4974 		auto id = ops[1];
4975 
4976 		if (expression_type(ops[2]).vecsize > 1)
4977 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4978 		else
4979 			HLSL_BOP_CAST(==, int_type);
4980 		break;
4981 	}
4982 
4983 	case OpLogicalEqual:
4984 	case OpFOrdEqual:
4985 	case OpFUnordEqual:
4986 	{
4987 		// HLSL != operator is unordered.
4988 		// https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4989 		// isnan() is apparently implemented as x != x as well.
4990 		// We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
4991 		// HACK: FUnordEqual will be implemented as FOrdEqual.
4992 
4993 		auto result_type = ops[0];
4994 		auto id = ops[1];
4995 
4996 		if (expression_type(ops[2]).vecsize > 1)
4997 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4998 		else
4999 			HLSL_BOP(==);
5000 		break;
5001 	}
5002 
5003 	case OpINotEqual:
5004 	{
5005 		auto result_type = ops[0];
5006 		auto id = ops[1];
5007 
5008 		if (expression_type(ops[2]).vecsize > 1)
5009 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5010 		else
5011 			HLSL_BOP_CAST(!=, int_type);
5012 		break;
5013 	}
5014 
5015 	case OpLogicalNotEqual:
5016 	case OpFOrdNotEqual:
5017 	case OpFUnordNotEqual:
5018 	{
5019 		// HLSL != operator is unordered.
5020 		// https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5021 		// isnan() is apparently implemented as x != x as well.
5022 
5023 		// FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5024 		// We would need to do something like not(UnordEqual), but that cannot be expressed either.
5025 		// Adding a lot of NaN checks would be a breaking change from perspective of performance.
5026 		// SPIR-V will generally use isnan() checks when this even matters.
5027 		// HACK: FOrdNotEqual will be implemented as FUnordEqual.
5028 
5029 		auto result_type = ops[0];
5030 		auto id = ops[1];
5031 
5032 		if (expression_type(ops[2]).vecsize > 1)
5033 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5034 		else
5035 			HLSL_BOP(!=);
5036 		break;
5037 	}
5038 
5039 	case OpUGreaterThan:
5040 	case OpSGreaterThan:
5041 	{
5042 		auto result_type = ops[0];
5043 		auto id = ops[1];
5044 		auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5045 
5046 		if (expression_type(ops[2]).vecsize > 1)
5047 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, type);
5048 		else
5049 			HLSL_BOP_CAST(>, type);
5050 		break;
5051 	}
5052 
5053 	case OpFOrdGreaterThan:
5054 	{
5055 		auto result_type = ops[0];
5056 		auto id = ops[1];
5057 
5058 		if (expression_type(ops[2]).vecsize > 1)
5059 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, SPIRType::Unknown);
5060 		else
5061 			HLSL_BOP(>);
5062 		break;
5063 	}
5064 
5065 	case OpFUnordGreaterThan:
5066 	{
5067 		auto result_type = ops[0];
5068 		auto id = ops[1];
5069 
5070 		if (expression_type(ops[2]).vecsize > 1)
5071 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", true, SPIRType::Unknown);
5072 		else
5073 			CompilerGLSL::emit_instruction(instruction);
5074 		break;
5075 	}
5076 
5077 	case OpUGreaterThanEqual:
5078 	case OpSGreaterThanEqual:
5079 	{
5080 		auto result_type = ops[0];
5081 		auto id = ops[1];
5082 
5083 		auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5084 		if (expression_type(ops[2]).vecsize > 1)
5085 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, type);
5086 		else
5087 			HLSL_BOP_CAST(>=, type);
5088 		break;
5089 	}
5090 
5091 	case OpFOrdGreaterThanEqual:
5092 	{
5093 		auto result_type = ops[0];
5094 		auto id = ops[1];
5095 
5096 		if (expression_type(ops[2]).vecsize > 1)
5097 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, SPIRType::Unknown);
5098 		else
5099 			HLSL_BOP(>=);
5100 		break;
5101 	}
5102 
5103 	case OpFUnordGreaterThanEqual:
5104 	{
5105 		auto result_type = ops[0];
5106 		auto id = ops[1];
5107 
5108 		if (expression_type(ops[2]).vecsize > 1)
5109 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", true, SPIRType::Unknown);
5110 		else
5111 			CompilerGLSL::emit_instruction(instruction);
5112 		break;
5113 	}
5114 
5115 	case OpULessThan:
5116 	case OpSLessThan:
5117 	{
5118 		auto result_type = ops[0];
5119 		auto id = ops[1];
5120 
5121 		auto type = opcode == OpULessThan ? uint_type : int_type;
5122 		if (expression_type(ops[2]).vecsize > 1)
5123 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, type);
5124 		else
5125 			HLSL_BOP_CAST(<, type);
5126 		break;
5127 	}
5128 
5129 	case OpFOrdLessThan:
5130 	{
5131 		auto result_type = ops[0];
5132 		auto id = ops[1];
5133 
5134 		if (expression_type(ops[2]).vecsize > 1)
5135 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, SPIRType::Unknown);
5136 		else
5137 			HLSL_BOP(<);
5138 		break;
5139 	}
5140 
5141 	case OpFUnordLessThan:
5142 	{
5143 		auto result_type = ops[0];
5144 		auto id = ops[1];
5145 
5146 		if (expression_type(ops[2]).vecsize > 1)
5147 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", true, SPIRType::Unknown);
5148 		else
5149 			CompilerGLSL::emit_instruction(instruction);
5150 		break;
5151 	}
5152 
5153 	case OpULessThanEqual:
5154 	case OpSLessThanEqual:
5155 	{
5156 		auto result_type = ops[0];
5157 		auto id = ops[1];
5158 
5159 		auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5160 		if (expression_type(ops[2]).vecsize > 1)
5161 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, type);
5162 		else
5163 			HLSL_BOP_CAST(<=, type);
5164 		break;
5165 	}
5166 
5167 	case OpFOrdLessThanEqual:
5168 	{
5169 		auto result_type = ops[0];
5170 		auto id = ops[1];
5171 
5172 		if (expression_type(ops[2]).vecsize > 1)
5173 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, SPIRType::Unknown);
5174 		else
5175 			HLSL_BOP(<=);
5176 		break;
5177 	}
5178 
5179 	case OpFUnordLessThanEqual:
5180 	{
5181 		auto result_type = ops[0];
5182 		auto id = ops[1];
5183 
5184 		if (expression_type(ops[2]).vecsize > 1)
5185 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", true, SPIRType::Unknown);
5186 		else
5187 			CompilerGLSL::emit_instruction(instruction);
5188 		break;
5189 	}
5190 
5191 	case OpImageQueryLod:
5192 		emit_texture_op(instruction, false);
5193 		break;
5194 
5195 	case OpImageQuerySizeLod:
5196 	{
5197 		auto result_type = ops[0];
5198 		auto id = ops[1];
5199 
5200 		require_texture_query_variant(ops[2]);
5201 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5202 		statement("uint ", dummy_samples_levels, ";");
5203 
5204 		auto expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", ",
5205 		                 bitcast_expression(SPIRType::UInt, ops[3]), ", ", dummy_samples_levels, ")");
5206 
5207 		auto &restype = get<SPIRType>(ops[0]);
5208 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5209 		emit_op(result_type, id, expr, true);
5210 		break;
5211 	}
5212 
5213 	case OpImageQuerySize:
5214 	{
5215 		auto result_type = ops[0];
5216 		auto id = ops[1];
5217 
5218 		require_texture_query_variant(ops[2]);
5219 		bool uav = expression_type(ops[2]).image.sampled == 2;
5220 
5221 		if (const auto *var = maybe_get_backing_variable(ops[2]))
5222 			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5223 				uav = false;
5224 
5225 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5226 		statement("uint ", dummy_samples_levels, ";");
5227 
5228 		string expr;
5229 		if (uav)
5230 			expr = join("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", dummy_samples_levels, ")");
5231 		else
5232 			expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
5233 
5234 		auto &restype = get<SPIRType>(ops[0]);
5235 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5236 		emit_op(result_type, id, expr, true);
5237 		break;
5238 	}
5239 
5240 	case OpImageQuerySamples:
5241 	case OpImageQueryLevels:
5242 	{
5243 		auto result_type = ops[0];
5244 		auto id = ops[1];
5245 
5246 		require_texture_query_variant(ops[2]);
5247 		bool uav = expression_type(ops[2]).image.sampled == 2;
5248 		if (opcode == OpImageQueryLevels && uav)
5249 			SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5250 
5251 		if (const auto *var = maybe_get_backing_variable(ops[2]))
5252 			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5253 				uav = false;
5254 
5255 		// Keep it simple and do not emit special variants to make this look nicer ...
5256 		// This stuff is barely, if ever, used.
5257 		forced_temporaries.insert(id);
5258 		auto &type = get<SPIRType>(result_type);
5259 		statement(variable_decl(type, to_name(id)), ";");
5260 
5261 		if (uav)
5262 			statement("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", to_name(id), ");");
5263 		else
5264 			statement("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", to_name(id), ");");
5265 
5266 		auto &restype = get<SPIRType>(ops[0]);
5267 		auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
5268 		set<SPIRExpression>(id, expr, result_type, true);
5269 		break;
5270 	}
5271 
5272 	case OpImageRead:
5273 	{
5274 		uint32_t result_type = ops[0];
5275 		uint32_t id = ops[1];
5276 		auto *var = maybe_get_backing_variable(ops[2]);
5277 		auto &type = expression_type(ops[2]);
5278 		bool subpass_data = type.image.dim == DimSubpassData;
5279 		bool pure = false;
5280 
5281 		string imgexpr;
5282 
5283 		if (subpass_data)
5284 		{
5285 			if (hlsl_options.shader_model < 40)
5286 				SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5287 
5288 			// Similar to GLSL, implement subpass loads using texelFetch.
5289 			if (type.image.ms)
5290 			{
5291 				uint32_t operands = ops[4];
5292 				if (operands != ImageOperandsSampleMask || instruction.length != 6)
5293 					SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5294 				uint32_t sample = ops[5];
5295 				imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int2(gl_FragCoord.xy), ", to_expression(sample), ")");
5296 			}
5297 			else
5298 				imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int3(int2(gl_FragCoord.xy), 0))");
5299 
5300 			pure = true;
5301 		}
5302 		else
5303 		{
5304 			imgexpr = join(to_non_uniform_aware_expression(ops[2]), "[", to_expression(ops[3]), "]");
5305 			// The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5306 			// except that the underlying type changes how the data is interpreted.
5307 
5308 			bool force_srv =
5309 			    hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(var->self, DecorationNonWritable);
5310 			pure = force_srv;
5311 
5312 			if (var && !subpass_data && !force_srv)
5313 				imgexpr = remap_swizzle(get<SPIRType>(result_type),
5314 				                        image_format_to_components(get<SPIRType>(var->basetype).image.format), imgexpr);
5315 		}
5316 
5317 		if (var && var->forwardable)
5318 		{
5319 			bool forward = forced_temporaries.find(id) == end(forced_temporaries);
5320 			auto &e = emit_op(result_type, id, imgexpr, forward);
5321 
5322 			if (!pure)
5323 			{
5324 				e.loaded_from = var->self;
5325 				if (forward)
5326 					var->dependees.push_back(id);
5327 			}
5328 		}
5329 		else
5330 			emit_op(result_type, id, imgexpr, false);
5331 
5332 		inherit_expression_dependencies(id, ops[2]);
5333 		if (type.image.ms)
5334 			inherit_expression_dependencies(id, ops[5]);
5335 		break;
5336 	}
5337 
5338 	case OpImageWrite:
5339 	{
5340 		auto *var = maybe_get_backing_variable(ops[0]);
5341 
5342 		// The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5343 		// except that the underlying type changes how the data is interpreted.
5344 		auto value_expr = to_expression(ops[2]);
5345 		if (var)
5346 		{
5347 			auto &type = get<SPIRType>(var->basetype);
5348 			auto narrowed_type = get<SPIRType>(type.image.type);
5349 			narrowed_type.vecsize = image_format_to_components(type.image.format);
5350 			value_expr = remap_swizzle(narrowed_type, expression_type(ops[2]).vecsize, value_expr);
5351 		}
5352 
5353 		statement(to_non_uniform_aware_expression(ops[0]), "[", to_expression(ops[1]), "] = ", value_expr, ";");
5354 		if (var && variable_storage_is_aliased(*var))
5355 			flush_all_aliased_variables();
5356 		break;
5357 	}
5358 
5359 	case OpImageTexelPointer:
5360 	{
5361 		uint32_t result_type = ops[0];
5362 		uint32_t id = ops[1];
5363 
5364 		auto expr = to_expression(ops[2]);
5365 		expr += join("[", to_expression(ops[3]), "]");
5366 		auto &e = set<SPIRExpression>(id, expr, result_type, true);
5367 
5368 		// When using the pointer, we need to know which variable it is actually loaded from.
5369 		auto *var = maybe_get_backing_variable(ops[2]);
5370 		e.loaded_from = var ? var->self : ID(0);
5371 		inherit_expression_dependencies(id, ops[3]);
5372 		break;
5373 	}
5374 
5375 	case OpAtomicCompareExchange:
5376 	case OpAtomicExchange:
5377 	case OpAtomicISub:
5378 	case OpAtomicSMin:
5379 	case OpAtomicUMin:
5380 	case OpAtomicSMax:
5381 	case OpAtomicUMax:
5382 	case OpAtomicAnd:
5383 	case OpAtomicOr:
5384 	case OpAtomicXor:
5385 	case OpAtomicIAdd:
5386 	case OpAtomicIIncrement:
5387 	case OpAtomicIDecrement:
5388 	case OpAtomicLoad:
5389 	case OpAtomicStore:
5390 	{
5391 		emit_atomic(ops, instruction.length, opcode);
5392 		break;
5393 	}
5394 
5395 	case OpControlBarrier:
5396 	case OpMemoryBarrier:
5397 	{
5398 		uint32_t memory;
5399 		uint32_t semantics;
5400 
5401 		if (opcode == OpMemoryBarrier)
5402 		{
5403 			memory = evaluate_constant_u32(ops[0]);
5404 			semantics = evaluate_constant_u32(ops[1]);
5405 		}
5406 		else
5407 		{
5408 			memory = evaluate_constant_u32(ops[1]);
5409 			semantics = evaluate_constant_u32(ops[2]);
5410 		}
5411 
5412 		if (memory == ScopeSubgroup)
5413 		{
5414 			// No Wave-barriers in HLSL.
5415 			break;
5416 		}
5417 
5418 		// We only care about these flags, acquire/release and friends are not relevant to GLSL.
5419 		semantics = mask_relevant_memory_semantics(semantics);
5420 
5421 		if (opcode == OpMemoryBarrier)
5422 		{
5423 			// If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5424 			// does what we need, so we avoid redundant barriers.
5425 			const Instruction *next = get_next_instruction_in_block(instruction);
5426 			if (next && next->op == OpControlBarrier)
5427 			{
5428 				auto *next_ops = stream(*next);
5429 				uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
5430 				uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
5431 				next_semantics = mask_relevant_memory_semantics(next_semantics);
5432 
5433 				// There is no "just execution barrier" in HLSL.
5434 				// If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5435 				if (next_semantics == 0)
5436 					next_semantics = MemorySemanticsWorkgroupMemoryMask;
5437 
5438 				bool memory_scope_covered = false;
5439 				if (next_memory == memory)
5440 					memory_scope_covered = true;
5441 				else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5442 				{
5443 					// If we only care about workgroup memory, either Device or Workgroup scope is fine,
5444 					// scope does not have to match.
5445 					if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5446 					    (memory == ScopeDevice || memory == ScopeWorkgroup))
5447 					{
5448 						memory_scope_covered = true;
5449 					}
5450 				}
5451 				else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5452 				{
5453 					// The control barrier has device scope, but the memory barrier just has workgroup scope.
5454 					memory_scope_covered = true;
5455 				}
5456 
5457 				// If we have the same memory scope, and all memory types are covered, we're good.
5458 				if (memory_scope_covered && (semantics & next_semantics) == semantics)
5459 					break;
5460 			}
5461 		}
5462 
5463 		// We are synchronizing some memory or syncing execution,
5464 		// so we cannot forward any loads beyond the memory barrier.
5465 		if (semantics || opcode == OpControlBarrier)
5466 		{
5467 			assert(current_emitting_block);
5468 			flush_control_dependent_expressions(current_emitting_block->self);
5469 			flush_all_active_variables();
5470 		}
5471 
5472 		if (opcode == OpControlBarrier)
5473 		{
5474 			// We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5475 			if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5476 				statement("GroupMemoryBarrierWithGroupSync();");
5477 			else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5478 				statement("DeviceMemoryBarrierWithGroupSync();");
5479 			else
5480 				statement("AllMemoryBarrierWithGroupSync();");
5481 		}
5482 		else
5483 		{
5484 			if (semantics == MemorySemanticsWorkgroupMemoryMask)
5485 				statement("GroupMemoryBarrier();");
5486 			else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5487 				statement("DeviceMemoryBarrier();");
5488 			else
5489 				statement("AllMemoryBarrier();");
5490 		}
5491 		break;
5492 	}
5493 
5494 	case OpBitFieldInsert:
5495 	{
5496 		if (!requires_bitfield_insert)
5497 		{
5498 			requires_bitfield_insert = true;
5499 			force_recompile();
5500 		}
5501 
5502 		auto expr = join("spvBitfieldInsert(", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ",
5503 		                 to_expression(ops[4]), ", ", to_expression(ops[5]), ")");
5504 
5505 		bool forward =
5506 		    should_forward(ops[2]) && should_forward(ops[3]) && should_forward(ops[4]) && should_forward(ops[5]);
5507 
5508 		auto &restype = get<SPIRType>(ops[0]);
5509 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5510 		emit_op(ops[0], ops[1], expr, forward);
5511 		break;
5512 	}
5513 
5514 	case OpBitFieldSExtract:
5515 	case OpBitFieldUExtract:
5516 	{
5517 		if (!requires_bitfield_extract)
5518 		{
5519 			requires_bitfield_extract = true;
5520 			force_recompile();
5521 		}
5522 
5523 		if (opcode == OpBitFieldSExtract)
5524 			HLSL_TFOP(spvBitfieldSExtract);
5525 		else
5526 			HLSL_TFOP(spvBitfieldUExtract);
5527 		break;
5528 	}
5529 
5530 	case OpBitCount:
5531 	{
5532 		auto basetype = expression_type(ops[2]).basetype;
5533 		emit_unary_func_op_cast(ops[0], ops[1], ops[2], "countbits", basetype, basetype);
5534 		break;
5535 	}
5536 
5537 	case OpBitReverse:
5538 		HLSL_UFOP(reversebits);
5539 		break;
5540 
5541 	case OpArrayLength:
5542 	{
5543 		auto *var = maybe_get_backing_variable(ops[2]);
5544 		if (!var)
5545 			SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5546 
5547 		auto &type = get<SPIRType>(var->basetype);
5548 		if (!has_decoration(type.self, DecorationBlock) && !has_decoration(type.self, DecorationBufferBlock))
5549 			SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5550 
5551 		// This must be 32-bit uint, so we're good to go.
5552 		emit_uninitialized_temporary_expression(ops[0], ops[1]);
5553 		statement(to_non_uniform_aware_expression(ops[2]), ".GetDimensions(", to_expression(ops[1]), ");");
5554 		uint32_t offset = type_struct_member_offset(type, ops[3]);
5555 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
5556 		statement(to_expression(ops[1]), " = (", to_expression(ops[1]), " - ", offset, ") / ", stride, ";");
5557 		break;
5558 	}
5559 
5560 	case OpIsHelperInvocationEXT:
5561 		SPIRV_CROSS_THROW("helperInvocationEXT() is not supported in HLSL.");
5562 
5563 	case OpBeginInvocationInterlockEXT:
5564 	case OpEndInvocationInterlockEXT:
5565 		if (hlsl_options.shader_model < 51)
5566 			SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5567 		break; // Nothing to do in the body
5568 
5569 	default:
5570 		CompilerGLSL::emit_instruction(instruction);
5571 		break;
5572 	}
5573 }
5574 
require_texture_query_variant(uint32_t var_id)5575 void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5576 {
5577 	if (const auto *var = maybe_get_backing_variable(var_id))
5578 		var_id = var->self;
5579 
5580 	auto &type = expression_type(var_id);
5581 	bool uav = type.image.sampled == 2;
5582 	if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
5583 		uav = false;
5584 
5585 	uint32_t bit = 0;
5586 	switch (type.image.dim)
5587 	{
5588 	case Dim1D:
5589 		bit = type.image.arrayed ? Query1DArray : Query1D;
5590 		break;
5591 
5592 	case Dim2D:
5593 		if (type.image.ms)
5594 			bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5595 		else
5596 			bit = type.image.arrayed ? Query2DArray : Query2D;
5597 		break;
5598 
5599 	case Dim3D:
5600 		bit = Query3D;
5601 		break;
5602 
5603 	case DimCube:
5604 		bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5605 		break;
5606 
5607 	case DimBuffer:
5608 		bit = QueryBuffer;
5609 		break;
5610 
5611 	default:
5612 		SPIRV_CROSS_THROW("Unsupported query type.");
5613 	}
5614 
5615 	switch (get<SPIRType>(type.image.type).basetype)
5616 	{
5617 	case SPIRType::Float:
5618 		bit += QueryTypeFloat;
5619 		break;
5620 
5621 	case SPIRType::Int:
5622 		bit += QueryTypeInt;
5623 		break;
5624 
5625 	case SPIRType::UInt:
5626 		bit += QueryTypeUInt;
5627 		break;
5628 
5629 	default:
5630 		SPIRV_CROSS_THROW("Unsupported query type.");
5631 	}
5632 
5633 	auto norm_state = image_format_to_normalized_state(type.image.format);
5634 	auto &variant = uav ? required_texture_size_variants
5635 	                          .uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
5636 	                      required_texture_size_variants.srv;
5637 
5638 	uint64_t mask = 1ull << bit;
5639 	if ((variant & mask) == 0)
5640 	{
5641 		force_recompile();
5642 		variant |= mask;
5643 	}
5644 }
5645 
set_root_constant_layouts(std::vector<RootConstants> layout)5646 void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5647 {
5648 	root_constants_layout = move(layout);
5649 }
5650 
add_vertex_attribute_remap(const HLSLVertexAttributeRemap & vertex_attributes)5651 void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5652 {
5653 	remap_vertex_attributes.push_back(vertex_attributes);
5654 }
5655 
remap_num_workgroups_builtin()5656 VariableID CompilerHLSL::remap_num_workgroups_builtin()
5657 {
5658 	update_active_builtins();
5659 
5660 	if (!active_input_builtins.get(BuiltInNumWorkgroups))
5661 		return 0;
5662 
5663 	// Create a new, fake UBO.
5664 	uint32_t offset = ir.increase_bound_by(4);
5665 
5666 	uint32_t uint_type_id = offset;
5667 	uint32_t block_type_id = offset + 1;
5668 	uint32_t block_pointer_type_id = offset + 2;
5669 	uint32_t variable_id = offset + 3;
5670 
5671 	SPIRType uint_type;
5672 	uint_type.basetype = SPIRType::UInt;
5673 	uint_type.width = 32;
5674 	uint_type.vecsize = 3;
5675 	uint_type.columns = 1;
5676 	set<SPIRType>(uint_type_id, uint_type);
5677 
5678 	SPIRType block_type;
5679 	block_type.basetype = SPIRType::Struct;
5680 	block_type.member_types.push_back(uint_type_id);
5681 	set<SPIRType>(block_type_id, block_type);
5682 	set_decoration(block_type_id, DecorationBlock);
5683 	set_member_name(block_type_id, 0, "count");
5684 	set_member_decoration(block_type_id, 0, DecorationOffset, 0);
5685 
5686 	SPIRType block_pointer_type = block_type;
5687 	block_pointer_type.pointer = true;
5688 	block_pointer_type.storage = StorageClassUniform;
5689 	block_pointer_type.parent_type = block_type_id;
5690 	auto &ptr_type = set<SPIRType>(block_pointer_type_id, block_pointer_type);
5691 
5692 	// Preserve self.
5693 	ptr_type.self = block_type_id;
5694 
5695 	set<SPIRVariable>(variable_id, block_pointer_type_id, StorageClassUniform);
5696 	ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5697 
5698 	num_workgroups_builtin = variable_id;
5699 	return variable_id;
5700 }
5701 
set_resource_binding_flags(HLSLBindingFlags flags)5702 void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5703 {
5704 	resource_binding_flags = flags;
5705 }
5706 
validate_shader_model()5707 void CompilerHLSL::validate_shader_model()
5708 {
5709 	// Check for nonuniform qualifier.
5710 	// Instead of looping over all decorations to find this, just look at capabilities.
5711 	for (auto &cap : ir.declared_capabilities)
5712 	{
5713 		switch (cap)
5714 		{
5715 		case CapabilityShaderNonUniformEXT:
5716 		case CapabilityRuntimeDescriptorArrayEXT:
5717 			if (hlsl_options.shader_model < 51)
5718 				SPIRV_CROSS_THROW(
5719 				    "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5720 			break;
5721 
5722 		case CapabilityVariablePointers:
5723 		case CapabilityVariablePointersStorageBuffer:
5724 			SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5725 
5726 		default:
5727 			break;
5728 		}
5729 	}
5730 
5731 	if (ir.addressing_model != AddressingModelLogical)
5732 		SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5733 
5734 	if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5735 		SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5736 }
5737 
compile()5738 string CompilerHLSL::compile()
5739 {
5740 	ir.fixup_reserved_names();
5741 
5742 	// Do not deal with ES-isms like precision, older extensions and such.
5743 	options.es = false;
5744 	options.version = 450;
5745 	options.vulkan_semantics = true;
5746 	backend.float_literal_suffix = true;
5747 	backend.double_literal_suffix = false;
5748 	backend.long_long_literal_suffix = true;
5749 	backend.uint32_t_literal_suffix = true;
5750 	backend.int16_t_literal_suffix = "";
5751 	backend.uint16_t_literal_suffix = "u";
5752 	backend.basic_int_type = "int";
5753 	backend.basic_uint_type = "uint";
5754 	backend.demote_literal = "discard";
5755 	backend.boolean_mix_function = "";
5756 	backend.swizzle_is_function = false;
5757 	backend.shared_is_implied = true;
5758 	backend.unsized_array_supported = true;
5759 	backend.explicit_struct_type = false;
5760 	backend.use_initializer_list = true;
5761 	backend.use_constructor_splatting = false;
5762 	backend.can_swizzle_scalar = true;
5763 	backend.can_declare_struct_inline = false;
5764 	backend.can_declare_arrays_inline = false;
5765 	backend.can_return_array = false;
5766 	backend.nonuniform_qualifier = "NonUniformResourceIndex";
5767 	backend.support_case_fallthrough = false;
5768 
5769 	// SM 4.1 does not support precise for some reason.
5770 	backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
5771 
5772 	fixup_type_alias();
5773 	reorder_type_alias();
5774 	build_function_control_flow_graphs_and_analyze();
5775 	validate_shader_model();
5776 	update_active_builtins();
5777 	analyze_image_and_sampler_usage();
5778 	analyze_interlocked_resource_usage();
5779 
5780 	// Subpass input needs SV_Position.
5781 	if (need_subpass_input)
5782 		active_input_builtins.set(BuiltInFragCoord);
5783 
5784 	uint32_t pass_count = 0;
5785 	do
5786 	{
5787 		if (pass_count >= 3)
5788 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
5789 
5790 		reset();
5791 
5792 		// Move constructor for this type is broken on GCC 4.9 ...
5793 		buffer.reset();
5794 
5795 		emit_header();
5796 		emit_resources();
5797 
5798 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
5799 		emit_hlsl_entry_point();
5800 
5801 		pass_count++;
5802 	} while (is_forcing_recompilation());
5803 
5804 	// Entry point in HLSL is always main() for the time being.
5805 	get_entry_point().name = "main";
5806 
5807 	return buffer.str();
5808 }
5809 
emit_block_hints(const SPIRBlock & block)5810 void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
5811 {
5812 	switch (block.hint)
5813 	{
5814 	case SPIRBlock::HintFlatten:
5815 		statement("[flatten]");
5816 		break;
5817 	case SPIRBlock::HintDontFlatten:
5818 		statement("[branch]");
5819 		break;
5820 	case SPIRBlock::HintUnroll:
5821 		statement("[unroll]");
5822 		break;
5823 	case SPIRBlock::HintDontUnroll:
5824 		statement("[loop]");
5825 		break;
5826 	default:
5827 		break;
5828 	}
5829 }
5830 
get_unique_identifier()5831 string CompilerHLSL::get_unique_identifier()
5832 {
5833 	return join("_", unique_identifier_count++, "ident");
5834 }
5835 
add_hlsl_resource_binding(const HLSLResourceBinding & binding)5836 void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
5837 {
5838 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
5839 	resource_bindings[tuple] = { binding, false };
5840 }
5841 
is_hlsl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const5842 bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
5843 {
5844 	StageSetBinding tuple = { model, desc_set, binding };
5845 	auto itr = resource_bindings.find(tuple);
5846 	return itr != end(resource_bindings) && itr->second.second;
5847 }
5848 
get_bitcast_type(uint32_t result_type,uint32_t op0)5849 CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
5850 {
5851 	auto &rslt_type = get<SPIRType>(result_type);
5852 	auto &expr_type = expression_type(op0);
5853 
5854 	if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
5855 	    expr_type.vecsize == 2)
5856 		return BitcastType::TypePackUint2x32;
5857 	else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
5858 	         expr_type.basetype == SPIRType::BaseType::UInt64)
5859 		return BitcastType::TypeUnpackUint64;
5860 
5861 	return BitcastType::TypeNormal;
5862 }
5863 
is_hlsl_force_storage_buffer_as_uav(ID id) const5864 bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
5865 {
5866 	if (hlsl_options.force_storage_buffer_as_uav)
5867 	{
5868 		return true;
5869 	}
5870 
5871 	const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
5872 	const uint32_t binding = get_decoration(id, spv::DecorationBinding);
5873 
5874 	return (force_uav_buffer_bindings.find({ desc_set, binding }) != force_uav_buffer_bindings.end());
5875 }
5876 
set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set,uint32_t binding)5877 void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
5878 {
5879 	SetBindingPair pair = { desc_set, binding };
5880 	force_uav_buffer_bindings.insert(pair);
5881 }
5882 
builtin_translates_to_nonarray(spv::BuiltIn builtin) const5883 bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
5884 {
5885 	return (builtin == BuiltInSampleMask);
5886 }
5887